|
| 1 | +"""Tests for TrajectoryGroup copy and deepcopy functionality.""" |
| 2 | + |
| 3 | +import copy |
| 4 | + |
| 5 | +import pytest |
| 6 | +from openai.types.chat import ChatCompletionMessage |
| 7 | +from openai.types.chat.chat_completion import Choice |
| 8 | + |
| 9 | +from art.trajectories import PydanticException, Trajectory, TrajectoryGroup |
| 10 | + |
| 11 | + |
| 12 | +@pytest.fixture |
| 13 | +def sample_trajectory(): |
| 14 | + """Create a sample trajectory for testing.""" |
| 15 | + return Trajectory( |
| 16 | + messages_and_choices=[ |
| 17 | + {"role": "user", "content": "Hello"}, |
| 18 | + Choice( |
| 19 | + finish_reason="stop", |
| 20 | + index=0, |
| 21 | + logprobs=None, |
| 22 | + message=ChatCompletionMessage( |
| 23 | + role="assistant", |
| 24 | + content="Hi there!", |
| 25 | + refusal=None, |
| 26 | + ), |
| 27 | + ), |
| 28 | + ], |
| 29 | + tools=None, |
| 30 | + reward=1.0, |
| 31 | + metrics={"accuracy": 0.95}, |
| 32 | + metadata={"test": "value"}, |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +@pytest.fixture |
| 37 | +def sample_trajectory_group(sample_trajectory): |
| 38 | + """Create a sample trajectory group for testing.""" |
| 39 | + trajectory2 = Trajectory( |
| 40 | + messages_and_choices=[ |
| 41 | + {"role": "user", "content": "How are you?"}, |
| 42 | + Choice( |
| 43 | + finish_reason="stop", |
| 44 | + index=0, |
| 45 | + logprobs=None, |
| 46 | + message=ChatCompletionMessage( |
| 47 | + role="assistant", |
| 48 | + content="I'm doing well!", |
| 49 | + refusal=None, |
| 50 | + ), |
| 51 | + ), |
| 52 | + ], |
| 53 | + tools=None, |
| 54 | + reward=0.8, |
| 55 | + ) |
| 56 | + return TrajectoryGroup( |
| 57 | + trajectories=[sample_trajectory, trajectory2], |
| 58 | + exceptions=[], |
| 59 | + ) |
| 60 | + |
| 61 | + |
| 62 | +def test_shallow_copy(sample_trajectory_group): |
| 63 | + """Test that shallow copy works correctly.""" |
| 64 | + copied = copy.copy(sample_trajectory_group) |
| 65 | + |
| 66 | + # Should be a different object |
| 67 | + assert copied is not sample_trajectory_group |
| 68 | + |
| 69 | + # Trajectories should be a new list (shallow copy of list) |
| 70 | + assert copied.trajectories is not sample_trajectory_group.trajectories |
| 71 | + |
| 72 | + # But the trajectory objects themselves should be the same (shallow copy) |
| 73 | + assert copied.trajectories[0] is sample_trajectory_group.trajectories[0] |
| 74 | + assert copied.trajectories[1] is sample_trajectory_group.trajectories[1] |
| 75 | + |
| 76 | + # Exceptions should be a new list with same contents |
| 77 | + assert copied.exceptions is not sample_trajectory_group.exceptions |
| 78 | + assert copied.exceptions == sample_trajectory_group.exceptions |
| 79 | + |
| 80 | + |
| 81 | +def test_deep_copy(sample_trajectory_group): |
| 82 | + """Test that deep copy works correctly.""" |
| 83 | + copied = copy.deepcopy(sample_trajectory_group) |
| 84 | + |
| 85 | + # Should be a different object |
| 86 | + assert copied is not sample_trajectory_group |
| 87 | + |
| 88 | + # Should have different trajectories list (deep copy) |
| 89 | + assert copied.trajectories is not sample_trajectory_group.trajectories |
| 90 | + |
| 91 | + # Trajectories themselves should be different objects |
| 92 | + assert copied.trajectories[0] is not sample_trajectory_group.trajectories[0] |
| 93 | + assert copied.trajectories[1] is not sample_trajectory_group.trajectories[1] |
| 94 | + |
| 95 | + # But should have same content |
| 96 | + assert len(copied.trajectories) == len(sample_trajectory_group.trajectories) |
| 97 | + assert ( |
| 98 | + copied.trajectories[0].reward == sample_trajectory_group.trajectories[0].reward |
| 99 | + ) |
| 100 | + assert ( |
| 101 | + copied.trajectories[1].reward == sample_trajectory_group.trajectories[1].reward |
| 102 | + ) |
| 103 | + |
| 104 | + # Exceptions should also be deep copied |
| 105 | + assert copied.exceptions is not sample_trajectory_group.exceptions |
| 106 | + |
| 107 | + |
| 108 | +def test_deep_copy_with_exceptions(): |
| 109 | + """Test that deep copy works with exceptions.""" |
| 110 | + group = TrajectoryGroup( |
| 111 | + trajectories=[ |
| 112 | + Trajectory( |
| 113 | + messages_and_choices=[{"role": "user", "content": "test"}], |
| 114 | + tools=None, |
| 115 | + reward=1.0, |
| 116 | + ) |
| 117 | + ], |
| 118 | + exceptions=[ValueError("test error")], |
| 119 | + ) |
| 120 | + |
| 121 | + copied = copy.deepcopy(group) |
| 122 | + |
| 123 | + # Should be different objects |
| 124 | + assert copied is not group |
| 125 | + assert copied.exceptions is not group.exceptions |
| 126 | + |
| 127 | + # Should have same exception content |
| 128 | + assert len(copied.exceptions) == len(group.exceptions) |
| 129 | + assert copied.exceptions[0].message == group.exceptions[0].message |
| 130 | + |
| 131 | + |
| 132 | +def test_deep_copy_circular_reference(): |
| 133 | + """Test that deep copy handles circular references correctly.""" |
| 134 | + group = TrajectoryGroup( |
| 135 | + trajectories=[ |
| 136 | + Trajectory( |
| 137 | + messages_and_choices=[{"role": "user", "content": "test"}], |
| 138 | + tools=None, |
| 139 | + reward=1.0, |
| 140 | + ) |
| 141 | + ], |
| 142 | + exceptions=[], |
| 143 | + ) |
| 144 | + |
| 145 | + # Create a memo dict with a circular reference |
| 146 | + memo = {} |
| 147 | + copied = copy.deepcopy(group, memo) |
| 148 | + |
| 149 | + # Should be in memo |
| 150 | + assert id(group) in memo |
| 151 | + assert memo[id(group)] is copied |
| 152 | + |
| 153 | + # Copying again with same memo should return the same object |
| 154 | + copied2 = copy.deepcopy(group, memo) |
| 155 | + assert copied2 is copied |
| 156 | + |
| 157 | + |
| 158 | +def test_deep_copy_preserves_metadata(sample_trajectory_group): |
| 159 | + """Test that deep copy preserves trajectory metadata.""" |
| 160 | + copied = copy.deepcopy(sample_trajectory_group) |
| 161 | + |
| 162 | + # Check that metadata is preserved |
| 163 | + assert ( |
| 164 | + copied.trajectories[0].metrics |
| 165 | + == sample_trajectory_group.trajectories[0].metrics |
| 166 | + ) |
| 167 | + assert ( |
| 168 | + copied.trajectories[0].metadata |
| 169 | + == sample_trajectory_group.trajectories[0].metadata |
| 170 | + ) |
| 171 | + |
| 172 | + # But should be different dict objects |
| 173 | + assert ( |
| 174 | + copied.trajectories[0].metrics |
| 175 | + is not sample_trajectory_group.trajectories[0].metrics |
| 176 | + ) |
| 177 | + assert ( |
| 178 | + copied.trajectories[0].metadata |
| 179 | + is not sample_trajectory_group.trajectories[0].metadata |
| 180 | + ) |
| 181 | + |
| 182 | + |
| 183 | +def test_copy_empty_group(): |
| 184 | + """Test copying an empty trajectory group.""" |
| 185 | + empty_group = TrajectoryGroup(trajectories=[], exceptions=[]) |
| 186 | + |
| 187 | + shallow = copy.copy(empty_group) |
| 188 | + assert shallow is not empty_group |
| 189 | + assert len(shallow.trajectories) == 0 |
| 190 | + |
| 191 | + deep = copy.deepcopy(empty_group) |
| 192 | + assert deep is not empty_group |
| 193 | + assert len(deep.trajectories) == 0 |
0 commit comments