Skip to content

Commit f739866

Browse files
committed
fix types
1 parent 5ca7815 commit f739866

File tree

2 files changed

+198
-1
lines changed

2 files changed

+198
-1
lines changed

src/art/trajectories.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,14 @@ def __copy__(self):
187187
new_instance.exceptions = self.exceptions[:]
188188
return new_instance
189189

190-
def __deepcopy__(self, memo):
190+
def __deepcopy__(self, memo: dict[int, Any] | None = None):
191191
"""Support for copy.deepcopy()"""
192192
import copy
193193

194+
# Initialize memo if not provided
195+
if memo is None:
196+
memo = {}
197+
194198
# Check memo to handle circular references
195199
if id(self) in memo:
196200
return memo[id(self)]

tests/unit/test_trajectory_copy.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""Tests for TrajectoryGroup copy and deepcopy functionality."""
2+
3+
import copy
4+
5+
import pytest
6+
from openai.types.chat import ChatCompletionMessage
7+
from openai.types.chat.chat_completion import Choice
8+
9+
from art.trajectories import PydanticException, Trajectory, TrajectoryGroup
10+
11+
12+
@pytest.fixture
13+
def sample_trajectory():
14+
"""Create a sample trajectory for testing."""
15+
return Trajectory(
16+
messages_and_choices=[
17+
{"role": "user", "content": "Hello"},
18+
Choice(
19+
finish_reason="stop",
20+
index=0,
21+
logprobs=None,
22+
message=ChatCompletionMessage(
23+
role="assistant",
24+
content="Hi there!",
25+
refusal=None,
26+
),
27+
),
28+
],
29+
tools=None,
30+
reward=1.0,
31+
metrics={"accuracy": 0.95},
32+
metadata={"test": "value"},
33+
)
34+
35+
36+
@pytest.fixture
37+
def sample_trajectory_group(sample_trajectory):
38+
"""Create a sample trajectory group for testing."""
39+
trajectory2 = Trajectory(
40+
messages_and_choices=[
41+
{"role": "user", "content": "How are you?"},
42+
Choice(
43+
finish_reason="stop",
44+
index=0,
45+
logprobs=None,
46+
message=ChatCompletionMessage(
47+
role="assistant",
48+
content="I'm doing well!",
49+
refusal=None,
50+
),
51+
),
52+
],
53+
tools=None,
54+
reward=0.8,
55+
)
56+
return TrajectoryGroup(
57+
trajectories=[sample_trajectory, trajectory2],
58+
exceptions=[],
59+
)
60+
61+
62+
def test_shallow_copy(sample_trajectory_group):
63+
"""Test that shallow copy works correctly."""
64+
copied = copy.copy(sample_trajectory_group)
65+
66+
# Should be a different object
67+
assert copied is not sample_trajectory_group
68+
69+
# Trajectories should be a new list (shallow copy of list)
70+
assert copied.trajectories is not sample_trajectory_group.trajectories
71+
72+
# But the trajectory objects themselves should be the same (shallow copy)
73+
assert copied.trajectories[0] is sample_trajectory_group.trajectories[0]
74+
assert copied.trajectories[1] is sample_trajectory_group.trajectories[1]
75+
76+
# Exceptions should be a new list with same contents
77+
assert copied.exceptions is not sample_trajectory_group.exceptions
78+
assert copied.exceptions == sample_trajectory_group.exceptions
79+
80+
81+
def test_deep_copy(sample_trajectory_group):
82+
"""Test that deep copy works correctly."""
83+
copied = copy.deepcopy(sample_trajectory_group)
84+
85+
# Should be a different object
86+
assert copied is not sample_trajectory_group
87+
88+
# Should have different trajectories list (deep copy)
89+
assert copied.trajectories is not sample_trajectory_group.trajectories
90+
91+
# Trajectories themselves should be different objects
92+
assert copied.trajectories[0] is not sample_trajectory_group.trajectories[0]
93+
assert copied.trajectories[1] is not sample_trajectory_group.trajectories[1]
94+
95+
# But should have same content
96+
assert len(copied.trajectories) == len(sample_trajectory_group.trajectories)
97+
assert (
98+
copied.trajectories[0].reward == sample_trajectory_group.trajectories[0].reward
99+
)
100+
assert (
101+
copied.trajectories[1].reward == sample_trajectory_group.trajectories[1].reward
102+
)
103+
104+
# Exceptions should also be deep copied
105+
assert copied.exceptions is not sample_trajectory_group.exceptions
106+
107+
108+
def test_deep_copy_with_exceptions():
109+
"""Test that deep copy works with exceptions."""
110+
group = TrajectoryGroup(
111+
trajectories=[
112+
Trajectory(
113+
messages_and_choices=[{"role": "user", "content": "test"}],
114+
tools=None,
115+
reward=1.0,
116+
)
117+
],
118+
exceptions=[ValueError("test error")],
119+
)
120+
121+
copied = copy.deepcopy(group)
122+
123+
# Should be different objects
124+
assert copied is not group
125+
assert copied.exceptions is not group.exceptions
126+
127+
# Should have same exception content
128+
assert len(copied.exceptions) == len(group.exceptions)
129+
assert copied.exceptions[0].message == group.exceptions[0].message
130+
131+
132+
def test_deep_copy_circular_reference():
133+
"""Test that deep copy handles circular references correctly."""
134+
group = TrajectoryGroup(
135+
trajectories=[
136+
Trajectory(
137+
messages_and_choices=[{"role": "user", "content": "test"}],
138+
tools=None,
139+
reward=1.0,
140+
)
141+
],
142+
exceptions=[],
143+
)
144+
145+
# Create a memo dict with a circular reference
146+
memo = {}
147+
copied = copy.deepcopy(group, memo)
148+
149+
# Should be in memo
150+
assert id(group) in memo
151+
assert memo[id(group)] is copied
152+
153+
# Copying again with same memo should return the same object
154+
copied2 = copy.deepcopy(group, memo)
155+
assert copied2 is copied
156+
157+
158+
def test_deep_copy_preserves_metadata(sample_trajectory_group):
159+
"""Test that deep copy preserves trajectory metadata."""
160+
copied = copy.deepcopy(sample_trajectory_group)
161+
162+
# Check that metadata is preserved
163+
assert (
164+
copied.trajectories[0].metrics
165+
== sample_trajectory_group.trajectories[0].metrics
166+
)
167+
assert (
168+
copied.trajectories[0].metadata
169+
== sample_trajectory_group.trajectories[0].metadata
170+
)
171+
172+
# But should be different dict objects
173+
assert (
174+
copied.trajectories[0].metrics
175+
is not sample_trajectory_group.trajectories[0].metrics
176+
)
177+
assert (
178+
copied.trajectories[0].metadata
179+
is not sample_trajectory_group.trajectories[0].metadata
180+
)
181+
182+
183+
def test_copy_empty_group():
184+
"""Test copying an empty trajectory group."""
185+
empty_group = TrajectoryGroup(trajectories=[], exceptions=[])
186+
187+
shallow = copy.copy(empty_group)
188+
assert shallow is not empty_group
189+
assert len(shallow.trajectories) == 0
190+
191+
deep = copy.deepcopy(empty_group)
192+
assert deep is not empty_group
193+
assert len(deep.trajectories) == 0

0 commit comments

Comments
 (0)