@@ -69,11 +69,11 @@ def compute_loss(
6969 # if param_group.get("weight_decay"):
7070 # param_group["weight_decay"] = config.weight_decay
7171
72- if inputs ["pixel_values" ][0 ] is not None :
72+ if inputs . get ( "pixel_values" ) and inputs ["pixel_values" ][0 ] is not None :
7373 inputs ["pixel_values" ] = inputs ["pixel_values" ][0 ] # type: ignore
7474 else :
7575 del inputs ["pixel_values" ] # type: ignore
76- if inputs ["image_grid_thw" ][0 ] is not None :
76+ if inputs . get ( "image_grid_thw" ) and inputs ["image_grid_thw" ][0 ] is not None :
7777 inputs ["image_grid_thw" ] = inputs ["image_grid_thw" ][0 ] # type: ignore
7878 else :
7979 del inputs ["image_grid_thw" ] # type: ignore
@@ -114,9 +114,9 @@ def compute_loss(
114114 next_input_ids = shift_tensor (inputs ["tokens" ], 0 )
115115 chunk_size = _config .get ("logprob_calculation_chunk_size" , 1024 )
116116 # Assert that sequence length is evenly divisible by the chunk size
117- assert seq_len % chunk_size == 0 , (
118- f"Sequence length ( { seq_len } ) must be evenly divisible by chunk size ( { chunk_size } )"
119- )
117+ assert (
118+ seq_len % chunk_size == 0
119+ ), f"Sequence length ( { seq_len } ) must be evenly divisible by chunk size ( { chunk_size } )"
120120 os .environ ["UNSLOTH_RETURN_HIDDEN_STATES" ] = "1"
121121 forward_kwargs = {}
122122 if "pixel_values" in inputs :
@@ -371,7 +371,9 @@ def _calculate_logprobs(
371371 chunk_logits = torch .matmul (chunk_hs , lm_head_t ) # [B, chunk_size, V]
372372 chunk_selected_logits = torch .gather (
373373 chunk_logits , dim = - 1 , index = chunk_input_ids .unsqueeze (- 1 )
374- ).squeeze (- 1 ) # [B, chunk_size]
374+ ).squeeze (
375+ - 1
376+ ) # [B, chunk_size]
375377 chunk_logsumexp = torch .logsumexp (chunk_logits , dim = - 1 ) # [B, chunk_size]
376378 log_probs [:, i : i + chunk_size ] = chunk_selected_logits - chunk_logsumexp
377379
0 commit comments