Skip to content

Commit 70f332e

Browse files
committed
feat: Expand importance_sampling_level options and improve type hints in tokenization
1 parent a0ae38a commit 70f332e

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

src/art/dev/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ class TrainConfig(TypedDict, total=False):
1313
epsilon_high: (
1414
float | None
1515
) # asymmetric clip upper bound. Defaults to epsilon when None
16-
importance_sampling_level: Literal["token", "sequence"]
16+
importance_sampling_level: Literal[
17+
"token", "sequence", "average", "harmonic_average"
18+
]
1719
logprob_calculation_chunk_size: int
1820
max_negative_advantage_importance_sampling_weight: float
1921
num_trajectories_learning_rate_multiplier_power: float

src/art/preprocessing/tokenize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def tokenize_trajectory(
163163
str,
164164
tokenizer.apply_chat_template(
165165
cast(list[dict], messages),
166-
tools=tools,
166+
tools=tools, # type: ignore
167167
continue_final_message=True,
168168
tokenize=False,
169169
),
@@ -172,7 +172,7 @@ def tokenize_trajectory(
172172
list[int],
173173
tokenizer.apply_chat_template(
174174
cast(list[dict], messages),
175-
tools=tools,
175+
tools=tools, # type: ignore
176176
continue_final_message=True,
177177
),
178178
)
@@ -198,7 +198,7 @@ def tokenize_trajectory(
198198
for message_or_choice in messages_and_choices
199199
],
200200
),
201-
tools=tools,
201+
tools=tools, # type: ignore
202202
continue_final_message=True,
203203
),
204204
)

src/art/serverless/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ async def _train_model(
133133
epsilon_high=dev_config.get("epsilon_high"),
134134
importance_sampling_level=dev_config.get("importance_sampling_level"),
135135
learning_rate=config.learning_rate,
136-
max_negative_advantage_importance_sampling_weight=dev_config.get("max_negative_advantage_importance_sampling_weight"),
136+
max_negative_advantage_importance_sampling_weight=dev_config.get(
137+
"max_negative_advantage_importance_sampling_weight"
138+
),
137139
ppo=dev_config.get("ppo"),
138140
precalculate_logprobs=dev_config.get("precalculate_logprobs"),
139141
scale_rewards=dev_config.get("scale_rewards"),

src/art/serverless/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ class ExperimentalTrainingConfig(TypedDict, total=False):
5353
advantage_balance: float | None
5454
epsilon: float | None
5555
epsilon_high: float | None
56-
importance_sampling_level: Literal["token", "sequence"] | None
56+
importance_sampling_level: (
57+
Literal["token", "sequence", "average", "harmonic_average"] | None
58+
)
5759
learning_rate: float | None
5860
max_negative_advantage_importance_sampling_weight: float | None
5961
ppo: bool | None

0 commit comments

Comments
 (0)