Improve replay buffer incremental benchmark#447
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements incremental object retrieval for the replay-buffer example, optimizing performance by applying only newly seen patch batches to a service-local materialized buffer. Changes include updates to the design documentation, CLI flags for benchmark evaluation, and the introduction of a stateful deserializer in the ReplayBuffer class. Review feedback correctly identifies that the new stateful implementation lacks synchronization, which could lead to race conditions in concurrent environments; adding a lock to protect the internal state during materialization is recommended.
| self._materialized_base = None | ||
| self._materialized_data: dict[str, Any] | None = None | ||
| self._materialized_patch_count = 0 |
There was a problem hiding this comment.
The ReplayBuffer is now stateful to support incremental materialization, but it lacks synchronization. If this buffer is used within a service that handles concurrent requests (which is common in Flame services), multiple threads could execute _deserializer simultaneously, leading to race conditions and data corruption in the materialized buffer. I recommend adding a lock to protect the internal state.
| self._materialized_base = None | |
| self._materialized_data: dict[str, Any] | None = None | |
| self._materialized_patch_count = 0 | |
| import threading | |
| self._lock = threading.Lock() | |
| self._materialized_base = None | |
| self._materialized_data: dict[str, Any] | None = None | |
| self._materialized_patch_count = 0 |
| def _deserializer(self, base: dict, deltas: List) -> dict: | ||
| transitions = list(base.get("transitions", [])) | ||
| for delta in deltas: | ||
| transitions.extend(delta) | ||
| return { | ||
| "transitions": transitions, | ||
| "total_added": base.get("total_added", 0) + sum(len(d) for d in deltas), | ||
| } | ||
| if ( | ||
| self._materialized_data is None | ||
| or self._materialized_base is not base | ||
| or self._materialized_patch_count > len(deltas) | ||
| ): | ||
| self._materialized_base = base | ||
| self._materialized_data = { | ||
| "transitions": list(base.get("transitions", [])), | ||
| "total_added": base.get("total_added", 0), | ||
| } | ||
| self._materialized_patch_count = 0 | ||
|
|
||
| for delta in deltas[self._materialized_patch_count :]: | ||
| self._materialized_data["transitions"].extend(delta) | ||
| self._materialized_data["total_added"] += len(delta) | ||
|
|
||
| self._materialized_patch_count = len(deltas) | ||
| return self._materialized_data |
There was a problem hiding this comment.
To ensure thread safety when updating the materialized buffer state, wrap the logic in a lock. This prevents multiple threads from concurrently extending the transitions list or updating the patch count, which could otherwise lead to duplicate data or inconsistent state.
| def _deserializer(self, base: dict, deltas: List) -> dict: | |
| transitions = list(base.get("transitions", [])) | |
| for delta in deltas: | |
| transitions.extend(delta) | |
| return { | |
| "transitions": transitions, | |
| "total_added": base.get("total_added", 0) + sum(len(d) for d in deltas), | |
| } | |
| if ( | |
| self._materialized_data is None | |
| or self._materialized_base is not base | |
| or self._materialized_patch_count > len(deltas) | |
| ): | |
| self._materialized_base = base | |
| self._materialized_data = { | |
| "transitions": list(base.get("transitions", [])), | |
| "total_added": base.get("total_added", 0), | |
| } | |
| self._materialized_patch_count = 0 | |
| for delta in deltas[self._materialized_patch_count :]: | |
| self._materialized_data["transitions"].extend(delta) | |
| self._materialized_data["total_added"] += len(delta) | |
| self._materialized_patch_count = len(deltas) | |
| return self._materialized_data | |
| def _deserializer(self, base: dict, deltas: List) -> dict: | |
| with self._lock: | |
| if ( | |
| self._materialized_data is None | |
| or self._materialized_base is not base | |
| or self._materialized_patch_count > len(deltas) | |
| ): | |
| self._materialized_base = base | |
| self._materialized_data = { | |
| "transitions": list(base.get("transitions", [])), | |
| "total_added": base.get("total_added", 0), | |
| } | |
| self._materialized_patch_count = 0 | |
| for delta in deltas[self._materialized_patch_count :]: | |
| self._materialized_data["transitions"].extend(delta) | |
| self._materialized_data["total_added"] += len(delta) | |
| self._materialized_patch_count = len(deltas) | |
| return self._materialized_data |
980ee6e to
563991f
Compare
563991f to
6e2ce65
Compare
Summary
Validation
cd examples/rl/replay_buffer && PYTHONPATH=../../../sdk/python/src ../../../sdk/python/.venv/bin/python -m pytest test_replay_buffer.py -qcd examples/rl/replay_buffer && uv run -n ruff check main.py replay_buffer.py test_replay_buffer.pycd examples/rl/replay_buffer && uv run -n ruff format --check main.py replay_buffer.py test_replay_buffer.pycd examples/rl/replay_buffer && uv run -n main.py --helpgit diff --checkNotes
tasks/remains local task-tracking and is not included in this PR.