diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 00000000..d582505a --- /dev/null +++ b/codecov.yml @@ -0,0 +1,8 @@ +coverage: + status: + project: + default: + target: auto + threshold: 1% + patch: + default: off diff --git a/examples/rl/replay_buffer/README.md b/examples/rl/replay_buffer/README.md new file mode 100644 index 00000000..6fbb38bf --- /dev/null +++ b/examples/rl/replay_buffer/README.md @@ -0,0 +1,132 @@ +# Distributed Replay Buffer with patch_object + +`ReplayBuffer` wraps an `ObjectRef` - collectors call `buffer.push()` which uses `patch_object` for efficient writes, while the buffer service handles state queries and sampling. + +## Key Pattern: Shared Object + patch_object + +```python +with Runner("replay-buffer") as rr: + # Create ReplayBuffer (creates ObjectRef internally via runner) + buffer = ReplayBuffer(rr) + + # Wrap as service for state/sample operations + buffer_svc = rr.service(buffer, autoscale=False) + collector = rr.service(Collector(env_name), autoscale=True) + + # Pass the SAME buffer object to collectors (pickled with its ObjectRef) + collect_futures = [collector.collect(buffer, num_steps) for _ in range(num_collections)] + rr.get(collect_futures) + + # Query state and sample from service + stats = buffer_svc.state().get() + batch = buffer_svc.sample(batch_size).get() +``` + +## Why This Pattern? + +| Approach | Data Flow | Network Hops | +|----------|-----------|--------------| +| Service only | Collector → Service → Buffer | 2 per worker | +| **patch_object** | Collector → ObjectRef (cache) | 1 per worker | + +- `ReplayBuffer` holds an `ObjectRef` pointing to shared data in Flame cache +- When pickled as parameter, collectors get the same `ObjectRef` +- `buffer.push()` calls `patch_object` - writes directly to cache +- `get_object` with deserializer consolidates patches when reading + +## Usage + +### Distributed Mode + +```shell +docker compose exec -it flame-console /bin/bash +cd /opt/examples/rl/replay_buffer +uv run main.py +``` + +### Local Mode + +```shell +uv run main.py --local +``` + +### Options + +| Flag | Description | Default | +|------|-------------|---------| +| `--env` | Gymnasium environment | CartPole-v1 | +| `--local` | Run without Flame cluster | Off | +| `--iterations` | Collection iterations | 10 | +| `--collections` | Collections per iteration | 4 | +| `--steps-per-collection` | Steps per collection task | 100 | +| `--batch-size` | Sample batch size | 32 | + +## Example Output + +``` +============================================================ +Distributed Replay Buffer (patch_object) +============================================================ + +Configuration: + Environment: CartPole-v1 + Collections per iteration: 4 + Steps per collection: 100 + Iterations: 10 + Batch size: 32 + +Starting distributed collection... +Iteration 0 | Buffer: 400 | Total added: 400 | Avg Reward: 22.5 + | Sampled batch of 32 transitions +Iteration 1 | Buffer: 800 | Total added: 800 | Avg Reward: 21.8 + | Sampled batch of 32 transitions +... + +============================================================ +Collection Complete! + Total time: 2.45s + Total transitions: 4000 + Throughput: 1632.7 transitions/sec +============================================================ +``` + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Main (Learner) │ +│ │ +│ buffer_svc.state() buffer_svc.sample(batch) │ +│ │ │ │ +└───────┼─────────────────────────────────┼───────────────────┘ + │ │ + ▼ ▼ +┌─────────────────────────────────────────────────────────────┐ +│ ReplayBuffer Service │ +│ │ +│ Wraps ObjectRef - handles push/state/sample │ +│ │ +│ push(transitions) - patch_object to ObjectRef │ +│ state() - get buffer stats (size, total_added) │ +│ sample(batch_size) - random sample from buffer │ +│ │ +└──────────────────────────┬──────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ ObjectRef (Flame Cache) │ +│ │ +│ {"transitions": [...], "total_added": N} │ +│ │ +└──────────────────────────┬──────────────────────────────────┘ + ▲ + ┌──────────────────┼──────────────────┐ + │ │ │ + │ buffer.push() │ buffer.push() │ buffer.push() + │ │ │ +┌───────┴─────┐ ┌───────┴─────┐ ┌───────┴─────┐ +│ Collector 1 │ │ Collector 2 │ │ Collector N │ +│ (env) │ │ (env) │ │ (env) │ +└─────────────┘ └─────────────┘ └─────────────┘ + Flame Cluster +``` diff --git a/examples/rl/replay_buffer/collector.py b/examples/rl/replay_buffer/collector.py new file mode 100644 index 00000000..b7f1b4d1 --- /dev/null +++ b/examples/rl/replay_buffer/collector.py @@ -0,0 +1,47 @@ +from replay_buffer import ReplayBuffer + + +class Collector: + def __init__(self, env_name: str): + import gymnasium as gym + + self.env = gym.make(env_name) + self.state, _ = self.env.reset() + self.episode_reward = 0.0 + self.episode_count = 0 + self.total_reward = 0.0 + + def collect(self, buffer: ReplayBuffer, num_steps: int) -> dict: + transitions = [] + + for _ in range(num_steps): + action = self.env.action_space.sample() + next_state, reward, terminated, truncated, _ = self.env.step(action) + done = terminated or truncated + self.episode_reward += reward + + transitions.append( + { + "state": self.state.tolist(), + "action": int(action), + "reward": float(reward), + "next_state": next_state.tolist(), + "done": done, + } + ) + + if done: + self.state, _ = self.env.reset() + self.total_reward += self.episode_reward + self.episode_count += 1 + self.episode_reward = 0.0 + else: + self.state = next_state + + buffer.push(transitions) + + return { + "num_transitions": len(transitions), + "episode_count": self.episode_count, + "avg_reward": self.total_reward / max(1, self.episode_count), + } diff --git a/examples/rl/replay_buffer/main.py b/examples/rl/replay_buffer/main.py new file mode 100644 index 00000000..86a24a71 --- /dev/null +++ b/examples/rl/replay_buffer/main.py @@ -0,0 +1,211 @@ +""" +Distributed Replay Buffer using patch_object for efficient data movement. + +ReplayBuffer wraps an ObjectRef - collectors patch transitions directly to it, +avoiding data transfer through the service. The service handles merging and sampling. + +Use --local flag for local mode without a Flame cluster. +""" + +from collector import Collector +from replay_buffer import ReplayBuffer + + +def run_distributed( + env_name: str = "CartPole-v1", + num_iterations: int = 50, + num_collections: int = 20, + steps_per_collection: int = 500, + batch_size: int = 64, +): + import time + + from flamepy.runner import Runner + + print("=" * 60) + print("Distributed Replay Buffer (patch_object)") + print("=" * 60) + print("\nConfiguration:") + print(f" Environment: {env_name}") + print(f" Collections per iteration: {num_collections}") + print(f" Steps per collection: {steps_per_collection}") + print(f" Iterations: {num_iterations}") + print(f" Batch size: {batch_size}") + print("\nStarting distributed collection...") + + start_time = time.time() + + with Runner(f"replay-buffer-{env_name.lower()}") as rr: + buffer = ReplayBuffer(rr) + buffer_svc = rr.service(buffer, autoscale=False, warmup=1) + collector = rr.service(Collector(env_name), autoscale=True) + + for iteration in range(num_iterations): + collect_futures = [ + collector.collect(buffer, steps_per_collection) + for _ in range(num_collections) + ] + collect_results = rr.get(collect_futures) + + if iteration % 5 == 4: + buffer_svc.merge().wait() + + stats = buffer_svc.state().get() + total_size = stats["size"] + total_added = stats["total_added"] + total_episodes = sum(r["episode_count"] for r in collect_results) + avg_reward = sum(r["avg_reward"] * r["episode_count"] for r in collect_results) / max(1, total_episodes) + + print( + f"Iteration {iteration:2d} | " + f"Buffer: {total_size:6d} | " + f"Total added: {total_added:6d} | " + f"Avg Reward: {avg_reward:7.1f}" + ) + + if total_size >= batch_size: + batch = buffer_svc.sample(batch_size).get() + print(f" | Sampled batch of {len(batch)} transitions") + + elapsed = time.time() - start_time + print("\n" + "=" * 60) + print("Collection Complete!") + print(f" Total time: {elapsed:.2f}s") + print(f" Total transitions: {total_added}") + print(f" Throughput: {total_added / elapsed:.1f} transitions/sec") + print("=" * 60) + + +def run_local( + env_name: str = "CartPole-v1", + num_iterations: int = 50, + steps_per_iteration: int = 2000, + batch_size: int = 64, +): + import random + import time + + import gymnasium as gym + from collections import deque + + print("=" * 60) + print("Local Replay Buffer") + print("=" * 60) + print("\nConfiguration:") + print(f" Environment: {env_name}") + print(f" Steps per iteration: {steps_per_iteration}") + print(f" Iterations: {num_iterations}") + print(f" Batch size: {batch_size}") + print("\nStarting local collection...") + + start_time = time.time() + env = gym.make(env_name) + buffer: deque = deque(maxlen=100000) + state, _ = env.reset() + episode_reward = 0.0 + episode_count = 0 + total_reward = 0.0 + total_added = 0 + + for iteration in range(num_iterations): + for _ in range(steps_per_iteration): + action = env.action_space.sample() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + episode_reward += reward + + buffer.append( + { + "state": state.tolist(), + "action": int(action), + "reward": float(reward), + "next_state": next_state.tolist(), + "done": done, + } + ) + total_added += 1 + + if done: + state, _ = env.reset() + total_reward += episode_reward + episode_count += 1 + episode_reward = 0.0 + else: + state = next_state + + avg_reward = total_reward / max(1, episode_count) + print( + f"Iteration {iteration:2d} | " + f"Buffer: {len(buffer):6d} | " + f"Total added: {total_added:6d} | " + f"Avg Reward: {avg_reward:7.1f}" + ) + + if len(buffer) >= batch_size: + batch = random.sample(list(buffer), batch_size) + print(f" | Sampled batch of {len(batch)} transitions") + + env.close() + + elapsed = time.time() - start_time + print("\n" + "=" * 60) + print("Collection Complete!") + print(f" Total time: {elapsed:.2f}s") + print(f" Total transitions: {total_added}") + print(f" Throughput: {total_added / elapsed:.1f} transitions/sec") + print("=" * 60) + + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Distributed Replay Buffer Service") + parser.add_argument( + "--env", + type=str, + default="CartPole-v1", + help="Gymnasium environment (default: CartPole-v1)", + ) + parser.add_argument( + "--local", action="store_true", help="Run local mode (no Flame cluster)" + ) + parser.add_argument( + "--iterations", type=int, default=50, help="Number of collection iterations" + ) + parser.add_argument( + "--collections", + type=int, + default=20, + help="Number of collections per iteration", + ) + parser.add_argument( + "--steps-per-collection", + type=int, + default=500, + help="Steps per collection task", + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size for sampling" + ) + + args = parser.parse_args() + + if args.local: + run_local( + env_name=args.env, + num_iterations=args.iterations, + steps_per_iteration=args.collections * args.steps_per_collection, + batch_size=args.batch_size, + ) + else: + run_distributed( + env_name=args.env, + num_iterations=args.iterations, + num_collections=args.collections, + steps_per_collection=args.steps_per_collection, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/rl/replay_buffer/pyproject.toml b/examples/rl/replay_buffer/pyproject.toml new file mode 100644 index 00000000..48a98905 --- /dev/null +++ b/examples/rl/replay_buffer/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "replay-buffer" +version = "0.1.0" +description = "Distributed Replay Buffer Example using Flame's patch_object API" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "numpy", + "gymnasium", +] + +[dependency-groups] +dev = ["flamepy"] + +[tool.setuptools] +py-modules = ["main", "replay_buffer", "collector"] diff --git a/examples/rl/replay_buffer/replay_buffer.py b/examples/rl/replay_buffer/replay_buffer.py new file mode 100644 index 00000000..ebecb46b --- /dev/null +++ b/examples/rl/replay_buffer/replay_buffer.py @@ -0,0 +1,49 @@ +import random +from typing import Any, List, TYPE_CHECKING + +if TYPE_CHECKING: + from flamepy.runner import Runner + + +class ReplayBuffer: + def __init__(self, rr: "Runner"): + from flamepy.core import get_object, patch_object, update_object + + self.buffer_ref = rr.put_object({"transitions": [], "total_added": 0}) + self._get_object = get_object + self._update_object = update_object + self._patch_object = patch_object + + def _deserializer(self, base: dict, deltas: List) -> dict: + transitions = list(base.get("transitions", [])) + for delta in deltas: + transitions.extend(delta) + return { + "transitions": transitions, + "total_added": base.get("total_added", 0) + sum(len(d) for d in deltas), + } + + def _fetch(self) -> dict: + return self._get_object(self.buffer_ref, deserializer=self._deserializer) + + def push(self, transitions: List[dict]) -> None: + self._patch_object(self.buffer_ref, transitions) + + def merge(self) -> None: + data = self._fetch() + self.buffer_ref = self._update_object(self.buffer_ref, data) + + def get(self) -> dict: + return self._fetch() + + def state(self) -> dict[str, Any]: + data = self._fetch() + return { + "size": len(data.get("transitions", [])), + "total_added": data.get("total_added", 0), + } + + def sample(self, batch_size: int) -> List[dict]: + data = self._fetch() + items = data.get("transitions", []) + return random.sample(items, min(batch_size, len(items))) diff --git a/object_cache/src/cache.rs b/object_cache/src/cache.rs index ffcad615..2da63afa 100644 --- a/object_cache/src/cache.rs +++ b/object_cache/src/cache.rs @@ -1285,7 +1285,11 @@ pub async fn run(cache_config: &FlameCache) -> Result<(), FlameError> { } builder - .add_service(FlightServiceServer::new(server)) + .add_service( + FlightServiceServer::new(server) + .max_decoding_message_size(usize::MAX) + .max_encoding_message_size(usize::MAX), + ) .serve(addr) .await .map_err(|e| FlameError::Internal(format!("Server error: {}", e)))?; diff --git a/sdk/python/src/flamepy/core/cache.py b/sdk/python/src/flamepy/core/cache.py index b0c00f46..39a66cb0 100644 --- a/sdk/python/src/flamepy/core/cache.py +++ b/sdk/python/src/flamepy/core/cache.py @@ -265,16 +265,22 @@ def _normalize_endpoint(endpoint: str) -> str: return endpoint +GRPC_OPTIONS = [ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), +] + + def _create_flight_client(location: str, tls_config: Optional[FlameClientTls] = None) -> flight.FlightClient: if location.startswith("grpc+tls://"): if tls_config and tls_config.ca_file: with open(tls_config.ca_file, "rb") as f: root_certs = f.read() - return flight.FlightClient(location, tls_root_certs=root_certs) + return flight.FlightClient(location, tls_root_certs=root_certs, generic_options=GRPC_OPTIONS) else: - return flight.FlightClient(location) + return flight.FlightClient(location, generic_options=GRPC_OPTIONS) else: - return flight.FlightClient(location) + return flight.FlightClient(location, generic_options=GRPC_OPTIONS) def _remove_stale_client(location: str) -> None: