Skip to content

Commit

Permalink
🐛 Fix git action
Browse files Browse the repository at this point in the history
  • Loading branch information
xiachenrui committed Jun 8, 2024
1 parent e2c3a63 commit a1f13c4
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
matrix:
fast_finish: [false]
python-version: [3.8]
python-version: [3.10]

steps:
- uses: actions/checkout@v2
Expand All @@ -31,7 +31,7 @@ jobs:
curl -sSL https://install.python-poetry.org | python3 -
pip install --upgrade pip
pip install -e ".[docs, dev]"
pip install pyg_lib torch_scatter==2.1.1 torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
install_pyg_dependencies
- name: Build documentation
run: |
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pynvml = "*"
scikit-misc = "*"
opencv-python = "*"
harmonypy = "*"
loguru = "*"

torch = {version = ">=2.0.0"}
torchvision = {version = "*"}
Expand Down Expand Up @@ -85,4 +86,7 @@ dev = ["pytest", "pytest-cov", "papermill", "ipython", "jupyter", "parse", "nbfo
[tool.pyright]
include = ["scSLAT"]
exclude = ["**/conda", "**/__pycache__", "**/.**"]
ignore = ["resource/"]
ignore = ["resource/"]

[tool.poetry.scripts]
install_pyg_dependencies = "scSLAT.utils:install_pyg_dep"
65 changes: 64 additions & 1 deletion scSLAT/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
Miscellaneous utilities
"""
import random
from subprocess import run

from pynvml import *
from anndata import AnnData
from typing import List, Optional, Union
from loguru import logger

import numpy as np
import torch
Expand Down Expand Up @@ -96,4 +98,65 @@ def global_seed(seed: int) -> None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
print(f"Global seed set to {seed}.")
print(f"Global seed set to {seed}.")


def install_pyg_dep(torch_version: str = None, cuda_version: str = None):
r"""
Automatically install PyG dependencies
Parameters
----------
torch_version
torch version, e.g. 2.2.1
cuda_version
cuda version, e.g. 12.1
"""
if torch_version is None:
torch_version = torch.__version__
torch_version = torch_version.split("+")[0]

if cuda_version is None:
cuda_version = torch.version.cuda

if torch_version < "2.0":
raise ValueError(f"PyG only support torch>=2.0, but get {torch_version}")
elif "2.0" <= torch_version < "2.1":
torch_version = "2.0.0"
elif "2.1" <= torch_version < "2.2":
torch_version = "2.1.0"
elif "2.2" <= torch_version < "2.3":
torch_version = "2.2.0"

if "cu" in cuda_version and not torch.cuda.is_available():
logger.warning(
"CUDA is not available, try install CPU version, but may raise error."
)
cuda_version = "cpu"
elif cuda_version >= "12.1":
cuda_version = "cu121"
elif "11.8" <= cuda_version < "12.1":
cuda_version = "cu118"
elif "11.7" <= cuda_version < "11.8":
cuda_version = "cu117"
else:
raise ValueError(f"PyG only support cuda>=11.7, but get {cuda_version}")

if torch_version == "2.2.0" and cuda_version == "cu117":
raise ValueError(
"PyG not support torch-2.2.* with cuda-11.7, please check https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html"
)
if torch_version == "2.1.0" and cuda_version == "cu117":
raise ValueError(
"PyG not support torch-2.1.* with cuda-11.7, please check https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html"
)
if torch_version == "2.0.0" and cuda_version == "cu121":
raise ValueError(
"PyG not support torch-2.0.* with cuda-12.1, please check https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html"
)

logger.info(
f"Installing PyG dependencies for torch-{torch_version} and cuda-{cuda_version}"
)
cmd = f"pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html"
run(cmd, shell=True)

0 comments on commit a1f13c4

Please sign in to comment.