Skip to content

Commit 519efe6

Browse files
committed
chore: Fix formatting
1 parent 557e704 commit 519efe6

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/art/unsloth/train.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def compute_loss(
114114
next_input_ids = shift_tensor(inputs["tokens"], 0)
115115
chunk_size = _config.get("logprob_calculation_chunk_size", 1024)
116116
# Assert that sequence length is evenly divisible by the chunk size
117-
assert (
118-
seq_len % chunk_size == 0
119-
), f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
117+
assert seq_len % chunk_size == 0, (
118+
f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
119+
)
120120
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
121121
forward_kwargs = {}
122122
if "pixel_values" in inputs:
@@ -371,9 +371,7 @@ def _calculate_logprobs(
371371
chunk_logits = torch.matmul(chunk_hs, lm_head_t) # [B, chunk_size, V]
372372
chunk_selected_logits = torch.gather(
373373
chunk_logits, dim=-1, index=chunk_input_ids.unsqueeze(-1)
374-
).squeeze(
375-
-1
376-
) # [B, chunk_size]
374+
).squeeze(-1) # [B, chunk_size]
377375
chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size]
378376
log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp
379377

0 commit comments

Comments
 (0)