Skip to content

Commit ab57ee9

Browse files
committed
fix: torchtune precalculate_logprobs
1 parent 538a662 commit ab57ee9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/art/torchtune/recipe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def explain_block_diffs(
860860
del hidden_states, logits
861861

862862
if return_new_logprobs:
863-
return torch.nn.functional.pad(new_logprobs[:, :-1], (1, 0), value=0.0)
863+
return torch.nn.functional.pad(new_logprobs[:-1], (1, 0), value=0.0)
864864

865865
old_logprobs = torch.where(
866866
torch.isnan(old_logprobs),

0 commit comments

Comments
 (0)