File tree Expand file tree Collapse file tree 1 file changed +4
-6
lines changed
Expand file tree Collapse file tree 1 file changed +4
-6
lines changed Original file line number Diff line number Diff line change @@ -114,9 +114,9 @@ def compute_loss(
114114 next_input_ids = shift_tensor (inputs ["tokens" ], 0 )
115115 chunk_size = _config .get ("logprob_calculation_chunk_size" , 1024 )
116116 # Assert that sequence length is evenly divisible by the chunk size
117- assert (
118- seq_len % chunk_size == 0
119- ), f"Sequence length ( { seq_len } ) must be evenly divisible by chunk size ( { chunk_size } )"
117+ assert seq_len % chunk_size == 0 , (
118+ f"Sequence length ( { seq_len } ) must be evenly divisible by chunk size ( { chunk_size } )"
119+ )
120120 os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "1"
121121 forward_kwargs = {}
122122 if "pixel_values" in inputs :
@@ -371,9 +371,7 @@ def _calculate_logprobs(
371371 chunk_logits = torch .matmul (chunk_hs , lm_head_t ) # [B, chunk_size, V]
372372 chunk_selected_logits = torch .gather (
373373 chunk_logits , dim = - 1 , index = chunk_input_ids .unsqueeze (- 1 )
374- ).squeeze (
375- - 1
376- ) # [B, chunk_size]
374+ ).squeeze (- 1 ) # [B, chunk_size]
377375 chunk_logsumexp = torch .logsumexp (chunk_logits , dim = - 1 ) # [B, chunk_size]
378376 log_probs [:, i : i + chunk_size ] = chunk_selected_logits - chunk_logsumexp
379377
You can’t perform that action at this time.
0 commit comments