Skip to content

Commit

Permalink
update save last
Browse files Browse the repository at this point in the history
  • Loading branch information
Ultimate-Storm committed Jul 4, 2023
1 parent 8511a3a commit cf3a747
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
15 changes: 7 additions & 8 deletions vis_breast_mri_discard_localorbest_model_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
'SL local best ckpt': [0.60, 0.60, 0.61, 0.62],
'SL global best ckpt': [0.60, 0.60, 0.61, 0.62, 0.61],
'Local merged - 80% data': [0.59, 0.58, 0.60, 0.60, 0.62]},
'DenseNet121-3D': {'Localhost1 - 40% data': [0.47, 0.79, 0.72, 0.66, 0.63],
'Localhost2 - 30% data': [0.38, 0.45, 0.65, 0.50, 0.51],
'Localhost3 - 10% data': [0.62, 0.67, 0.66, 0.50, 0.55],
'SL local best ckpt': [0.71, 0.71, 0.78, 0.79],
'SL global best ckpt': [0.73, 0.52, 0.70, 0.69, 0.60],
'Local merged - 80% data': [0.75, 0.72, 0.68, 0.68, 0.69]},

'Att-MIL': {'Localhost1 - 40% data': [0.71, 0.74, 0.69, 0.71, 0.72],
'Localhost2 - 30% data': [0.62, 0.62, 0.58, 0.66, 0.54],
'Localhost3 - 10% data': [0.46, 0.51, 0.55, 0.5, 0.48],
Expand All @@ -41,7 +36,12 @@
'SL global best ckpt': [0.76, 0.74, 0.75, 0.75, 0.74],
'Local merged - 80% data': [0.75, 0.75, 0.74, 0.76, 0.76]},


'DenseNet121-3D': {'Localhost1 - 40% data': [0.47, 0.79, 0.72, 0.66, 0.63],
'Localhost2 - 30% data': [0.38, 0.45, 0.65, 0.50, 0.51],
'Localhost3 - 10% data': [0.62, 0.67, 0.66, 0.50, 0.55],
'SL local best ckpt': [0.71, 0.71, 0.78, 0.79],
'SL global best ckpt': [0.73, 0.52, 0.70, 0.69, 0.60],
'Local merged - 80% data': [0.75, 0.72, 0.68, 0.68, 0.69]},
'ResNet18-3D': {'Localhost1 - 40% data': [0.69, 0.71, 0.63, 0.68, 0.67],
'Localhost2 - 30% data': [0.77, 0.76, 0.73, 0.75, 0.70],
'Localhost3 - 10% data': [0.67, 0.64, 0.67, 0.69, 0.63],
Expand All @@ -61,7 +61,6 @@
'SL local best ckpt': [0.82, 0.82, 0.83],
'SL global best ckpt': [0.86, 0.78, 0.77, 0.74, 0.81],
'Local merged - 80% data': [0.82, 0.81, 0.83, 0.82, 0.84]},

}

# convert string values to float
Expand Down
2 changes: 1 addition & 1 deletion workspace/automate_scripts/launch_sl/run_swci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ sudo "$script_dir/../../swarm_learning_scripts/run-swci" \
--cert="cert/swci-$host_index-cert.pem" \
--capath="cert/ca/capath" \
-e "http_proxy=" -e "https_proxy=" --apls-ip="$sentinel" --apls-port=5000 \
-e "SWCI_TASK_MAX_WAIT_TIME=500"
-e "SWCI_TASK_MAX_WAIT_TIME=5000"
5 changes: 1 addition & 4 deletions workspace/odelia-breast-mri/model/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def cal_weightage(train_size):
monitor=to_monitor,
#every_n_train_steps=log_every_n_steps,
save_last=True,
#save_top_k=2,
save_top_k=2,
#filename='odelia-epoch{epoch:02d}-val_AUC_ROC{val/AUC_ROC:.2f}',
mode=min_max,
)
Expand All @@ -181,8 +181,6 @@ def cal_weightage(train_size):
logger=TensorBoardLogger(save_dir=path_run_dir)
)
trainer.fit(model, datamodule=dm)

