Skip to content

Commit f09742e

Browse files
committed
improve class and variable naming
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
1 parent 71a0ecc commit f09742e

2 files changed

Lines changed: 23 additions & 33 deletions

File tree

skyrl/backends/jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
113113
default=None,
114114
description="Total number of processes in the multi-node cluster",
115115
)
116+
# RayJaxBackend configuration
116117
ray_actor_options: dict = Field(
117118
default_factory=dict,
118119
description="Options to pass to Ray actors (e.g., resources, num_cpus)",

skyrl/backends/ray_jax.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import socket
21
import ray
32
from cloudpathlib import AnyPath
43

@@ -7,18 +6,12 @@
76
from skyrl.tinker import types
87
from skyrl.utils.log import logger
98

10-
def get_free_port() -> int:
11-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
12-
s.bind(("", 0))
13-
return s.getsockname()[1]
14-
159

1610
@ray.remote
17-
class RayJaxActor:
18-
"""Ray Actor wrapper for JaxBackendImpl.
19-
20-
Each actor runs JaxBackendImpl and communicates with other actors
21-
via JAX distributed (NCCL) for data parallel operations.
11+
class RayJaxBackendImpl:
12+
"""RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl.
13+
14+
Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl.
2215
"""
2316
def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
2417
self.base_model = base_model
@@ -28,7 +21,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
2821

2922
if process_id == 0:
3023
self.node_ip = ray.util.get_node_ip_address()
31-
self.port = get_free_port()
24+
self.port = 7777
3225
self.coordinator_address = f"{self.node_ip}:{self.port}"
3326
else:
3427
self.coordinator_address = None
@@ -38,7 +31,7 @@ def get_coordinator_address(self) -> str:
3831

3932
def setup(self, coordinator_address: str | None = None):
4033
"""Initializes JAX distributed and creates JaxBackendImpl."""
41-
import jax # Import here to avoid issues or ensure it's loaded in the actor
34+
import jax
4235

4336
addr = coordinator_address or self.coordinator_address
4437
logger.info(f"Worker {self.process_id} initializing JAX distributed with coordinator {addr}")
@@ -51,10 +44,6 @@ def setup(self, coordinator_address: str | None = None):
5144
self.backend = JaxBackendImpl(self.base_model, self.config, self.process_id)
5245
logger.info(f"Worker {self.process_id} JaxBackendImpl initialized.")
5346

54-
# =========================================================================
55-
# Proxied Backend Methods
56-
# =========================================================================
57-
5847
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
5948
self.backend.create_model(model_id, lora_config)
6049

@@ -90,61 +79,61 @@ def get_metrics(self) -> types.EngineMetrics:
9079

9180

9281
class RayJaxBackend(AbstractBackend):
93-
"""Proxy Backend that orchestrates Ray actors for multi-node JAX execution.
94-
95-
This class runs in the driver program (Tinker Engine process) and proxies
82+
"""RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution.
83+
84+
This class runs in the driver program (along with Tinker API / Engine) and proxies
9685
commands to all JAX workers running as Ray actors.
9786
"""
9887
def __init__(self, base_model: str, config: JaxBackendConfig):
9988
self.base_model = base_model
10089
self.config = config.model_copy()
101-
90+
10291
num_processes = self.config.num_processes or 1
10392
self.config.num_processes = num_processes
104-
93+
10594
logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}")
10695

107-
# Instantiate a Ray placement group
96+
# Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig
10897
from ray.util.placement_group import placement_group
109-
logger.info("Instantiating Ray placement group for JAX workers...")
98+
logger.info("Creating Ray placement group for JAX backend")
11099
bundles = self.config.ray_pg_bundles
111100
if not bundles:
112101
bundles = [{"CPU": 1}] * num_processes
113102
self.pg = placement_group(bundles, strategy="SPREAD")
114103
ray.get(self.pg.ready())
115104

116105
self.workers = []
117-
118-
# Create worker 0 (coordinator)
106+
107+
# node0 (coordinator)
108+
logger.info("Scheduling Ray actor for node0 (JAX coordinator)")
119109
w0_options = self.config.ray_actor_options.copy()
120110
w0_options.update({
121111
"placement_group": self.pg,
122112
"placement_group_bundle_index": 0,
123113
})
124-
w0 = RayJaxActor.options(**w0_options).remote(self.base_model, self.config, 0)
114+
w0 = RayJaxBackendImpl.options(**w0_options).remote(self.base_model, self.config, 0)
125115
self.workers.append(w0)
126116

127-
# Retrieve dynamically allocated coordinator address from actor 0
128117
coordinator_address = ray.get(w0.get_coordinator_address.remote())
129-
130-
# Create other workers
118+
119+
# Create remaining node1 - nodeN for multi-node training.
120+
logger.info("Scheduling remaining Ray actors (JAX workers)")
131121
for i in range(1, num_processes):
132122
wi_options = self.config.ray_actor_options.copy()
133123
wi_options.update({
134124
"placement_group": self.pg,
135125
"placement_group_bundle_index": i,
136126
})
137-
w = RayJaxActor.options(**wi_options).remote(self.base_model, self.config, i)
127+
w = RayJaxBackendImpl.options(**wi_options).remote(self.base_model, self.config, i)
138128
self.workers.append(w)
139129

140-
# Trigger setup on all workers
141130
# This will block until JAX distributed is initialized on all workers
142131
setup_refs = [w0.setup.remote()]
143132
for w in self.workers[1:]:
144133
setup_refs.append(w.setup.remote(coordinator_address))
145134

146135
ray.get(setup_refs)
147-
logger.info("RayJaxBackend is fully initialized and distributed cluster is ready.")
136+
logger.info("RayJaxBackend is fully initialized and distributed JAX cluster is ready.")
148137

149138
@property
150139
def metrics(self) -> types.EngineMetrics:

0 commit comments

Comments
 (0)