Skip to content

Commit 557e704

Browse files
committed
fix: VLM tensor issue
1 parent 711169f commit 557e704

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

src/art/serverless/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from openai._types import NOT_GIVEN
55
from tqdm import auto as tqdm
66

7-
from art.client import Client, ExperimentalTrainingConfig
7+
from art.serverless.client import Client, ExperimentalTrainingConfig
88
from art.utils.deploy_model import LoRADeploymentJob, LoRADeploymentProvider
99

1010
from .. import dev
@@ -128,8 +128,10 @@ async def _train_model(
128128
model_id=model.id,
129129
trajectory_groups=trajectory_groups,
130130
experimental_config=ExperimentalTrainingConfig(
131+
advantage_balance=dev_config.get("advantage_balance"),
131132
learning_rate=config.learning_rate,
132133
precalculate_logprobs=dev_config.get("precalculate_logprobs"),
134+
scale_rewards=dev_config.get("scale_rewards"),
133135
),
134136
)
135137
after: str | None = None
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from openai.pagination import AsyncCursorPage
2020
from typing_extensions import override
2121

22-
from .trajectories import TrajectoryGroup
22+
from ..trajectories import TrajectoryGroup
2323

2424
ResponseT = TypeVar("ResponseT")
2525

@@ -50,8 +50,10 @@ class DeleteCheckpointsResponse(BaseModel):
5050

5151

5252
class ExperimentalTrainingConfig(TypedDict, total=False):
53+
advantage_balance: float | None
5354
learning_rate: float | None
5455
precalculate_logprobs: bool | None
56+
scale_rewards: bool | None
5557

5658

5759
class TrainingJob(BaseModel):

src/art/unsloth/train.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ def compute_loss(
6969
# if param_group.get("weight_decay"):
7070
# param_group["weight_decay"] = config.weight_decay
7171

72-
if inputs["pixel_values"][0] is not None:
72+
if inputs.get("pixel_values") and inputs["pixel_values"][0] is not None:
7373
inputs["pixel_values"] = inputs["pixel_values"][0] # type: ignore
7474
else:
7575
del inputs["pixel_values"] # type: ignore
76-
if inputs["image_grid_thw"][0] is not None:
76+
if inputs.get("image_grid_thw") and inputs["image_grid_thw"][0] is not None:
7777
inputs["image_grid_thw"] = inputs["image_grid_thw"][0] # type: ignore
7878
else:
7979
del inputs["image_grid_thw"] # type: ignore
@@ -114,9 +114,9 @@ def compute_loss(
114114
next_input_ids = shift_tensor(inputs["tokens"], 0)
115115
chunk_size = _config.get("logprob_calculation_chunk_size", 1024)
116116
# Assert that sequence length is evenly divisible by the chunk size
117-
assert seq_len % chunk_size == 0, (
118-
f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
119-
)
117+
assert (
118+
seq_len % chunk_size == 0
119+
), f"Sequence length ({seq_len}) must be evenly divisible by chunk size ({chunk_size})"
120120
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "1"
121121
forward_kwargs = {}
122122
if "pixel_values" in inputs:
@@ -371,7 +371,9 @@ def _calculate_logprobs(
371371
chunk_logits = torch.matmul(chunk_hs, lm_head_t) # [B, chunk_size, V]
372372
chunk_selected_logits = torch.gather(
373373
chunk_logits, dim=-1, index=chunk_input_ids.unsqueeze(-1)
374-
).squeeze(-1) # [B, chunk_size]
374+
).squeeze(
375+
-1
376+
) # [B, chunk_size]
375377
chunk_logsumexp = torch.logsumexp(chunk_logits, dim=-1) # [B, chunk_size]
376378
log_probs[:, i : i + chunk_size] = chunk_selected_logits - chunk_logsumexp
377379

0 commit comments

Comments
 (0)