Skip to content

Commit 5ca160d

Browse files
committed
enable per output token likelihood prediction for evo2
Signed-off-by: Yang Zhang <[email protected]>
1 parent 77fdb9a commit 5ca160d

File tree

1 file changed

+11
-7
lines changed
  • sub-packages/bionemo-evo2/src/bionemo/evo2/run

1 file changed

+11
-7
lines changed

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def parse_args():
111111
)
112112
ap.add_argument(
113113
"--log-prob-collapse-option",
114-
choices=["sum", "mean"],
114+
choices=["sum", "mean", "per_token"],
115115
default="mean",
116116
help="How to collapse the log probabilities across the sequence dimension.",
117117
)
@@ -160,7 +160,7 @@ def __init__(
160160
self,
161161
*args,
162162
output_log_prob_seqs: bool = False,
163-
log_prob_collapse_option: Literal["sum", "mean"] = "mean",
163+
log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean",
164164
**kwargs,
165165
):
166166
"""Initialize the predictor with our needs around computing log probabilities."""
@@ -195,10 +195,14 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor:
195195
2, # along the vocab dimension...
196196
input_ids.unsqueeze(-1), # using the token ids to index.
197197
).squeeze(-1)
198-
log_prob_seqs = torch.sum(logprobs * batch["loss_mask"][:, 1:].float(), dim=-1)
199-
if self.log_prob_collapse_option == "mean":
200-
log_prob_seqs = log_prob_seqs / (batch["loss_mask"][:, 1:].float().sum(dim=-1) + 1e-8)
201-
return {"log_probs_seqs": log_prob_seqs.cpu(), "seq_idx": batch["seq_idx"].cpu()}
198+
log_prob_per_token = logprobs * batch["loss_mask"][:, 1:].float()
199+
if self.log_prob_collapse_option == "per_token":
200+
return {"log_probs_seqs": log_prob_per_token.cpu(), "seq_idx": batch["seq_idx"].cpu()}
201+
else:
202+
log_prob_seqs = torch.sum(log_prob_per_token, dim=1)
203+
if self.log_prob_collapse_option == "mean":
204+
log_prob_seqs = log_prob_seqs / (batch["loss_mask"][:, 1:].float().sum(dim=-1) + 1e-8)
205+
return {"log_probs_seqs": log_prob_seqs.cpu(), "seq_idx": batch["seq_idx"].cpu()}
202206
else:
203207
# If the user wants to match back to logits, then they will need to do the offsetting logic themselves.
204208
return {
@@ -504,7 +508,7 @@ def __init__(
504508
config,
505509
tokenizer=None,
506510
output_log_prob_seqs: bool = False,
507-
log_prob_collapse_option: Literal["sum", "mean"] = "mean",
511+
log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean",
508512
):
509513
"""Initialize the MambaPredictor, which wraps the mamba model for prediction handling model parallelism.
510514

0 commit comments

Comments
 (0)