Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch size limitations for 3D/video training #76

Open
ff98li opened this issue Dec 10, 2024 · 1 comment
Open

Batch size limitations for 3D/video training #76

ff98li opened this issue Dec 10, 2024 · 1 comment

Comments

@ff98li
Copy link

ff98li commented Dec 10, 2024

Thank you for sharing the awesome work!

I'm trying to fine-tune it for 3D/video segmentation but ran into an issue. The code seems to have the batch size hard-coded to 1:

if args.dataset == 'btcv':
'''btcv data'''
btcv_train_dataset = BTCV(args, args.data_path, transform = None, transform_msk= None, mode = 'Training', prompt=args.prompt)
btcv_test_dataset = BTCV(args, args.data_path, transform = None, transform_msk= None, mode = 'Test', prompt=args.prompt)
nice_train_loader = DataLoader(btcv_train_dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
nice_test_loader = DataLoader(btcv_test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
'''end'''
elif args.dataset == 'amos':
'''amos data'''
amos_train_dataset = AMOS(args, args.data_path, transform = None, transform_msk= None, mode = 'Training', prompt=args.prompt)
amos_test_dataset = AMOS(args, args.data_path, transform = None, transform_msk= None, mode = 'Test', prompt=args.prompt)
nice_train_loader = DataLoader(amos_train_dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
nice_test_loader = DataLoader(amos_test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
'''end'''

When I tried bumping it up to 2, it threw an error:
image

Quick question: In your 3D/video model experiments, did you also train with just one video per batch? Would love to know if this is expected behavior or if I might be missing something.

Thanks in advance for any insights!

@CodeHarcourt
Copy link

Traceback (most recent call last):
File "/root/Medical-SAM2/train_3d.py", line 112, in
main()
File "/root/Medical-SAM2/train_3d.py", line 95, in main
loss, prompt_loss, non_prompt_loss = function.train_sam(args, net, optimizer1, optimizer2, nice_train_loader, epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Medical-SAM2/func_3d/function.py", line 115, in train_sam
_, _, _ = net.train_add_new_bbox(
^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 439, in train_add_new_bbox
out_frame_idx, out_obj_ids, out_mask_logits = self.train_add_new_points(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 523, in train_add_new_points
current_out, _ = self._run_single_frame_inference(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Medical-SAM2/sam2_train/sam2_video_predictor.py", line 1351, in _run_single_frame_inference
pred_masks_gpu = fill_holes_in_mask_scores(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/Medical-SAM2/sam2_train/utils/misc.py", line 255, in fill_holes_in_mask_scores
is_hole = (labels > 0) & (areas <= max_area)
^^^^^^^^^^
RuntimeError: CUDA error: no kernel image is available for execution on the device
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Hello, may I ask if you encountered the error mentioned above while adjusting the code? How did you handle it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants