File tree Expand file tree Collapse file tree 3 files changed +11
-0
lines changed
Expand file tree Collapse file tree 3 files changed +11
-0
lines changed Original file line number Diff line number Diff line change @@ -18,6 +18,7 @@ class TrainConfig(TypedDict, total=False):
1818 max_negative_advantage_importance_sampling_weight : float
1919 num_trajectories_learning_rate_multiplier_power : float
2020 plot_tensors : bool
21+ ppo : bool
2122 precalculate_logprobs : bool
2223 scale_learning_rate_by_reward_std_dev : bool
2324 scale_rewards : bool
Original file line number Diff line number Diff line change @@ -129,7 +129,12 @@ async def _train_model(
129129 trajectory_groups = trajectory_groups ,
130130 experimental_config = ExperimentalTrainingConfig (
131131 advantage_balance = dev_config .get ("advantage_balance" ),
132+ epsilon = dev_config .get ("epsilon" ),
133+ epsilon_high = dev_config .get ("epsilon_high" ),
134+ importance_sampling_level = dev_config .get ("importance_sampling_level" ),
132135 learning_rate = config .learning_rate ,
136+ max_negative_advantage_importance_sampling_weight = dev_config .get ("max_negative_advantage_importance_sampling_weight" ),
137+ ppo = dev_config .get ("ppo" ),
133138 precalculate_logprobs = dev_config .get ("precalculate_logprobs" ),
134139 scale_rewards = dev_config .get ("scale_rewards" ),
135140 ),
Original file line number Diff line number Diff line change @@ -51,7 +51,12 @@ class DeleteCheckpointsResponse(BaseModel):
5151
5252class ExperimentalTrainingConfig (TypedDict , total = False ):
5353 advantage_balance : float | None
54+ epsilon : float | None
55+ epsilon_high : float | None
56+ importance_sampling_level : Literal ["token" , "sequence" ] | None
5457 learning_rate : float | None
58+ max_negative_advantage_importance_sampling_weight : float | None
59+ ppo : bool | None
5560 precalculate_logprobs : bool | None
5661 scale_rewards : bool | None
5762
You can’t perform that action at this time.
0 commit comments