Skip to content

Commit a0ae38a

Browse files
committed
chore: Add support for more experimental config with the ServerlessBackend
1 parent a9657c4 commit a0ae38a

File tree

3 files changed

+11
-0
lines changed

3 files changed

+11
-0
lines changed

src/art/dev/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False):
1818
max_negative_advantage_importance_sampling_weight: float
1919
num_trajectories_learning_rate_multiplier_power: float
2020
plot_tensors: bool
21+
ppo: bool
2122
precalculate_logprobs: bool
2223
scale_learning_rate_by_reward_std_dev: bool
2324
scale_rewards: bool

src/art/serverless/backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,12 @@ async def _train_model(
129129
trajectory_groups=trajectory_groups,
130130
experimental_config=ExperimentalTrainingConfig(
131131
advantage_balance=dev_config.get("advantage_balance"),
132+
epsilon=dev_config.get("epsilon"),
133+
epsilon_high=dev_config.get("epsilon_high"),
134+
importance_sampling_level=dev_config.get("importance_sampling_level"),
132135
learning_rate=config.learning_rate,
136+
max_negative_advantage_importance_sampling_weight=dev_config.get("max_negative_advantage_importance_sampling_weight"),
137+
ppo=dev_config.get("ppo"),
133138
precalculate_logprobs=dev_config.get("precalculate_logprobs"),
134139
scale_rewards=dev_config.get("scale_rewards"),
135140
),

src/art/serverless/client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,12 @@ class DeleteCheckpointsResponse(BaseModel):
5151

5252
class ExperimentalTrainingConfig(TypedDict, total=False):
5353
advantage_balance: float | None
54+
epsilon: float | None
55+
epsilon_high: float | None
56+
importance_sampling_level: Literal["token", "sequence"] | None
5457
learning_rate: float | None
58+
max_negative_advantage_importance_sampling_weight: float | None
59+
ppo: bool | None
5560
precalculate_logprobs: bool | None
5661
scale_rewards: bool | None
5762

0 commit comments

Comments
 (0)