Skip to content

Commit dd383f5

Browse files
JRMeyerclaude
andcommitted
fix: preserve _internal_config during Pydantic deserialization
The _internal_config field was being lost when TrainableModel was deserialized from JSON (e.g., when sent from client to SkyPilot backend). This is because Pydantic ignores fields starting with underscore during model_validate(). Added a model_validator(mode="wrap") that extracts _internal_config from the input data before validation and sets it after the model is created. This fixes the "Cannot request more than 0 logprobs" error when using _internal_config.engine_args with remote backends. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 2c6e4bf commit dd383f5

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

src/art/model.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1-
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
1+
from typing import (
2+
TYPE_CHECKING,
3+
Any,
4+
Generic,
5+
Iterable,
6+
Optional,
7+
TypeVar,
8+
cast,
9+
overload,
10+
)
211

312
import httpx
413
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
5-
from pydantic import BaseModel
14+
from pydantic import BaseModel, model_validator
615
from typing_extensions import Never
716

817
from . import dev
@@ -279,6 +288,19 @@ def __init__(
279288
# Bypass BaseModel __setattr__ to allow setting private attr
280289
object.__setattr__(self, "_internal_config", _internal_config)
281290

291+
@model_validator(mode="wrap")
292+
@classmethod
293+
def _preserve_internal_config(
294+
cls, data: Any, handler: Any
295+
) -> "TrainableModel[ModelConfig]":
296+
internal_config = None
297+
if isinstance(data, dict) and "_internal_config" in data:
298+
internal_config = data.pop("_internal_config")
299+
model = handler(data)
300+
if internal_config is not None:
301+
object.__setattr__(model, "_internal_config", internal_config)
302+
return model
303+
282304
@overload
283305
def __new__(
284306
cls,

src/art/serverless/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ async def _train_model(
133133
epsilon_high=dev_config.get("epsilon_high"),
134134
importance_sampling_level=dev_config.get("importance_sampling_level"),
135135
learning_rate=config.learning_rate,
136-
max_negative_advantage_importance_sampling_weight=dev_config.get("max_negative_advantage_importance_sampling_weight"),
136+
max_negative_advantage_importance_sampling_weight=dev_config.get(
137+
"max_negative_advantage_importance_sampling_weight"
138+
),
137139
ppo=dev_config.get("ppo"),
138140
precalculate_logprobs=dev_config.get("precalculate_logprobs"),
139141
scale_rewards=dev_config.get("scale_rewards"),

0 commit comments

Comments
 (0)