1- import socket
21import ray
32from cloudpathlib import AnyPath
43
76from skyrl .tinker import types
87from skyrl .utils .log import logger
98
10- def get_free_port () -> int :
11- with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as s :
12- s .bind (("" , 0 ))
13- return s .getsockname ()[1 ]
14-
159
1610@ray .remote
17- class RayJaxActor :
18- """Ray Actor wrapper for JaxBackendImpl.
19-
20- Each actor runs JaxBackendImpl and communicates with other actors
21- via JAX distributed (NCCL) for data parallel operations.
11+ class RayJaxBackendImpl :
12+ """RayJaxBackendImpl is a Ray wrapper for JaxBackendImpl.
13+
14+ Each actor calls jax.distributed.initialize() and holds an instance of JaxBackendImpl.
2215 """
2316 def __init__ (self , base_model : str , config : JaxBackendConfig , process_id : int ):
2417 self .base_model = base_model
@@ -28,7 +21,7 @@ def __init__(self, base_model: str, config: JaxBackendConfig, process_id: int):
2821
2922 if process_id == 0 :
3023 self .node_ip = ray .util .get_node_ip_address ()
31- self .port = get_free_port ()
24+ self .port = 7777
3225 self .coordinator_address = f"{ self .node_ip } :{ self .port } "
3326 else :
3427 self .coordinator_address = None
@@ -38,7 +31,7 @@ def get_coordinator_address(self) -> str:
3831
3932 def setup (self , coordinator_address : str | None = None ):
4033 """Initializes JAX distributed and creates JaxBackendImpl."""
41- import jax # Import here to avoid issues or ensure it's loaded in the actor
34+ import jax
4235
4336 addr = coordinator_address or self .coordinator_address
4437 logger .info (f"Worker { self .process_id } initializing JAX distributed with coordinator { addr } " )
@@ -51,10 +44,6 @@ def setup(self, coordinator_address: str | None = None):
5144 self .backend = JaxBackendImpl (self .base_model , self .config , self .process_id )
5245 logger .info (f"Worker { self .process_id } JaxBackendImpl initialized." )
5346
54- # =========================================================================
55- # Proxied Backend Methods
56- # =========================================================================
57-
5847 def create_model (self , model_id : str , lora_config : types .LoraConfig ) -> None :
5948 self .backend .create_model (model_id , lora_config )
6049
@@ -90,61 +79,61 @@ def get_metrics(self) -> types.EngineMetrics:
9079
9180
9281class RayJaxBackend (AbstractBackend ):
93- """Proxy Backend that orchestrates Ray actors for multi-node JAX execution.
94-
95- This class runs in the driver program (Tinker Engine process ) and proxies
82+ """RayJaxBackend is a proxy Backend that orchestrates Ray actors for multi-node JAX execution.
83+
84+ This class runs in the driver program (along with Tinker API / Engine ) and proxies
9685 commands to all JAX workers running as Ray actors.
9786 """
9887 def __init__ (self , base_model : str , config : JaxBackendConfig ):
9988 self .base_model = base_model
10089 self .config = config .model_copy ()
101-
90+
10291 num_processes = self .config .num_processes or 1
10392 self .config .num_processes = num_processes
104-
93+
10594 logger .info (f"Initializing RayJaxBackend with num_processes={ num_processes } " )
10695
107- # Instantiate a Ray placement group
96+ # Initialize a Ray placement group based on ray_pg_bundles in JaxBackendConfig
10897 from ray .util .placement_group import placement_group
109- logger .info ("Instantiating Ray placement group for JAX workers... " )
98+ logger .info ("Creating Ray placement group for JAX backend " )
11099 bundles = self .config .ray_pg_bundles
111100 if not bundles :
112101 bundles = [{"CPU" : 1 }] * num_processes
113102 self .pg = placement_group (bundles , strategy = "SPREAD" )
114103 ray .get (self .pg .ready ())
115104
116105 self .workers = []
117-
118- # Create worker 0 (coordinator)
106+
107+ # node0 (coordinator)
108+ logger .info ("Scheduling Ray actor for node0 (JAX coordinator)" )
119109 w0_options = self .config .ray_actor_options .copy ()
120110 w0_options .update ({
121111 "placement_group" : self .pg ,
122112 "placement_group_bundle_index" : 0 ,
123113 })
124- w0 = RayJaxActor .options (** w0_options ).remote (self .base_model , self .config , 0 )
114+ w0 = RayJaxBackendImpl .options (** w0_options ).remote (self .base_model , self .config , 0 )
125115 self .workers .append (w0 )
126116
127- # Retrieve dynamically allocated coordinator address from actor 0
128117 coordinator_address = ray .get (w0 .get_coordinator_address .remote ())
129-
130- # Create other workers
118+
119+ # Create remaining node1 - nodeN for multi-node training.
120+ logger .info ("Scheduling remaining Ray actors (JAX workers)" )
131121 for i in range (1 , num_processes ):
132122 wi_options = self .config .ray_actor_options .copy ()
133123 wi_options .update ({
134124 "placement_group" : self .pg ,
135125 "placement_group_bundle_index" : i ,
136126 })
137- w = RayJaxActor .options (** wi_options ).remote (self .base_model , self .config , i )
127+ w = RayJaxBackendImpl .options (** wi_options ).remote (self .base_model , self .config , i )
138128 self .workers .append (w )
139129
140- # Trigger setup on all workers
141130 # This will block until JAX distributed is initialized on all workers
142131 setup_refs = [w0 .setup .remote ()]
143132 for w in self .workers [1 :]:
144133 setup_refs .append (w .setup .remote (coordinator_address ))
145134
146135 ray .get (setup_refs )
147- logger .info ("RayJaxBackend is fully initialized and distributed cluster is ready." )
136+ logger .info ("RayJaxBackend is fully initialized and distributed JAX cluster is ready." )
148137
149138 @property
150139 def metrics (self ) -> types .EngineMetrics :
0 commit comments