Skip to content

M-LibCity: An Open Source Library for Urban Spatio-temporal Prediction Models Based on MindSpore

License

Notifications You must be signed in to change notification settings

LibCity/M-LibCity

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 

Repository files navigation

M-Libcity

一、介绍

M-libcity 是一个基于华为MindSpore框架实现的开源算法库,专注于城市时空预测领域。它为MindSpore开发人员提供了统一、全面、可扩展的时空预测模型实现方案,同时为城市时空预测研究人员提供了可靠的实验工具和便捷的开发框架。M-LibCity开源算法库涵盖了与城市时空预测相关的所有必要步骤和组件,构建了完整的研究流程,使研究人员能够进行全面的对比实验。这将为研究人员在MindSpore平台上开展城市时空预测研究提供便利和强大的支持。

目前,M-LibCity已支持交通状态预测、轨迹下一跳预测及到达时间预测等多项关键的交通预测任务,涵盖9个深度学习模型,6个交通数据集。除此之外,M-LibCity还支持在GPU/NPU等多种后端上,进行多卡训练/推理加速。

二、安装与配置

1. 配置MindSpore依赖环境,在Mindspore官网寻找安装命令:

  • 如果是在全新的系统上使用pip安装Mindspore,可以使用自动安装脚本进行一键式安装。安装脚本会安装Mindspore以及其所需要的依赖。
  • 如果系统已经安装了部分依赖,如Python,GCC等,可以参照官网手动安装步骤进行。

2. 安装运行的Python环境依赖

matplotlib==3.5.1
networkx==2.5
numpy==1.21.6
pandas==1.3.5
scikit_learn==0.24.0
scipy==1.5.4
tqdm==4.62.3

PS:如在启智平台使用还需安装下列包:

moxing_framework==2.1.7.dc1f3d0b//启智平台需要

3. 下载所需数据集

数据集下载链接:https://drive.google.com/file/d/11a6PyE5KrFK0wnI7RSv4sxLV9SLxqlLs/view?usp=drive_link

raw_data下载链接:https://drive.google.com/file/d/1JqnGsprG2zpJ4tisf0o3liOLy3GR7t0i/view?usp=drive_link

运行需要首先下载需要的数据集,并放到M-LibCity-[x]/raw_data下,其中x={gpu,npu}。

三、使用说明

快速运行代码命令

1.  cd [rootpath for project]
2.  python run_model.py [task] [model_name] [dataset]

PS:目录结构中的M-LibCity-gpu和M-LibCity-npu分别对应着GPU端和NPU端的代码,请移动到对应目录结构下执行命令。

修改模型参数

所有的pipeline默认参数都存放在M_libcity/config文件夹下。 模型配置文件可在M_libcity/config/model文件夹下找到,该文件夹按照model的类别进行分类。 task_config.json记录了模型要加载的具体数据模块配置文件、执行模块配置文件、评估模块配置文件和模型模块配置文件,可通过task_config.json查看对应关系。 如想添加其他参数,可以在run_model.py中的run_model()中通过other_args={key:value}的形式传递。

PS:所有参数的注释以及取值可从https://bigscity-libcity-docs.readthedocs.io/en/latest/user_guide/config_settings.html 搜索得到。

调试任务实现多卡训练

调用 M_libCity_[x]/run_with_multi_devices.sh 文件,其中[x]={gpu, npu}.

多卡训练,启动方式为:

1.  bash run_with_multi_devices.sh 2 [task] [model_name] [dataset]

参数2表示卡数为2。