Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__pycache__/
.art/
.art-backup/
.env
.venv/
grpo_trainer_lora_model/
Expand Down
2 changes: 2 additions & 0 deletions .skyignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__pycache__/
.art/
.art-backup/
*.safetensors
# .env
.venv/
grpo_trainer_lora_model/
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ dependencies = [
"typer>=0.15.2",
"litellm==1.74.1",
"weave>=0.51.51",
"uvicorn[standard]",
"fastapi",
]

[project.optional-dependencies]
Expand Down
18 changes: 18 additions & 0 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import warnings
from typing import Literal

from typing_extensions import TypedDict

from .engine import EngineArgs

ENGINE_INIT_ONLY_ARGS = {
"max_logprobs",
"gpu_memory_utilization",
"tensor_parallel_size",
"max_model_len",
}


def get_openai_server_config(
model_name: str,
Expand Down Expand Up @@ -35,6 +43,16 @@ def get_openai_server_config(
generation_config="vllm",
)
engine_args.update(config.get("engine_args", {}))
user_engine_args = config.get("engine_args", {})
ignored_args = set(user_engine_args.keys()) & ENGINE_INIT_ONLY_ARGS
if ignored_args:
warnings.warn(
f"OpenAIServerConfig.engine_args contains {ignored_args} which will be "
f"ignored. The vLLM engine is initialized by Unsloth before this config "
f"is applied. Use TrainableModel._internal_config.engine_args instead.",
UserWarning,
stacklevel=2,
)
return OpenAIServerConfig(
log_file=log_file, server_args=server_args, engine_args=engine_args
)
Expand Down
20 changes: 20 additions & 0 deletions src/art/local/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,26 @@ async def _train_model(
dev_config: dev.TrainConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
print("[DEBUG _train_model] Received trajectory_groups")
for tg_idx, tg in enumerate(trajectory_groups):
rewards = [t.reward for t in tg.trajectories]
print(f"[DEBUG _train_model] tg={tg_idx} rewards={rewards}")
for traj_idx, traj in enumerate(tg.trajectories):
for msg_idx, msg in enumerate(traj.messages_and_choices):
if isinstance(msg, dict) and msg.get("role") == "assistant":
print(f"[DEBUG _train_model] tg={tg_idx} traj={traj_idx} msg={msg_idx}")
print(f"[DEBUG _train_model] Assistant msg keys: {list(msg.keys())}")
print(f"[DEBUG _train_model] has logprobs: {'logprobs' in msg}")
if 'logprobs' in msg:
lp = msg['logprobs']
print(f"[DEBUG _train_model] logprobs type: {type(lp)}, truthy: {bool(lp)}")
if isinstance(lp, dict):
print(f"[DEBUG _train_model] logprobs keys: {list(lp.keys())}")
if 'values' in lp:
print(f"[DEBUG _train_model] logprobs['values'] len: {len(lp['values'])}")
print(f"[DEBUG _train_model] token_ids present: {'token_ids' in msg and msg.get('token_ids') is not None}")
if 'token_ids' in msg and msg.get('token_ids') is not None:
print(f"[DEBUG _train_model] token_ids len: {len(msg['token_ids'])}")
if verbose:
print("Starting _train_model")
service = await self._get_service(model)
Expand Down
21 changes: 19 additions & 2 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Loss(BaseModel):
mean_kl: torch.Tensor
mean_entropy: torch.Tensor | None
probs_corr: torch.Tensor
frac_old_logprobs_valid: float
mean_importance_ratio: torch.Tensor
clip_fraction: torch.Tensor


def loss_fn(
Expand All @@ -32,6 +35,9 @@ def loss_fn(
)
weights = shift_tensor(inputs["weights"], 0.0)
old_logprobs_mask = ~torch.isnan(old_logprobs)
frac_old_logprobs_valid = (
old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6)
).item()
probs_corr = torch.corrcoef(
torch.stack(
[
Expand Down Expand Up @@ -77,15 +83,23 @@ def loss_fn(
)
if tau := experimental_config.get("kimi_k2_tau", None):
advantages -= tau * logprob_diff.detach()
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
if experimental_config.get("ppo", True):
policy_loss = -torch.min(
prob_ratio * advantages,
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
clipped_ratio * advantages,
)
else:
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
policy_loss = -(
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
clipped_ratio.detach()
* advantages
* new_logprobs
)
Expand Down Expand Up @@ -123,6 +137,9 @@ def loss_fn(
mean_kl=mean_kl,
mean_entropy=mean_entropy,
probs_corr=probs_corr,
frac_old_logprobs_valid=frac_old_logprobs_valid,
mean_importance_ratio=mean_importance_ratio,
clip_fraction=clip_fraction,
)


Expand Down
26 changes: 24 additions & 2 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Optional,
TypeVar,
cast,
overload,
)

import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from typing_extensions import Never

from . import dev
Expand Down Expand Up @@ -279,6 +288,19 @@ def __init__(
# Bypass BaseModel __setattr__ to allow setting private attr
object.__setattr__(self, "_internal_config", _internal_config)

@model_validator(mode="wrap")
@classmethod
def _preserve_internal_config(
cls, data: Any, handler: Any
) -> "TrainableModel[ModelConfig]":
internal_config = None
if isinstance(data, dict) and "_internal_config" in data:
internal_config = data.pop("_internal_config")
model = handler(data)
if internal_config is not None:
object.__setattr__(model, "_internal_config", internal_config)
return model

@overload
def __new__(
cls,
Expand Down
115 changes: 88 additions & 27 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,28 @@ def tokenize_trajectory_groups(
shuffle_group_trajectories: bool = True,
image_processor: BaseImageProcessor | None = None,
) -> Generator["TokenizedResult", None, None]:
for group in trajectory_groups:
print(f"[TOKENIZE_GROUPS] Starting with {len(trajectory_groups)} groups")
for group_idx, group in enumerate(trajectory_groups):
if not group:
continue
print(f"[TOKENIZE_GROUPS] Group {group_idx}: {len(group)} trajectories")
results: list[TokenizedResult] = []
# Calculate GRPO group mean and standard deviation
reward_mean = sum(trajectory.reward for trajectory in group) / len(group)
reward_std = math.sqrt(
sum((trajectory.reward - reward_mean) ** 2 for trajectory in group)
/ len(group)
)
for trajectory in group:
print(f"[TOKENIZE_GROUPS] Group {group_idx}: rewards={[t.reward for t in group]}, mean={reward_mean}, std={reward_std}")
for traj_idx, trajectory in enumerate(group):
# Calculate GRPO advantage for this trajectory
advantage = trajectory.reward - reward_mean
if scale_rewards:
advantage /= reward_std + 1e-6
print(f"[TOKENIZE_GROUPS] Group {group_idx} Traj {traj_idx}: raw_adv={trajectory.reward - reward_mean}, scaled_adv={advantage}")
# Skip trajectories with no advantage
if advantage == 0:
print(f"[TOKENIZE_GROUPS] Group {group_idx} Traj {traj_idx}: SKIPPED (advantage=0)")
continue
trajectory_results: list[TokenizedResult] = []
for history in [
Expand Down Expand Up @@ -138,19 +143,26 @@ def tokenize_trajectory(
"""
# Find the index of the last assistant message
last_assistant_index = -1
print(f"[TOKENIZE FIRST LOOP] Checking {len(history.messages_and_choices)} messages")
for i, message in enumerate(history.messages_and_choices):
if isinstance(message, dict):
print(f"[TOKENIZE FIRST LOOP] msg {i}: dict, role={message.get('role')}, has_logprobs={bool(message.get('logprobs'))}")
else:
print(f"[TOKENIZE FIRST LOOP] msg {i}: Choice obj, has_logprobs={bool(message.logprobs if hasattr(message, 'logprobs') else None)}")
if (
isinstance(message, dict)
and message["role"] == "assistant"
and allow_training_without_logprobs
and (message.get("logprobs") or allow_training_without_logprobs)
):
last_assistant_index = i
elif not isinstance(message, dict) and (
message.logprobs or allow_training_without_logprobs
):
last_assistant_index = i
print(f"[TOKENIZE FIRST LOOP] last_assistant_index={last_assistant_index}")
# If there are no trainable assistant messages, return None
if last_assistant_index == -1:
print("[TOKENIZE FIRST LOOP] -> Returning None (no trainable messages)")
return None
messages_and_choices = history.messages_and_choices[: last_assistant_index + 1]
messages = get_messages(messages_and_choices)
Expand All @@ -159,23 +171,28 @@ def tokenize_trajectory(
if history.tools is not None
else None
)
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
try:
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
except ValueError as e:
if "continue_final_message" in str(e):
return None
raise
sentinal_token_id = max(
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
)
Expand Down Expand Up @@ -216,13 +233,57 @@ def tokenize_trajectory(
if isinstance(message, dict):
content = message.get("content")
assert isinstance(content, str)
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
msg_token_ids = message.get("token_ids")
dict_logprobs = message.get("logprobs")
print(f"[TOKENIZE DEBUG] Processing assistant dict message:")
print(f" message keys: {list(message.keys())}")
print(f" msg_token_ids is not None: {msg_token_ids is not None}")
print(f" dict_logprobs truthy: {bool(dict_logprobs)}")
print(f" dict_logprobs value: {repr(dict_logprobs)[:200] if dict_logprobs else repr(dict_logprobs)}")
if dict_logprobs:
print(f" dict_logprobs type: {type(dict_logprobs).__name__}")
print(f" dict_logprobs keys: {list(dict_logprobs.keys()) if isinstance(dict_logprobs, dict) else 'N/A'}")
print(f" 'values' in dict_logprobs: {'values' in dict_logprobs if isinstance(dict_logprobs, dict) else 'N/A'}")
if (
msg_token_ids is not None
and dict_logprobs
and "values" in dict_logprobs
):
print(f" -> Using provided token_ids ({len(msg_token_ids)}) and logprobs.values ({len(dict_logprobs['values'])})")
token_ids[start:end] = msg_token_ids
logprobs[start:end] = dict_logprobs["values"]
assistant_mask[start:end] = [1] * len(msg_token_ids)
elif (
dict_logprobs
and "content" in dict_logprobs
and dict_logprobs["content"]
):
token_logprobs = dict_logprobs["content"]
try:
token_ids[start:end] = [
int(lp["token"].split(":")[1]) for lp in token_logprobs
]
except (IndexError, ValueError, KeyError):
token_ids[start:end] = [
token_id if token_id is not None else tokenizer.eos_token_id
for token_id in tokenizer.convert_tokens_to_ids(
[
lp.get("token") or tokenizer.eos_token
for lp in token_logprobs
]
)
]
logprobs[start:end] = [lp["logprob"] for lp in token_logprobs]
assistant_mask[start:end] = [1] * len(token_logprobs)
else:
print(f" -> FALLBACK: re-tokenizing content, logprobs will be NaN")
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
else:
choice = message
assert choice.logprobs or allow_training_without_logprobs, (
Expand Down
4 changes: 3 additions & 1 deletion src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rich import print

import art
from art.utils.strip_logprobs import strip_logprobs


class TrajectoryScore(BaseModel):
Expand Down Expand Up @@ -287,9 +288,10 @@ async def ruler_score_group(
new_trajectories.append(new_traj)

# Extract message lists and preserve original rewards for comparison
# Strip logprobs to avoid sending huge token probability data to the judge
message_lists: list[list[ChatCompletionMessageParam]] = []
for traj in new_trajectories:
message_lists.append(traj.messages())
message_lists.append(strip_logprobs(traj.messages()))
traj.metrics["independent_reward"] = traj.reward

try:
Expand Down
5 changes: 4 additions & 1 deletion src/art/skypilot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ async def initialize_cluster(
)
print("Art server task already running, using it…")
else:
art_server_task = sky.Task(name="art_server", run="uv run art")
art_server_task = sky.Task(
name="art_server",
run="source $HOME/.local/bin/env && uv sync --extra backend && uv run art",
)

clusters = await to_thread_typed(
lambda: sky.stream_and_get(
Expand Down
Loading
Loading