Skip to content

Commit 8c905c8

Browse files
committed
feat: Add kimi k2 tau experimental config support
1 parent 6273a6f commit 8c905c8

File tree

4 files changed

+5
-0
lines changed

4 files changed

+5
-0
lines changed

src/art/dev/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ class TrainConfig(TypedDict, total=False):
1616
importance_sampling_level: Literal[
1717
"token", "sequence", "average", "geometric_average"
1818
]
19+
kimi_k2_tau: float | None
1920
logprob_calculation_chunk_size: int
2021
max_negative_advantage_importance_sampling_weight: float
2122
num_trajectories_learning_rate_multiplier_power: float

src/art/loss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def loss_fn(
7575
prob_ratio = torch.clamp(
7676
prob_ratio, max=max_negative_advantage_importance_sampling_weight
7777
)
78+
if tau := experimental_config.get("kimi_k2_tau", None):
79+
advantages -= tau * logprob_diff.detach()
7880
if experimental_config.get("ppo", True):
7981
policy_loss = -torch.min(
8082
prob_ratio * advantages,

src/art/serverless/backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ async def _train_model(
158158
epsilon=dev_config.get("epsilon"),
159159
epsilon_high=dev_config.get("epsilon_high"),
160160
importance_sampling_level=dev_config.get("importance_sampling_level"),
161+
kimi_k2_tau=dev_config.get("kimi_k2_tau"),
161162
learning_rate=config.learning_rate,
162163
max_negative_advantage_importance_sampling_weight=dev_config.get(
163164
"max_negative_advantage_importance_sampling_weight"

src/art/serverless/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class ExperimentalTrainingConfig(TypedDict, total=False):
5656
importance_sampling_level: (
5757
Literal["token", "sequence", "average", "geometric_average"] | None
5858
)
59+
kimi_k2_tau: float | None
5960
learning_rate: float | None
6061
max_negative_advantage_importance_sampling_weight: float | None
6162
ppo: bool | None

0 commit comments

Comments
 (0)