diff --git a/veloproj/util.py b/veloproj/util.py index a9e0d41..b3f7b2d 100644 --- a/veloproj/util.py +++ b/veloproj/util.py @@ -279,7 +279,7 @@ def fit_model(args, adata, model, inputs, tensor_v=None, xyids=None, device=None i = 1 + iter if scaler is None: - loss = train_step_AE( + loss, info = train_step_AE( inputs, tensor_v, model, optimizer, @@ -293,7 +293,7 @@ def fit_model(args, adata, model, inputs, tensor_v=None, xyids=None, device=None mask=mask, ) else: - loss = train_step_AE_half( + loss, info = train_step_AE_half( inputs, tensor_v, model, optimizer, @@ -311,7 +311,7 @@ def fit_model(args, adata, model, inputs, tensor_v=None, xyids=None, device=None # torch.cuda.empty_cache() losses.append(loss) - pbar.set_description(f"Loss: {losses[-1]:.6f}") + pbar.set_description(info) if i % args.log_interval == 0: if (not np.isnan(losses[-1])) and (losses[-1] < min_loss): min_loss = losses[-1] @@ -516,19 +516,28 @@ def train_step_AE(Xs, device=device, norm=norm_lr ) - vloss = torch.sum(vloss) * aux_weight + vloss = torch.sum(vloss) + loss = loss + vloss * aux_weight if v_rg_wt > 0: v = model.encoder(Xs[xyids[0]] + tensor_v) - s preds = mask * (u - gamma * s - offset) - vloss = vloss + v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt + rg_loss = v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt + loss = loss + rg_loss * v_rg_wt + rg_loss = rg_loss.item() + lr_loss = vloss.item() - loss += vloss loss.backward() optimizer.step() if rt_all_loss: return loss.item(), ae_loss, lr_loss - return loss.item() + + if v_rg_wt > 0: + info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}, (RG) {v_rg_wt:.2f} * {rg_loss:.6f}" + else: + info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}" + + return loss.item(), info def train_step_AE_half(Xs, tensor_v, @@ -579,20 +588,29 @@ def train_step_AE_half(Xs, device=device, norm=norm_lr ) - vloss = torch.sum(vloss) * aux_weight + vloss = torch.sum(vloss) + loss = loss + vloss * aux_weight if v_rg_wt > 0: v = model.encoder(Xs[xyids[0]] + tensor_v) - s preds = mask * (u - gamma * s - offset) - vloss = vloss + v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt + rg_loss = v_rg_wt * torch.nn.functional.smooth_l1_loss(preds, v * mask, beta=smoothl1_beta) * v_rg_wt + loss = loss + rg_loss * v_rg_wt + rg_loss = rg_loss.item() + lr_loss = vloss.item() - loss += vloss half_scaler.scale(loss).backward() half_scaler.step(optimizer) if rt_all_loss: return loss.item(), ae_loss, lr_loss - return loss.item() + + if v_rg_wt > 0: + info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}, (RG) {v_rg_wt:.2f} * {rg_loss:.6f}" + else: + info = f"Loss: (Total) {loss.item():.6f}, (AE) {ae_loss:.6f}, (LR) {aux_weight:.2f} * {lr_loss:.6f}" + + return loss.item(), info def sklearn_decompose(method, X, S, U, V, use_leastsq=True, norm_lr=False):