trainer.fit(model, datamodule=dm)
else:
swarmCallback = SwarmCallback(syncFrequency=512,
minPeers=min_peers,
Expand All @@ -194,7 +192,6 @@ def cal_weightage(train_size):
torch.autograd.set_detect_anomaly(True)
swarmCallback.logger.setLevel(logging.DEBUG)
swarmCallback.on_train_begin()

trainer = Trainer(
accelerator=accelerator,
precision=16,
Expand Down
8 changes: 3 additions & 5 deletions workspace/odelia-breast-mri/model/main_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def cal_weightage(train_size):
monitor=to_monitor,
#every_n_train_steps=log_every_n_steps,
save_last=True,
#save_top_k=2,
save_top_k=2,
#filename='odelia-epoch{epoch:02d}-val_AUC_ROC{val/AUC_ROC:.2f}',
mode=min_max,
)
Expand All @@ -181,8 +181,6 @@ def cal_weightage(train_size):
logger=TensorBoardLogger(save_dir=path_run_dir)
)
trainer.fit(model, datamodule=dm)

trainer.fit(model, datamodule=dm)
else:
swarmCallback = SwarmCallback(syncFrequency=512,
minPeers=min_peers,
Expand All @@ -205,14 +203,14 @@ def cal_weightage(train_size):
min_epochs=5,
log_every_n_steps=log_every_n_steps,
auto_lr_find=False,
max_epochs=10,
max_epochs=120,
num_sanity_val_steps=2,
logger=TensorBoardLogger(save_dir=path_run_dir)
)
trainer.fit(model, datamodule=dm)
swarmCallback.on_train_end()
model.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path)
model.save_last_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path)
model.save_last_checkpoint(trainer.logger.log_dir, checkpointing.last_model_path)

import subprocess

Expand Down
7 changes: 5 additions & 2 deletions workspace/odelia-breast-mri/model/predict_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def predict_last(model_dir, test_data_dir, model_name):
results['NN_pred'].extend(pred[:, 0].tolist())

df = pd.DataFrame(results)
df.to_csv(path_out / 'results.csv')
df.to_csv(path_out / 'results_last.csv')

# -------------------------- Confusion Matrix -------------------------
cm = confusion_matrix(df['GT'], df['NN'])
Expand Down Expand Up @@ -154,4 +154,7 @@ def predict_last(model_dir, test_data_dir, model_name):
logger.info(f"Malign Objects: {np.sum(y_true_lab)}")
logger.info("Confusion Matrix {}".format(cm))
logger.info("Sensitivity {:.2f}".format(sens))
logger.info("Specificity {:.2f}".format(spec))
logger.info("Specificity {:.2f}".format(spec))

del model
torch.cuda.empty_cache()
7 changes: 5 additions & 2 deletions workspace/odelia-breast-mri/model/predict_last_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def predict_last(model_dir, test_data_dir, model_name):
results['NN_pred'].extend(pred[:, 0].tolist())

df = pd.DataFrame(results)
df.to_csv(path_out / 'results.csv')
df.to_csv(path_out / 'results_last.csv')

# -------------------------- Confusion Matrix -------------------------
cm = confusion_matrix(df['GT'], df['NN'])
Expand Down Expand Up @@ -154,4 +154,7 @@ def predict_last(model_dir, test_data_dir, model_name):
logger.info(f"Malign Objects: {np.sum(y_true_lab)}")
logger.info("Confusion Matrix {}".format(cm))
logger.info("Sensitivity {:.2f}".format(sens))
logger.info("Specificity {:.2f}".format(spec))
logger.info("Specificity {:.2f}".format(spec))

del model
torch.cuda.empty_cache()

0 comments on commit cf3a747

Please sign in to comment.