Skip to content

VAE training using test loss to decay beta #31

@Ajoo

Description

@Ajoo

Hi,

I noticed that while training the VAE for TabSyn, a loss computed on the test set seems to be used to decay the learning rate and $\beta$. Is this intentional?

The issue is in the following snippet:

tabsyn/tabsyn/vae/main.py

Lines 165 to 187 in cb5ac0f

with torch.no_grad():
Recon_X_num, Recon_X_cat, mu_z, std_z = model(X_test_num, X_test_cat)
val_mse_loss, val_ce_loss, val_kl_loss, val_acc = compute_loss(X_test_num, X_test_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)
val_loss = val_mse_loss.item() * 0 + val_ce_loss.item()
scheduler.step(val_loss)
new_lr = optimizer.param_groups[0]['lr']
if new_lr != current_lr:
current_lr = new_lr
print(f"Learning rate updated: {current_lr}")
train_loss = val_loss
if train_loss < best_train_loss:
best_train_loss = train_loss
patience = 0
torch.save(model.state_dict(), model_save_path)
else:
patience += 1
if patience == 10:
if beta > min_beta:
beta = beta * lambd

The loss used to decay lr and beta is computed in lines 168 and 169 on the test set.
Furthermore, this loss includes only the cross-entropy component for the categoricals as the mse loss is zeroed out in line 169. This seems like an odd choice given that in some datasets there might be very few categoricals.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions