-
Notifications
You must be signed in to change notification settings - Fork 0
/
lit_models.py
112 lines (92 loc) · 5.01 KB
/
lit_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import lightning.pytorch as pl
import torch
import torchmetrics
from utils.runner.metric import youden_j, c_index
import shap
class LitFullModel(pl.LightningModule):
def __init__(self, models: dict[str, torch.nn.Module], optimizers: dict[str, torch.optim.Optimizer], config: dict):
super().__init__()
self.save_hyperparameters(config)
self.feat_ext = models['feat_ext']
self.classifier = models['clf']
self.optimizers_dict = optimizers
self.step_results = [] # Slow but clean.
# Disable automatic optimization for manual backward if there are multiple optimizers.
if 'all' not in self.optimizers_dict:
self.automatic_optimization = False
def configure_optimizers(self):
if 'all' in self.optimizers_dict:
return self.optimizers_dict['all']
return [self.optimizers_dict['feat_ext'], self.optimizers_dict['clf']]
def training_step(self, batch, batch_idx):
(genomic, clinical, index, project_id), (overall_survival, survival_time, vital_status) = batch
if isinstance(self.optimizers(), list):
self.optimizers()[0].zero_grad()
self.optimizers()[1].zero_grad()
embedding = self.feat_ext(genomic, clinical, project_id)
y = self.classifier(embedding, project_id)
loss = torch.nn.functional.binary_cross_entropy_with_logits(y, overall_survival)
if isinstance(self.optimizers(), list):
self.manual_backward(loss)
self.optimizers()[0].step()
self.optimizers()[1].step()
self.log('train_loss', loss, on_epoch=True, on_step=False)
return loss
def _shared_eval(self, batch, batch_idx):
(genomic, clinical, index, project_id), (overall_survival, survival_time, vital_status) = batch
# feat_extractor = self.feat_ext
# explanation = shap.DeepExplainer(feat_extractor, train)
# shap_values = explanation.shap_values(train)
# print(shap_values)
# # plot the SHAP values
# shap.summary_plot(shap_values, train)
# #save the plot
#shap.save_html(log_path + 'shap_values.html', shap_values, train)
y = self.classifier(self.feat_ext(genomic, clinical, project_id), project_id)
loss = torch.nn.functional.binary_cross_entropy_with_logits(y, overall_survival)
self.step_results.append({
'output': y.detach().cpu(),
'label': overall_survival.detach().cpu().type(torch.int64),
'survival_time': survival_time.detach().cpu(),
'vital_status': vital_status.detach().cpu(),
'project_id': project_id.detach().cpu(),
})
return loss
def _shared_epoch_end(self) -> None:
outputs = torch.cat([result['output'] for result in self.step_results])
outputs = torch.functional.F.sigmoid(outputs) # AUC and PRC will not be affected.
labels = torch.cat([result['label'] for result in self.step_results])
survival_time = torch.cat([result['survival_time'] for result in self.step_results])
vital_status = torch.cat([result['vital_status'] for result in self.step_results])
project_id = torch.cat([result['project_id'] for result in self.step_results])
thres = youden_j(outputs, labels).astype('float')
for i in torch.unique(project_id):
mask = project_id == i
roc = torchmetrics.functional.auroc(outputs[mask], labels[mask], 'binary')
prc = torchmetrics.functional.average_precision(outputs[mask], labels[mask], 'binary')
try:
cindex = c_index(outputs[mask], survival_time[mask], vital_status[mask])
except ZeroDivisionError: #sometimes vital status is all 0 or 1
# print('ZeroDivisionError')
# print(outputs[mask])
# print(survival_time[mask])
# print(vital_status[mask])
cindex = 0
self.log(f'Youden_{i}', thres, on_epoch=True, on_step=False)
self.log(f'AUC_{i}', roc, on_epoch=True, on_step=False)
self.log(f'PRC_{i}', prc, on_epoch=True, on_step=False)
self.log(f'C-Index_{i}', cindex, on_epoch=True, on_step=False)
self.step_results.clear()
def validation_step(self, batch, batch_idx):
loss = self._shared_eval(batch, batch_idx)
self.log('loss', loss, on_epoch=True, on_step=False, prog_bar=True)
def on_validation_epoch_end(self) -> None:
self._shared_epoch_end()
def test_step(self, batch, batch_idx):
self._shared_eval(batch, batch_idx)
def on_test_epoch_end(self) -> None:
self._shared_epoch_end()
def predict_step(self, batch, batch_idx, dataloader_idx=None):
(genomic, clinical, index, project_id), (overall_survival, survival_time, vital_status) = batch
y = self.classifier(self.feat_ext(genomic, clinical, project_id), project_id)
return y.detach().cpu().numpy()