Skip to content

Commit f639854

Browse files
authored
Update shuffle
Differential Revision: D80646907 Pull Request resolved: #857
1 parent a4336c2 commit f639854

File tree

2 files changed

+104
-8
lines changed

2 files changed

+104
-8
lines changed

src/spdl/source/utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,23 +206,26 @@ def __init__(
206206
) -> None:
207207
self.src = src
208208
self._epoch = epoch
209-
self._shuffle_last = shuffle_last
209+
self._shuffle_first: bool = not shuffle_last
210210

211211
def _shuffle(self) -> None:
212212
t0 = time.monotonic()
213213
self.src.shuffle(seed=self._epoch)
214214
if (elapsed := time.monotonic() - t0) > 3:
215215
_LG.warning("Shuffling took %.2f sec.", elapsed)
216-
217-
def __iter__(self) -> Iterator[T]:
218-
if not self._shuffle_last:
219-
self._shuffle()
220-
221-
yield from self.src
222216
self._epoch += 1
223217

224-
if self._shuffle_last:
218+
def __iter__(self) -> Iterator[T]:
219+
if self._shuffle_first:
225220
self._shuffle()
221+
yield from self.src
222+
else:
223+
try:
224+
yield from self.src
225+
finally:
226+
# in case the iteration is stopped in the middle.
227+
# shuffle is called when the iterator is deleted.
228+
self._shuffle()
226229

227230
def __len__(self) -> int:
228231
if isinstance(self.src, Sized):
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from collections.abc import Iterator
10+
11+
from spdl.source.utils import embed_shuffle
12+
13+
14+
class IterableWithShuffle_:
15+
def __init__(self, n: int) -> None:
16+
self.vals = list(range(n))
17+
self._seed: int | None = None
18+
19+
def __iter__(self) -> Iterator[int]:
20+
yield from self.vals
21+
22+
def shuffle(self, seed: int) -> None:
23+
# rotate
24+
self._seed = seed
25+
self.vals = self.vals[1:] + self.vals[:1]
26+
27+
28+
def test_embed_shuffle():
29+
"""Iterable created by embed_shuffle calls shuffle automatically"""
30+
31+
foo = IterableWithShuffle_(3)
32+
assert foo._seed is None
33+
iterable = embed_shuffle(foo)
34+
assert list(iterable) == [1, 2, 0]
35+
assert foo._seed == 0
36+
assert list(iterable) == [2, 0, 1]
37+
assert foo._seed == 1
38+
assert list(iterable) == [0, 1, 2]
39+
assert foo._seed == 2
40+
41+
42+
def test_embed_shuffle_halt():
43+
"""The value is shuffled with different seed even after an iteration is halted."""
44+
45+
foo = IterableWithShuffle_(5)
46+
iterable = embed_shuffle(foo)
47+
48+
iterator = iter(iterable)
49+
assert foo._seed is None
50+
assert next(iterator) == 1
51+
assert foo._seed == 0
52+
assert next(iterator) == 2
53+
del iterator
54+
55+
iterator = iter(iterable)
56+
assert next(iterator) == 2
57+
assert foo._seed == 1
58+
assert next(iterator) == 3
59+
del iterator
60+
61+
62+
def test_embed_shuffle_shuffle_after():
63+
"""Iterable created by embed_shuffle calls shuffle automatically after iteration"""
64+
65+
foo = IterableWithShuffle_(3)
66+
iterable = embed_shuffle(foo, shuffle_last=True)
67+
assert foo._seed is None
68+
assert list(iterable) == [0, 1, 2]
69+
assert foo._seed == 0
70+
assert list(iterable) == [1, 2, 0]
71+
assert foo._seed == 1
72+
assert list(iterable) == [2, 0, 1]
73+
assert foo._seed == 2
74+
75+
76+
def test_embed_shuffle_shuffle_after_halt():
77+
"""The value is shuffled with different seed even after an iteration is halted."""
78+
79+
foo = IterableWithShuffle_(5)
80+
iterable = embed_shuffle(foo, shuffle_last=True)
81+
82+
iterator = iter(iterable)
83+
assert next(iterator) == 0
84+
assert next(iterator) == 1
85+
assert foo._seed is None
86+
del iterator
87+
assert foo._seed == 0
88+
89+
iterator = iter(iterable)
90+
assert next(iterator) == 1
91+
assert next(iterator) == 2
92+
del iterator
93+
assert foo._seed == 1

0 commit comments

Comments
 (0)