Skip to content

Fix VAE sampling: enable reparameterization by default, return mu for inference#62

Draft
Copilot wants to merge 4 commits into
developfrom
copilot/fix-sampling-issues-during-training
Draft

Fix VAE sampling: enable reparameterization by default, return mu for inference#62
Copilot wants to merge 4 commits into
developfrom
copilot/fix-sampling-issues-during-training

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented May 5, 2026

Encoder defaulted to sampling=False, causing the encoder to return the raw hidden state h as z instead of the reparameterized sample — making every VAE a plain autoencoder silently.

Root cause

forward() returned (mu, logvar, h) when sampling=False (wrong: hmu). The build_encoder factory never passed sampling=True. Result: KL loss computed against mu/logvar but decoder received h.

Changes

  • encoder.py: Default sampling=True. When sampling=False, return (mu, logvar, mu) — the third element is now always the correct decode input (reparameterized z or deterministic mu).
  • base_vae.py: build_encoder() accepts and forwards sampling=True. encode() returns mu for stable deterministic embeddings. SimpleEncoder likewise returns mu.
  • vae.py: VAE.__init__ exposes sampling: bool = True, passed through to build_encoder. Serialized in to_dict/from_dict.
  • commands/model.py: Added --sampling/--no-sampling flag to train-vae (default True). encode command now uses result[0] (mu) instead of result[2] for inference output.
# Before: z was the raw hidden state h — not a VAE sample
enc = Encoder(feature_dim=64, latent_dim=16)  # sampling=False default
mu, logvar, z = enc(x)
assert torch.allclose(z, mu)  # False! z was h, not even mu

# After: sampling=True by default; sampling=False gives z=mu (deterministic AE)
enc = Encoder(feature_dim=64, latent_dim=16)  # sampling=True
mu, logvar, z = enc(x)
# z = mu + eps*std  ✓

enc_det = Encoder(feature_dim=64, latent_dim=16, sampling=False)
mu, logvar, z = enc_det(x)
assert torch.allclose(z, mu)  # True ✓

Tests

Four new tests in test_encoder.py cover: sampling enabled produces stochastic z, disabled gives z == mu, default is True, and forward() always returns a 3-tuple.

Copilot AI linked an issue May 5, 2026 that may be closed by this pull request
Copilot AI changed the title [WIP] Fix sampling behavior in VAE training process Fix VAE sampling: enable reparameterization by default, return mu for inference May 5, 2026
Copilot AI requested a review from kellrott May 5, 2026 17:44
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 5, 2026

☂️ Python Coverage

current status: ❌

Overall Coverage

Lines Covered Coverage Threshold Status
2602 1255 48% 30% 🟢

New Files

No new covered files...

Modified Files

File Coverage Status
src/embkit/commands/model.py 35% 🔴
src/embkit/models/vae/base_vae.py 67% 🔴
src/embkit/models/vae/encoder.py 94% 🟢
src/embkit/models/vae/vae.py 37% 🔴
TOTAL 58% 🔴

updated for commit: 970fbc2 by action🐍

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see this on the coverage report, but the tests themselves are working.

Copy link
Copy Markdown

@kbcoulter kbcoulter left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests do not seem to appear in the coverage report (missing rather than excluded), but everything runs as expected and the math looks correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Potential Issues with sampling

3 participants