diff --git a/CHANGELOG.md b/CHANGELOG.md index 68ff995..31405df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,3 +7,10 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html + +## [0.1.0] - TBA + +### Added + +- support for scvi-tools >= 1.0 (#37) +- support for multiple NB lossed (#34) diff --git a/src/multigrate/model/_multivae.py b/src/multigrate/model/_multivae.py index 3007a46..13142f6 100644 --- a/src/multigrate/model/_multivae.py +++ b/src/multigrate/model/_multivae.py @@ -290,6 +290,7 @@ def train( weight_decay: float = 1e-3, eps: float = 1e-08, early_stopping: bool = True, + early_stopping_patience: int = 10, # save_best: bool = True, check_val_every_n_epoch: int | None = None, n_epochs_kl_warmup: int | None = None, @@ -419,7 +420,7 @@ def train( early_stopping=early_stopping, check_val_every_n_epoch=check_val_every_n_epoch, early_stopping_monitor="reconstruction_loss_validation", - early_stopping_patience=10, + early_stopping_patience=early_stopping_patience, # enable_checkpointing=enable_checkpointing, **kwargs, ) diff --git a/src/multigrate/module/_multivae_torch.py b/src/multigrate/module/_multivae_torch.py index a8b65fb..077cd49 100644 --- a/src/multigrate/module/_multivae_torch.py +++ b/src/multigrate/module/_multivae_torch.py @@ -181,11 +181,12 @@ def __init__( # assume for now that can only use nb/zinb once, i.e. for RNA-seq modality # TODO: add check for multiple nb/zinb losses given - self.theta = None + self.theta = [] for i, loss in enumerate(losses): if loss in ["nb", "zinb"]: - self.theta = torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups)) - break + self.theta.append(torch.nn.Parameter(torch.randn(self.input_dims[i], num_groups))) + else: + self.theta.append([]) # modality encoders cond_dim_enc = cond_dim * (len(cat_covariate_dims) + len(cont_covariate_dims)) if self.condition_encoders else 0 @@ -307,6 +308,7 @@ def _h_to_x(self, h, i): return x def _product_of_experts(self, mus, logvars, masks): + # print(mus, logvars, masks) vars = torch.exp(logvars) masks = masks.unsqueeze(-1).repeat(1, 1, vars.shape[-1]) mus_joint = torch.sum(mus * masks / vars, dim=1) @@ -658,7 +660,7 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks dec_mean = r size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1)) dec_mean = dec_mean * size_factor_view - dispersion = self.theta.T[group.squeeze().long()] + dispersion = self.theta[i].to(self.device).T[group.squeeze().long()] dispersion = torch.exp(dispersion) nb_loss = torch.sum(NegativeBinomial(mu=dec_mean, theta=dispersion).log_prob(x), dim=-1) nb_loss = loss_coefs[str(i)] * nb_loss @@ -667,9 +669,9 @@ def _calc_recon_loss(self, xs, rs, losses, group, size_factor, loss_coefs, masks dec_mean, dec_dropout = r dec_mean = dec_mean.squeeze() dec_dropout = dec_dropout.squeeze() - size_factor_view = size_factor.unsqueeze(1).expand(dec_mean.size(0), dec_mean.size(1)) + size_factor_view = size_factor.expand(dec_mean.size(0), dec_mean.size(1)) dec_mean = dec_mean * size_factor_view - dispersion = self.theta.T[group.squeeze().long()] + dispersion = self.theta[i].to(self.device).T[group.squeeze().long()] dispersion = torch.exp(dispersion) zinb_loss = torch.sum( ZeroInflatedNegativeBinomial(mu=dec_mean, theta=dispersion, zi_logits=dec_dropout).log_prob(x),