diff --git a/segmentation/dist_train.sh b/segmentation/dist_train.sh index edb94687..efa7bbfb 100755 --- a/segmentation/dist_train.sh +++ b/segmentation/dist_train.sh @@ -2,8 +2,9 @@ CONFIG=$1 GPUS=$2 +NODES=$3 PORT=${PORT:-29300} PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ -python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ +python -m torch.distributed.launch --nproc_per_node=$GPUS --nnodes=$NODES --master_port=$PORT \ $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} diff --git a/segmentation/slurm_train.sh b/segmentation/slurm_train.sh index 043d1ef3..2433644b 100644 --- a/segmentation/slurm_train.sh +++ b/segmentation/slurm_train.sh @@ -6,6 +6,7 @@ PARTITION=$1 JOB_NAME=$2 CONFIG=$3 GPUS=${GPUS:-8} +NUM_NODE=$4 GPUS_PER_NODE=${GPUS_PER_NODE:-8} CPUS_PER_TASK=${CPUS_PER_TASK:-5} SRUN_ARGS=${SRUN_ARGS:-""} @@ -16,6 +17,7 @@ srun -p ${PARTITION} \ --job-name=${JOB_NAME} \ --gres=gpu:${GPUS_PER_NODE} \ --ntasks=${GPUS} \ + --nnodes=${NUM_NODE} --ntasks-per-node=${GPUS_PER_NODE} \ --cpus-per-task=${CPUS_PER_TASK} \ --quotatype=spot \