Skip to content

Commit 7735f26

Browse files
committed
improve class and variable naming
Signed-off-by: Andrew Sy Kim <andrewsy@google.com>
1 parent 71a0ecc commit 7735f26

2 files changed

Lines changed: 17 additions & 26 deletions

File tree

skyrl/backends/jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class JaxBackendConfig(BaseModel, extra="forbid"):
113113
default=None,
114114
description="Total number of processes in the multi-node cluster",
115115
)
116+
# RayJaxBackend configuration
116117
ray_actor_options: dict = Field(
117118
default_factory=dict,
118119
description="Options to pass to Ray actors (e.g., resources, num_cpus)",

skyrl/backends/ray_jax.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,12 @@
77
from skyrl.tinker import types
88
from skyrl.utils.log import logger
99

10-
def get_free_port() -> int:
11-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
12-
s.bind(("", 0))
13-
return s.getsockname()[1]
14-
1510

1611
@ray.remote
17-
class RayJaxActor:
18-
"""Ray Actor wrapper for JaxBackendImpl.
12+
class RayJaxBackendImpl:
13+
"""RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl.
1914
20-
Each actor runs JaxBackendImpl and communicates with other actors
21-
via JAX distributed (NCCL) for data parallel operations.
15+
Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl.
2216
"""
2317
def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
2418
self.base_model = base_model
@@ -28,7 +22,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
2822

2923
if process_id == 0:
3024
self.node_ip = ray.util.get_node_ip_address()
31-
self.port = get_free_port()
25+
self.port = 7777
3226
self.coordinator_address = f"{self.node_ip}:{self.port}"
3327
else:
3428
self.coordinator_address = None
@@ -38,7 +32,7 @@ def get_coordinator_address(self) -> str:
3832

3933
def setup(self, coordinator_address: str | None = None):
4034
"""Initializes JAX distributed and creates JaxBackendImpl."""
41-
import jax # Import here to avoid issues or ensure it's loaded in the actor
35+
import jax
4236

4337
addr = coordinator_address or self.coordinator_address
4438
logger.info(f"Worker {self.process_id} initializing JAX distributed with coordinator {addr}")
@@ -51,10 +45,6 @@ def setup(self, coordinator_address: str | None = None):
5145
self.backend = JaxBackendImpl(self.base_model, self.config, self.process_id)
5246
logger.info(f"Worker {self.process_id} JaxBackendImpl initialized.")
5347

54-
# =========================================================================
55-
# Proxied Backend Methods
56-
# =========================================================================
57-
5848
def create_model(self, model_id: str, lora_config: types.LoraConfig) -> None:
5949
self.backend.create_model(model_id, lora_config)
6050

@@ -90,9 +80,9 @@ def get_metrics(self) -> types.EngineMetrics:
9080

9181

9282
class RayJaxBackend(AbstractBackend):
93-
"""Proxy Backend that orchestrates Ray actors for multi-node JAX execution.
83+
"""RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution.
9484
95-
This class runs in the driver program (Tinker Engine process) and proxies
85+
This class runs in the driver program (along with Tinker API / Engine) and proxies
9686
commands to all JAX workers running as Ray actors.
9787
"""
9888
def __init__(self, base_model: str, config: JaxBackendConfig):
@@ -104,9 +94,9 @@ def __init__(self, base_model: str, config: JaxBackendConfig):
10494

10595
logger.info(f"Initializing RayJaxBackend with num_processes={num_processes}")
10696

107-
# Instantiate a Ray placement group
97+
# Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig
10898
from ray.util.placement_group import placement_group
109-
logger.info("Instantiating Ray placement group for JAX workers...")
99+
logger.info("Creating Ray placement group for JAX backend")
110100
bundles = self.config.ray_pg_bundles
111101
if not bundles:
112102
bundles = [{"CPU": 1}] * num_processes
@@ -115,36 +105,36 @@ def __init__(self, base_model: str, config: JaxBackendConfig):
115105

116106
self.workers = []
117107

118-
# Create worker 0 (coordinator)
108+
# node0 (coordinator)
109+
logger.info("Scheduling Ray actor for node0 (JAX coordinator)")
119110
w0_options = self.config.ray_actor_options.copy()
120111
w0_options.update({
121112
"placement_group": self.pg,
122113
"placement_group_bundle_index": 0,
123114
})
124-
w0 = RayJaxActor.options(**w0_options).remote(self.base_model, self.config, 0)
115+
w0 = RayJaxBackendImpl.options(**w0_options).remote(self.base_model, self.config, 0)
125116
self.workers.append(w0)
126117

127-
# Retrieve dynamically allocated coordinator address from actor 0
128118
coordinator_address = ray.get(w0.get_coordinator_address.remote())
129119

130-
# Create other workers
120+
# Create remaining node1 - nodeN for multi-node training.
121+
logger.info("Scheduling remaining Ray actors (JAX workers)")
131122
for i in range(1, num_processes):
132123
wi_options = self.config.ray_actor_options.copy()
133124
wi_options.update({
134125
"placement_group": self.pg,
135126
"placement_group_bundle_index": i,
136127
})
137-
w = RayJaxActor.options(**wi_options).remote(self.base_model, self.config, i)
128+
w = RayJaxBackendImpl.options(**wi_options).remote(self.base_model, self.config, i)
138129
self.workers.append(w)
139130

140-
# Trigger setup on all workers
141131
# This will block until JAX distributed is initialized on all workers
142132
setup_refs = [w0.setup.remote()]
143133
for w in self.workers[1:]:
144134
setup_refs.append(w.setup.remote(coordinator_address))
145135

146136
ray.get(setup_refs)
147-
logger.info("RayJaxBackend is fully initialized and distributed cluster is ready.")
137+
logger.info("RayJaxBackend is fully initialized and distributed JAX cluster is ready.")
148138

149139
@property
150140
def metrics(self) -> types.EngineMetrics:

0 commit comments

Comments
 (0)