@@ -111,7 +111,7 @@ def parse_args():
111111 )
112112 ap .add_argument (
113113 "--log-prob-collapse-option" ,
114- choices = ["sum" , "mean" ],
114+ choices = ["sum" , "mean" , "per_token" ],
115115 default = "mean" ,
116116 help = "How to collapse the log probabilities across the sequence dimension." ,
117117 )
@@ -160,7 +160,7 @@ def __init__(
160160 self ,
161161 * args ,
162162 output_log_prob_seqs : bool = False ,
163- log_prob_collapse_option : Literal ["sum" , "mean" ] = "mean" ,
163+ log_prob_collapse_option : Literal ["sum" , "mean" , "per_token" ] = "mean" ,
164164 ** kwargs ,
165165 ):
166166 """Initialize the predictor with our needs around computing log probabilities."""
@@ -195,10 +195,14 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor:
195195 2 , # along the vocab dimension...
196196 input_ids .unsqueeze (- 1 ), # using the token ids to index.
197197 ).squeeze (- 1 )
198- log_prob_seqs = torch .sum (logprobs * batch ["loss_mask" ][:, 1 :].float (), dim = - 1 )
199- if self .log_prob_collapse_option == "mean" :
200- log_prob_seqs = log_prob_seqs / (batch ["loss_mask" ][:, 1 :].float ().sum (dim = - 1 ) + 1e-8 )
201- return {"log_probs_seqs" : log_prob_seqs .cpu (), "seq_idx" : batch ["seq_idx" ].cpu ()}
198+ log_prob_per_token = logprobs * batch ["loss_mask" ][:, 1 :].float ()
199+ if self .log_prob_collapse_option == "per_token" :
200+ return {"log_probs_seqs" : log_prob_per_token .cpu (), "seq_idx" : batch ["seq_idx" ].cpu ()}
201+ else :
202+ log_prob_seqs = torch .sum (log_prob_per_token , dim = 1 )
203+ if self .log_prob_collapse_option == "mean" :
204+ log_prob_seqs = log_prob_seqs / (batch ["loss_mask" ][:, 1 :].float ().sum (dim = - 1 ) + 1e-8 )
205+ return {"log_probs_seqs" : log_prob_seqs .cpu (), "seq_idx" : batch ["seq_idx" ].cpu ()}
202206 else :
203207 # If the user wants to match back to logits, then they will need to do the offsetting logic themselves.
204208 return {
@@ -504,7 +508,7 @@ def __init__(
504508 config ,
505509 tokenizer = None ,
506510 output_log_prob_seqs : bool = False ,
507- log_prob_collapse_option : Literal ["sum" , "mean" ] = "mean" ,
511+ log_prob_collapse_option : Literal ["sum" , "mean" , "per_token" ] = "mean" ,
508512 ):
509513 """Initialize the MambaPredictor, which wraps the mamba model for prediction handling model parallelism.
510514
0 commit comments