diff --git a/empanada_napari/finetune.py b/empanada_napari/finetune.py index 884ef8b..7253f5c 100644 --- a/empanada_napari/finetune.py +++ b/empanada_napari/finetune.py @@ -65,9 +65,19 @@ def main(config): main_worker(config) def main_worker(config): - config['device'] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - if str(config['device']) == 'cpu': + if torch.cuda.is_available(): + device = torch.device("cuda:0") + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + config['device'] = device + + if torch.device("cuda:0"): + print("Using GPU for training.") + elif str(config['device']) == "mps": + print("Using M1 Mac hardware for training.") + elif str(config['device']) == 'cpu': print(f"Using CPU for training.") else: print(f"Using GPU for training.") diff --git a/empanada_napari/inference.py b/empanada_napari/inference.py index 13eec12..468c9ee 100644 --- a/empanada_napari/inference.py +++ b/empanada_napari/inference.py @@ -177,8 +177,13 @@ def __init__( use_gpu=True, use_quantized=False ): - # check whether GPU is available - device = torch.device('cuda:0' if torch.cuda.is_available() and use_gpu else 'cpu') + # check whether GPU or M1 Mac hardware is available + if torch.cuda.is_available() and use_gpu: + device = torch.device('cuda:0') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') if use_quantized and str(device) == 'cpu' and model_config.get('model_quantized') is not None: model_url = model_config['model_quantized'] else: @@ -340,9 +345,13 @@ def __init__( store_url=None, save_panoptic=False ): - # check whether GPU is available - # check whether GPU is available - device = torch.device('cuda:0' if torch.cuda.is_available() and use_gpu else 'cpu') + # check whether GPU or M1 Mac hardware is available + if torch.cuda.is_available() and use_gpu: + device = torch.device('cuda:0') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') if use_quantized and str(device) == 'cpu' and model_config.get('model_quantized') is not None: model_url = model_config['model_quantized'] else: diff --git a/empanada_napari/train.py b/empanada_napari/train.py index fdf3bf6..22afd65 100644 --- a/empanada_napari/train.py +++ b/empanada_napari/train.py @@ -66,9 +66,20 @@ def main(config): return main_worker(config) def main_worker(config): - config['device'] = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - if str(config['device']) == 'cpu': + # check whether GPU or M1 Mac hardware is available + if torch.cuda.is_available() and use_gpu: + device = torch.device('cuda:0') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') + config['device'] = device + + if str(config['device']) == 'cuda:0': + print("Using GPU for training.") + elif str(config['device']) == 'mps': + print("Using M1 Mac hardware for traiing.") + elif str(config['device']) == 'cpu': print(f"Using CPU for training.") else: print(f"Using GPU for training.")