-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_train.py
109 lines (91 loc) · 3.82 KB
/
main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
sys.path.append('/workspace/Documents')
import os
import torch
import numpy as np
import Diffusion_for_CT_motion.diffusion_models.conditional_DDPM_3D as ddpm_3D
import Diffusion_for_CT_motion.diffusion_models.conditional_EDM_3D as edm
import Diffusion_for_CT_motion.utils.functions_collection as ff
import Diffusion_for_CT_motion.utils.Build_list as Build_list
import Diffusion_for_CT_motion.utils.Generator as Generator
########################### important parameter: set the trial name and pre-trained model path
trial_name = 'portable_EDM_patch_3Dmotion_hist_v1'
pre_trained_model = None # or path of the pre-trained model
start_step = 0 # if new training, start step = 0, if continue, start_step = None
########################### important parameter: set the data path!
# define train
build_sheet = Build_list.Build(os.path.join('/mnt/camca_NAS/diffusion_ct_motion/data/Patient_list/Patient_list_train_test_simulated_all_motion_v1.xlsx')) # this is data path for training data
_,_,_,_, _,_, x0_list1, _, condition_list1, _, _,_,_ = build_sheet.__build__(batch_list = [0,1,2,3]) # these are training batches
x0_list_train = np.copy(x0_list1); condition_list_train = np.copy(condition_list1)
# define val
_,_,_,_, _,_, x0_list2, _, condition_list2, _, _,_,_ = build_sheet.__build__(batch_list = [4]) # this is data path for validation data
x0_list_val = np.copy(x0_list2); condition_list_val = np.copy(condition_list2)
# set default, don't change unless necessary
image_size_3D = [256,256,50]
patch_size = 128
slice_number = 50; slice_start = [6,12] # if slice_start is an int then it will be the start slice, no random pick; if it is a range [a, b], then randomly pick a starting slice in the range
val_slice_number = 20; val_slice_start = [20,21]
histogram_equalization = True
background_cutoff = -1000
maximum_cutoff = 2000
normalize_factor = 'equation'
# main code
model = ddpm_3D.Unet3D(
init_dim = 64,
channels = 1,
dim_mults = (1, 2, 4, 8),
flash_attn = False,
conditional_diffusion = True,
full_attn = (None, None, False, True),
)
diffusion_model = edm.EDM(
model,
image_size = [patch_size, patch_size, image_size_3D[-1]],
num_sample_steps = 50,
clip_or_not = False,)
generator_train = Generator.Dataset_dual_patch(
x0_list_train,
condition_list_train,
image_size_3D = image_size_3D,
patch_size = patch_size,
patch_stride = patch_size,
original_patch_num = 1,
random_sampled_patch_num = 2,
patch_selection = None,
slice_number = slice_number,
slice_start = slice_start,
histogram_equalization = histogram_equalization,
background_cutoff = background_cutoff,
maximum_cutoff = maximum_cutoff,
normalize_factor = normalize_factor,
shuffle = True,
augment = True, # only translation
augment_frequency = 0.2,)
generator_val = Generator.Dataset_dual_patch(
x0_list_val,
condition_list_val,
image_size_3D = [image_size_3D[0], image_size_3D[1], val_slice_number],
patch_size = 256,
patch_stride = 1,
original_patch_num = 1,
random_sampled_patch_num = 0,
patch_selection = None,
slice_number = val_slice_number,
slice_start = val_slice_start,
histogram_equalization = histogram_equalization,
background_cutoff = background_cutoff,
maximum_cutoff = maximum_cutoff,
normalize_factor = normalize_factor,)
trainer = edm.Trainer(
diffusion_model= diffusion_model,
generator_train = generator_train,
include_validation = True,
train_batch_size = 1,
train_num_steps = 10000, # total training epochs
results_folder = os.path.join('/mnt/camca_NAS/diffusion_ct_motion/models', trial_name, 'models'),
train_lr = 1e-4,
train_lr_decay_every = 100,
save_models_every = 1,
validation_every = None,
)
trainer.train(pre_trained_model=pre_trained_model, start_step= start_step)