Skip to content

Commit 2af67d5

Browse files
authored
Add strip_logprobs utility function (#455)
* Add `strip_logprobs` function * Justify strip_logprobs * Fix test_strip_logprobs
1 parent 9144cbb commit 2af67d5

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

src/art/utils/strip_logprobs.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import copy
2+
import logging
3+
import sys
4+
from typing import Any
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def strip_logprobs(obj: Any) -> Any:
10+
"""
11+
Recursively remove 'logprobs' keys from nested data structures to reduce data storage size.
12+
13+
Args:
14+
obj: Any nested data structure
15+
16+
Returns:
17+
The same structure with 'logprobs' keys removed, or the original
18+
object if deepcopy fails
19+
"""
20+
21+
try:
22+
copied_obj = copy.deepcopy(obj)
23+
except Exception as e:
24+
logger.warning(
25+
f"Failed to deepcopy object in strip_logprobs: {e}. "
26+
"Returning original object unchanged."
27+
)
28+
return obj
29+
30+
result = _strip_logprobs(copied_obj)
31+
32+
return result
33+
34+
35+
def _strip_logprobs(obj: Any) -> Any:
36+
if isinstance(obj, dict):
37+
return {k: _strip_logprobs(v) for k, v in obj.items() if k != "logprobs"}
38+
elif isinstance(obj, (list, tuple)):
39+
result = [_strip_logprobs(v) for v in obj]
40+
return tuple(result) if isinstance(obj, tuple) else result
41+
elif hasattr(obj, "__dict__"):
42+
for k, v in obj.__dict__.items():
43+
if k == "logprobs":
44+
setattr(obj, k, None)
45+
else:
46+
setattr(obj, k, _strip_logprobs(v))
47+
return obj
48+
return obj

tests/unit/test_strip_logprobs.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""Tests for strip_logprobs utility function."""
2+
3+
import copy
4+
import logging
5+
from unittest.mock import MagicMock
6+
7+
import pytest
8+
9+
from art.utils.strip_logprobs import strip_logprobs
10+
11+
12+
class TestStripLogprobs:
13+
"""Test suite for strip_logprobs function."""
14+
15+
def test_strip_dict_with_logprobs(self):
16+
"""Test stripping logprobs from dictionary."""
17+
input_dict = {
18+
"data": "value",
19+
"logprobs": [0.1, 0.2, 0.3],
20+
"nested": {"key": "val", "logprobs": {"nested_log": 0.5}},
21+
}
22+
expected = {"data": "value", "nested": {"key": "val"}}
23+
24+
result = strip_logprobs(input_dict)
25+
26+
assert result == expected
27+
assert input_dict["logprobs"] == [0.1, 0.2, 0.3] # Original unchanged
28+
29+
def test_strip_nested_dict(self):
30+
"""Test stripping logprobs from deeply nested dictionaries."""
31+
input_dict = {
32+
"level1": {
33+
"level2": {
34+
"level3": {"data": 1, "logprobs": "remove_me"},
35+
"logprobs": [1, 2, 3],
36+
}
37+
},
38+
"logprobs": None,
39+
}
40+
expected = {"level1": {"level2": {"level3": {"data": 1}}}}
41+
42+
result = strip_logprobs(input_dict)
43+
44+
assert result == expected
45+
46+
def test_strip_list_with_logprobs(self):
47+
"""Test stripping logprobs from lists."""
48+
input_list = [
49+
{"item": 1, "logprobs": 0.1},
50+
{"item": 2, "logprobs": 0.2},
51+
{"item": 3},
52+
]
53+
expected = [{"item": 1}, {"item": 2}, {"item": 3}]
54+
55+
result = strip_logprobs(input_list)
56+
57+
assert result == expected
58+
59+
def test_strip_tuple_with_logprobs(self):
60+
"""Test stripping logprobs from tuples."""
61+
input_tuple = (
62+
{"item": 1, "logprobs": 0.1},
63+
{"item": 2},
64+
{"nested": {"logprobs": "remove"}},
65+
)
66+
expected = ({"item": 1}, {"item": 2}, {"nested": {}})
67+
68+
result = strip_logprobs(input_tuple)
69+
70+
assert result == expected
71+
assert isinstance(result, tuple)
72+
73+
def test_strip_object_with_logprobs(self):
74+
"""Test stripping logprobs from objects with __dict__."""
75+
76+
class TestObject:
77+
def __init__(self):
78+
self.data = "value"
79+
self.logprobs = [0.1, 0.2]
80+
self.nested = {"key": "val", "logprobs": "remove"}
81+
82+
obj = TestObject()
83+
result = strip_logprobs(obj)
84+
85+
assert result.data == "value"
86+
assert result.logprobs is None # Set to None for objects
87+
assert result.nested == {"key": "val"}
88+
89+
def test_strip_mixed_nested_structure(self):
90+
"""Test stripping logprobs from mixed nested structures."""
91+
input_data = {
92+
"list": [
93+
{"logprobs": 1},
94+
[{"nested_list": True, "logprobs": 2}],
95+
],
96+
"tuple": ({"logprobs": 3}, {"keep": "me"}),
97+
"dict": {"nested": {"logprobs": 4, "data": "keep"}},
98+
}
99+
expected = {
100+
"list": [{}, [{"nested_list": True}]],
101+
"tuple": ({}, {"keep": "me"}),
102+
"dict": {"nested": {"data": "keep"}},
103+
}
104+
105+
result = strip_logprobs(input_data)
106+
107+
assert result == expected
108+
109+
def test_strip_empty_structures(self):
110+
"""Test stripping logprobs from empty structures."""
111+
assert strip_logprobs({}) == {}
112+
assert strip_logprobs([]) == []
113+
assert strip_logprobs(()) == ()
114+
115+
def test_strip_none_and_primitives(self):
116+
"""Test stripping logprobs from None and primitive values."""
117+
assert strip_logprobs(None) is None
118+
assert strip_logprobs(42) == 42
119+
assert strip_logprobs("string") == "string"
120+
assert strip_logprobs(3.14) == 3.14
121+
assert strip_logprobs(True) is True
122+
123+
def test_no_logprobs_unchanged(self):
124+
"""Test that structures without logprobs remain unchanged."""
125+
input_dict = {
126+
"data": "value",
127+
"nested": {"key": "val"},
128+
"list": [1, 2, 3],
129+
}
130+
131+
result = strip_logprobs(input_dict)
132+
133+
assert result == input_dict
134+
135+
def test_deepcopy_behavior(self):
136+
"""Test that the function creates a deep copy."""
137+
nested_list = [1, 2, 3]
138+
input_dict = {
139+
"data": nested_list,
140+
"logprobs": "remove",
141+
}
142+
143+
result = strip_logprobs(input_dict)
144+
145+
result["data"].append(4)
146+
assert nested_list == [1, 2, 3] # Original unchanged
147+
assert result["data"] == [1, 2, 3, 4]
148+
149+
def test_deepcopy_failure_returns_original(self, caplog):
150+
"""Test that deepcopy failure returns original object and logs warning."""
151+
152+
class UnCopyableObject:
153+
def __init__(self):
154+
self.data = "value"
155+
self.logprobs = "should_remain"
156+
157+
def __deepcopy__(self, memo):
158+
raise RuntimeError("Cannot deepcopy this object")
159+
160+
obj = UnCopyableObject()
161+
162+
with caplog.at_level(logging.WARNING):
163+
result = strip_logprobs(obj)
164+
165+
# Should return the original object unchanged
166+
assert result is obj
167+
assert result.logprobs == "should_remain"
168+
169+
# Check that warning was logged
170+
assert len(caplog.records) == 1
171+
assert "Failed to deepcopy object in strip_logprobs" in caplog.text
172+
assert "Cannot deepcopy this object" in caplog.text
173+
assert "Returning original object unchanged" in caplog.text
174+
175+
def test_deepcopy_failure_with_recursion_error(self, caplog):
176+
"""Test handling of RecursionError during deepcopy."""
177+
178+
class RecursiveObject:
179+
def __init__(self):
180+
self.data = "value"
181+
self.logprobs = [1, 2, 3]
182+
183+
def __deepcopy__(self, memo):
184+
raise RecursionError("maximum recursion depth exceeded")
185+
186+
obj = RecursiveObject()
187+
188+
with caplog.at_level(logging.WARNING):
189+
result = strip_logprobs(obj)
190+
191+
# Should return the original object unchanged
192+
assert result is obj
193+
assert result.logprobs == [1, 2, 3]
194+
195+
# Check that warning was logged
196+
assert "Failed to deepcopy object in strip_logprobs" in caplog.text
197+
assert "maximum recursion depth exceeded" in caplog.text

0 commit comments

Comments
 (0)