77from skyrl .tinker import types
88from skyrl .utils .log import logger
99
10- def get_free_port () -> int :
11- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
12- s .bind (("" , 0 ))
13- return s .getsockname ()[1 ]
14-
1510
1611@ray .remote
17- class RayJaxActor :
18- """Ray Actor wrapper for JaxBackendImpl.
12+ class RayJaxBackendImpl :
13+ """RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl.
1914
20- Each actor runs JaxBackendImpl and communicates with other actors
21- via JAX distributed (NCCL) for data parallel operations.
15+ Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl.
2216 """
2317 def __init__ (self , base_model : str , config : JaxBackendConfig , process_id : int ):
2418 self .base_model = base_model
@@ -28,7 +22,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
2822
2923 if process_id == 0 :
3024 self .node_ip = ray .util .get_node_ip_address ()
31- self .port = get_free_port ()
25+ self .port = 7777
3226 self .coordinator_address = f"{ self .node_ip } :{ self .port } "
3327 else :
3428 self .coordinator_address = None
@@ -38,7 +32,7 @@ def get_coordinator_address(self) -> str:
3832
3933 def setup (self , coordinator_address : str | None = None ):
4034 """Initializes JAX distributed and creates JaxBackendImpl."""
41- import jax # Import here to avoid issues or ensure it's loaded in the actor
35+ import jax
4236
4337 addr = coordinator_address or self .coordinator_address
4438 logger .info (f"Worker { self .process_id } initializing JAX distributed with coordinator { addr } " )
@@ -51,10 +45,6 @@ def setup(self, coordinator_address: str | None = None):
5145 self .backend = JaxBackendImpl (self .base_model , self .config , self .process_id )
5246 logger .info (f"Worker { self .process_id } JaxBackendImpl initialized." )
5347
54- # =========================================================================
55- # Proxied Backend Methods
56- # =========================================================================
57-
5848 def create_model (self , model_id : str , lora_config : types .LoraConfig ) -> None :
5949 self .backend .create_model (model_id , lora_config )
6050
@@ -90,9 +80,9 @@ def get_metrics(self) -> types.EngineMetrics:
9080
9181
9282class RayJaxBackend (AbstractBackend ):
93- """Proxy Backend that orchestrates Ray actors for multi-node JAX execution.
83+ """RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution.
9484
95- This class runs in the driver program (Tinker Engine process ) and proxies
85+ This class runs in the driver program (along with Tinker API / Engine ) and proxies
9686 commands to all JAX workers running as Ray actors.
9787 """
9888 def __init__ (self , base_model : str , config : JaxBackendConfig ):
@@ -104,9 +94,9 @@ def __init__(self, base_model: str, config: JaxBackendConfig):
10494
10595 logger .info (f"Initializing RayJaxBackend with num_processes={ num_processes } " )
10696
107- # Instantiate a Ray placement group
97+ # Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig
10898 from ray .util .placement_group import placement_group
109- logger .info ("Instantiating Ray placement group for JAX workers... " )
99+ logger .info ("Creating Ray placement group for JAX backend " )
110100 bundles = self .config .ray_pg_bundles
111101 if not bundles :
112102 bundles = [{"CPU" : 1 }] * num_processes
@@ -115,36 +105,36 @@ def __init__(self, base_model: str, config: JaxBackendConfig):
115105
116106 self .workers = []
117107
118- # Create worker 0 (coordinator)
108+ # node0 (coordinator)
109+ logger .info ("Scheduling Ray actor for node0 (JAX coordinator)" )
119110 w0_options = self .config .ray_actor_options .copy ()
120111 w0_options .update ({
121112 "placement_group" : self .pg ,
122113 "placement_group_bundle_index" : 0 ,
123114 })
124- w0 = RayJaxActor .options (** w0_options ).remote (self .base_model , self .config , 0 )
115+ w0 = RayJaxBackendImpl .options (** w0_options ).remote (self .base_model , self .config , 0 )
125116 self .workers .append (w0 )
126117
127- # Retrieve dynamically allocated coordinator address from actor 0
128118 coordinator_address = ray .get (w0 .get_coordinator_address .remote ())
129119
130- # Create other workers
120+ # Create remaining node1 - nodeN for multi-node training.
121+ logger .info ("Scheduling remaining Ray actors (JAX workers)" )
131122 for i in range (1 , num_processes ):
132123 wi_options = self .config .ray_actor_options .copy ()
133124 wi_options .update ({
134125 "placement_group" : self .pg ,
135126 "placement_group_bundle_index" : i ,
136127 })
137- w = RayJaxActor .options (** wi_options ).remote (self .base_model , self .config , i )
128+ w = RayJaxBackendImpl .options (** wi_options ).remote (self .base_model , self .config , i )
138129 self .workers .append (w )
139130
140- # Trigger setup on all workers
141131 # This will block until JAX distributed is initialized on all workers
142132 setup_refs = [w0 .setup .remote ()]
143133 for w in self .workers [1 :]:
144134 setup_refs .append (w .setup .remote (coordinator_address ))
145135
146136 ray .get (setup_refs )
147- logger .info ("RayJaxBackend is fully initialized and distributed cluster is ready." )
137+ logger .info ("RayJaxBackend is fully initialized and distributed JAX cluster is ready." )
148138
149139 @property
150140 def metrics (self ) -> types .EngineMetrics :
0 commit comments