Skip to content

Commit 04857dc

Browse files
committed
feat: extract logprobs from dict messages for GRPO importance sampling
1 parent e456b20 commit 04857dc

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/art/preprocessing/tokenize.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,15 @@ def tokenize_trajectory(
221221
add_special_tokens=False,
222222
)
223223
token_ids[start:end] = content_token_ids
224-
logprobs[start:end] = [float("nan")] * len(content_token_ids)
224+
dict_logprobs = message.get("logprobs")
225+
if dict_logprobs is None:
226+
logprobs[start:end] = [float("nan")] * len(content_token_ids)
227+
elif "content" in dict_logprobs and dict_logprobs["content"]:
228+
logprobs[start:end] = [lp["logprob"] for lp in dict_logprobs["content"]]
229+
else:
230+
raise ValueError(
231+
f"Message has 'logprobs' key but content is missing or empty: {dict_logprobs}"
232+
)
225233
assistant_mask[start:end] = [1] * len(content_token_ids)
226234
else:
227235
choice = message

0 commit comments

Comments
 (0)