From a1f13c4e7c57ab098e7882a96f2d8f003b4d7276 Mon Sep 17 00:00:00 2001 From: xiachenrui Date: Sat, 8 Jun 2024 19:46:25 +0800 Subject: [PATCH] :bug: Fix git action --- .github/workflows/build.yml | 4 +-- pyproject.toml | 6 +++- scSLAT/utils.py | 65 ++++++++++++++++++++++++++++++++++++- 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6a27543..06cc31c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: fast_finish: [false] - python-version: [3.8] + python-version: [3.10] steps: - uses: actions/checkout@v2 @@ -31,7 +31,7 @@ jobs: curl -sSL https://install.python-poetry.org | python3 - pip install --upgrade pip pip install -e ".[docs, dev]" - pip install pyg_lib torch_scatter==2.1.1 torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html + install_pyg_dependencies - name: Build documentation run: | diff --git a/pyproject.toml b/pyproject.toml index c9f1c2e..2b97a81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ pynvml = "*" scikit-misc = "*" opencv-python = "*" harmonypy = "*" +loguru = "*" torch = {version = ">=2.0.0"} torchvision = {version = "*"} @@ -85,4 +86,7 @@ dev = ["pytest", "pytest-cov", "papermill", "ipython", "jupyter", "parse", "nbfo [tool.pyright] include = ["scSLAT"] exclude = ["**/conda", "**/__pycache__", "**/.**"] -ignore = ["resource/"] \ No newline at end of file +ignore = ["resource/"] + +[tool.poetry.scripts] +install_pyg_dependencies = "scSLAT.utils:install_pyg_dep" \ No newline at end of file diff --git a/scSLAT/utils.py b/scSLAT/utils.py index da41e75..2843477 100644 --- a/scSLAT/utils.py +++ b/scSLAT/utils.py @@ -2,10 +2,12 @@ Miscellaneous utilities """ import random +from subprocess import run from pynvml import * from anndata import AnnData from typing import List, Optional, Union +from loguru import logger import numpy as np import torch @@ -96,4 +98,65 @@ def global_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - print(f"Global seed set to {seed}.") \ No newline at end of file + print(f"Global seed set to {seed}.") + + +def install_pyg_dep(torch_version: str = None, cuda_version: str = None): + r""" + Automatically install PyG dependencies + + Parameters + ---------- + torch_version + torch version, e.g. 2.2.1 + cuda_version + cuda version, e.g. 12.1 + """ + if torch_version is None: + torch_version = torch.__version__ + torch_version = torch_version.split("+")[0] + + if cuda_version is None: + cuda_version = torch.version.cuda + + if torch_version < "2.0": + raise ValueError(f"PyG only support torch>=2.0, but get {torch_version}") + elif "2.0" <= torch_version < "2.1": + torch_version = "2.0.0" + elif "2.1" <= torch_version < "2.2": + torch_version = "2.1.0" + elif "2.2" <= torch_version < "2.3": + torch_version = "2.2.0" + + if "cu" in cuda_version and not torch.cuda.is_available(): + logger.warning( + "CUDA is not available, try install CPU version, but may raise error." + ) + cuda_version = "cpu" + elif cuda_version >= "12.1": + cuda_version = "cu121" + elif "11.8" <= cuda_version < "12.1": + cuda_version = "cu118" + elif "11.7" <= cuda_version < "11.8": + cuda_version = "cu117" + else: + raise ValueError(f"PyG only support cuda>=11.7, but get {cuda_version}") + + if torch_version == "2.2.0" and cuda_version == "cu117": + raise ValueError( + "PyG not support torch-2.2.* with cuda-11.7, please check https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html" + ) + if torch_version == "2.1.0" and cuda_version == "cu117": + raise ValueError( + "PyG not support torch-2.1.* with cuda-11.7, please check https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html" + ) + if torch_version == "2.0.0" and cuda_version == "cu121": + raise ValueError( + "PyG not support torch-2.0.* with cuda-12.1, please check https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html" + ) + + logger.info( + f"Installing PyG dependencies for torch-{torch_version} and cuda-{cuda_version}" + ) + cmd = f"pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html" + run(cmd, shell=True)