Skip to content

Commit

Permalink
update training print
Browse files Browse the repository at this point in the history
  • Loading branch information
qiaochen committed May 1, 2022
1 parent 2420d4f commit 1b19004
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions veloproj/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def fit_model(args, adata, model, inputs, tensor_v=None, xyids=None, device=None
i = 1 + iter

if scaler is None:
loss = train_step_AE(
loss, info = train_step_AE(
inputs,
tensor_v,
model, optimizer,
Expand All @@ -293,7 +293,7 @@ def fit_model(args, adata, model, inputs, tensor_v=None, xyids=None, device=None
mask=mask,
)
else:
loss = train_step_AE_half(
loss, info = train_step_AE_half(
inputs,
tensor_v,
model, optimizer,
Expand All @@ -311,7 +311,7 @@ def fit_model(args, adata, model, inputs, tensor_v=None, xyids=None, device=None

# torch.cuda.empty_cache()
losses.append(loss)
pbar.set_description(f"Loss: {losses[-1]:.6f}")
pbar.set_description(info)
if i % args.log_interval == 0:
if (not np.isnan(losses[-1])) and (losses[-1] < min_loss):
min_loss = losses[-1]
Expand Down Expand Up @@ -516,19 +516,28 @@ def train_step_AE(Xs,
device=device,
norm=norm_lr
)
vloss = torch.sum(vloss) * aux_weight
vloss = torch.sum(vloss)
loss = loss + vloss * aux_weight
if v_rg_wt > 0:
v = model.encoder(Xs[xyids[0]] + tensor_v) - s
preds = mask * (u - gamma * s - offset)
vloss = vloss + v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt
rg_loss = v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt
loss = loss + rg_loss * v_rg_wt
rg_loss = rg_loss.item()

lr_loss = vloss.item()
loss += vloss

loss.backward()
optimizer.step()
if rt_all_loss:
return loss.item(), ae_loss, lr_loss
return loss.item()

if v_rg_wt > 0:
info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}, (RG) {v_rg_wt:.2f} * {rg_loss:.6f}"
else:
info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}"

return loss.item(), info

def train_step_AE_half(Xs,
tensor_v,
Expand Down Expand Up @@ -579,20 +588,29 @@ def train_step_AE_half(Xs,
device=device,
norm=norm_lr
)
vloss = torch.sum(vloss) * aux_weight
vloss = torch.sum(vloss)
loss = loss + vloss * aux_weight
if v_rg_wt > 0:
v = model.encoder(Xs[xyids[0]] + tensor_v) - s
preds = mask * (u - gamma * s - offset)
vloss = vloss + v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt
rg_loss = v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt
loss = loss + rg_loss * v_rg_wt
rg_loss = rg_loss.item()

lr_loss = vloss.item()
loss += vloss

half_scaler.scale(loss).backward()
half_scaler.step(optimizer)

if rt_all_loss:
return loss.item(), ae_loss, lr_loss
return loss.item()

if v_rg_wt > 0:
info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}, (RG) {v_rg_wt:.2f} * {rg_loss:.6f}"
else:
info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}"

return loss.item(), info


def sklearn_decompose(method, X, S, U, V, use_leastsq=True, norm_lr=False):
Expand Down

0 comments on commit 1b19004

Please sign in to comment.