From 95e31d0ca4cb8ff2a180b9603bd17ab85a95ca57 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 20 Mar 2026 20:32:56 +0000 Subject: [PATCH 01/11] x Signed-off-by: ahao-anyscale --- pyproject.toml | 5 +- .../workers/megatron/megatron_worker.py | 53 ++++++++++++++++++- .../skyrl_train/gpu/gpu_ci/test_lora.py | 40 ++++++++++---- .../gpu/gpu_ci/test_megatron_worker.py | 1 - 4 files changed, 86 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4012edbc74..6b4a834dd1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -209,6 +209,7 @@ override-dependencies = [ "transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'", "megatron-core==0.16.0; sys_platform == 'linux'", "ml_dtypes>=0.5.0; sys_platform == 'linux'", + "transformers>=4.56.1,<5; sys_platform == 'linux'", ] [tool.uv.extra-build-dependencies] @@ -251,8 +252,8 @@ torchvision = [ { index = "pytorch-cu128", marker = "sys_platform == 'linux'" }, { index = "pytorch-cpu", marker = "sys_platform == 'darwin'" }, ] -# pin megatron bridge commit to fix for MoE + LoRA merging. Update this when an official release is cut -megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "02b5fccab5e5b21856d36c2e357839e0123b4b8f", marker = "sys_platform == 'linux'"} +# pin megatron bridge commit for LoRA adapter export support. Update this when an official release is cut +megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "f78c65f9", marker = "sys_platform == 'linux'"} harbor = { git = "https://github.com/laude-institute/harbor", rev = "8c040e1bb010201fd3c75bee3dede2407b9f57cd" } [tool.black] diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 636c518154..7c8ff42759 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -35,7 +35,11 @@ TrainingOutputBatch, ) from skyrl.backends.skyrl_train.utils.profiler import Profiler -from skyrl.backends.skyrl_train.weight_sync import WeightChunk, WeightExtractor +from skyrl.backends.skyrl_train.weight_sync import ( + LoraLoadRequest, + WeightChunk, + WeightExtractor, +) from skyrl.backends.skyrl_train.workers.megatron.megatron_model_wrapper import ( MegatronModelWrapper, ) @@ -802,6 +806,48 @@ async def init_weight_sync_state(self, inference_engine_client, inference_engine training_dtype=torch.bfloat16 if self.cfg.bf16 else torch.float32, ) + async def _save_lora_adapters_and_sync(self, lora_sync_path, inference_engine_client): + """Export LoRA adapter weights via Megatron-Bridge and tell the inference engine to load them. + + All ranks participate in the collective export (TP/PP/EP gathering is + handled internally by the bridge). Only rank 0 writes to disk and + sends the ``LoraLoadRequest``. + """ + import json + + from megatron.bridge.models.conversion.peft_bridge import ( + build_adapter_config_dict, + infer_target_modules_from_adapter_weights, + ) + from safetensors.torch import save_file + + adapter_state = {} + for name, tensor in self.bridge.export_adapter_weights(self.actor_module, cpu=True, show_progress=False): + adapter_state[f"base_model.model.{name}"] = tensor.clone().float() + + if torch.distributed.get_rank() == 0: + os.makedirs(lora_sync_path, exist_ok=True) + + target_modules = infer_target_modules_from_adapter_weights(adapter_state.keys()) + base_model_name_or_path = str( + getattr(self.bridge.hf_pretrained, "model_name_or_path", "") + or getattr(self.bridge.hf_pretrained, "name_or_path", "") + ) + adapter_config = build_adapter_config_dict( + self.lora_cls, + target_modules=target_modules, + base_model_name_or_path=base_model_name_or_path, + ) + + save_file(adapter_state, os.path.join(lora_sync_path, "adapter_model.safetensors")) + with open(os.path.join(lora_sync_path, "adapter_config.json"), "w", encoding="utf-8") as f: + json.dump(adapter_config, f, ensure_ascii=False, indent=4) + + lora_request = LoraLoadRequest(lora_path=lora_sync_path) + await inference_engine_client.update_named_weights(lora_request) + + torch.distributed.barrier() + async def broadcast_to_inference_engines( self, inference_engine_client: "InferenceEngineInterface", inference_engine_cfg: "InferenceEngineConfig" ): @@ -814,6 +860,11 @@ async def broadcast_to_inference_engines( torch.cuda.empty_cache() + if self._is_lora: + lora_sync_path = self.cfg.policy.model.lora.lora_sync_path + await self._save_lora_adapters_and_sync(lora_sync_path, inference_engine_client) + return + # Extract and send weights using the sender created at init time weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype) await self._weight_transfer_sender.send_chunks( diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py index b2df7308a6..1beb6112be 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py @@ -1,6 +1,9 @@ """ -# Run tests (requires fsdp extra): -uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py +# Run FSDP tests: +uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -k "fsdp" + +# Run Megatron tests: +uv run --isolated --extra dev --extra megatron pytest tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -k "megatron" """ import asyncio @@ -22,17 +25,30 @@ MODEL = "Qwen/Qwen2.5-0.5B-Instruct" -def get_test_actor_config(enable_lora: bool = False) -> SkyRLTrainConfig: +def get_test_actor_config( + strategy: str = "fsdp", + enable_lora: bool = False, + colocate_all: bool = False, + weight_sync_backend: str = "nccl", + tp_size: int = 2, +) -> SkyRLTrainConfig: """Get base config with test-specific overrides.""" cfg = SkyRLTrainConfig() cfg.trainer.policy.model.path = MODEL cfg.trainer.critic.model.path = "" + cfg.trainer.strategy = strategy + cfg.trainer.placement.colocate_all = colocate_all cfg.trainer.placement.policy_num_gpus_per_node = 2 cfg.generator.inference_engine.async_engine = True cfg.generator.inference_engine.num_engines = 1 cfg.generator.inference_engine.run_engines_locally = True + cfg.generator.inference_engine.weight_sync_backend = weight_sync_backend + cfg.generator.inference_engine.tensor_parallel_size = tp_size + + if strategy == "megatron": + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 - # LoRA configuration if enable_lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig( rank=32, @@ -51,23 +67,29 @@ def get_test_actor_config(enable_lora: bool = False) -> SkyRLTrainConfig: pytest.param(True, "nccl", "fsdp", 2), pytest.param(False, "nccl", "fsdp2", 2), pytest.param(True, "nccl", "fsdp2", 2), + pytest.param(False, "nccl", "megatron", 2, marks=pytest.mark.megatron), + pytest.param(True, "nccl", "megatron", 2, marks=pytest.mark.megatron), ], ids=[ "no_colocate_nccl_fsdp", "colocate_nccl_fsdp", "no_colocate_nccl_fsdp2", "colocate_nccl_fsdp2", + "no_colocate_nccl_megatron", + "colocate_nccl_megatron", ], ) def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_backend, strategy, tp_size): """ Tests initalizing the policy actor group and inference engine, syncing weights, and performing generation. """ - cfg = get_test_actor_config(enable_lora=True) - cfg.trainer.placement.colocate_all = colocate_all - cfg.generator.inference_engine.weight_sync_backend = weight_sync_backend - cfg.trainer.strategy = strategy - cfg.generator.inference_engine.tensor_parallel_size = tp_size + cfg = get_test_actor_config( + strategy=strategy, + enable_lora=True, + colocate_all=colocate_all, + weight_sync_backend=weight_sync_backend, + tp_size=tp_size, + ) # If colocate is True, this will load the engine, sleep, and wake up the engine with InferenceEngineState.create( diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py index 94726a4f4d..508cbdcced 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_megatron_worker.py @@ -122,7 +122,6 @@ def get_test_training_batch(batch_size=4) -> TrainingInputBatch: [ pytest.param(True, 4, 2, 2, 1, None, False, marks=_skip_new_inference, id="colocate_all"), pytest.param(False, 2, 2, 1, 1, None, False, id="non_colocated"), - pytest.param(True, 4, 2, 2, 1, None, True, marks=_skip_new_inference, id="colocate_all_lora"), ], ) @pytest.mark.megatron From 60bb1bc75e0fba5b341e42d2d49568b1c05f9401 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 20 Mar 2026 20:41:44 +0000 Subject: [PATCH 02/11] x Signed-off-by: ahao-anyscale --- .../workers/megatron/megatron_worker.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 7c8ff42759..268c465f08 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -863,14 +863,13 @@ async def broadcast_to_inference_engines( if self._is_lora: lora_sync_path = self.cfg.policy.model.lora.lora_sync_path await self._save_lora_adapters_and_sync(lora_sync_path, inference_engine_client) - return - - # Extract and send weights using the sender created at init time - weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype) - await self._weight_transfer_sender.send_chunks( - self.weight_extractor.extract_weights(generator_dtype), - weight_metadata=weight_metadata, - ) + else: + # Extract and send weights using the sender created at init time + weight_metadata = self.weight_extractor.get_weight_metadata(generator_dtype) + await self._weight_transfer_sender.send_chunks( + self.weight_extractor.extract_weights(generator_dtype), + weight_metadata=weight_metadata, + ) if cache_reset_task is not None: await cache_reset_task From ad0e4e1443a6fda500ac2552c95390ce96506946 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 20 Mar 2026 20:54:31 +0000 Subject: [PATCH 03/11] x Signed-off-by: ahao-anyscale --- .../skyrl_train/workers/megatron/megatron_worker.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 268c465f08..78e61bce43 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -843,8 +843,15 @@ async def _save_lora_adapters_and_sync(self, lora_sync_path, inference_engine_cl with open(os.path.join(lora_sync_path, "adapter_config.json"), "w", encoding="utf-8") as f: json.dump(adapter_config, f, ensure_ascii=False, indent=4) - lora_request = LoraLoadRequest(lora_path=lora_sync_path) - await inference_engine_client.update_named_weights(lora_request) + from skyrl.backends.skyrl_train.inference_servers.remote_inference_client import ( + RemoteInferenceClient, + ) + + if isinstance(inference_engine_client, RemoteInferenceClient): + await inference_engine_client.update_lora_from_disk(lora_sync_path) + else: + lora_request = LoraLoadRequest(lora_path=lora_sync_path) + await inference_engine_client.update_named_weights(lora_request) torch.distributed.barrier() From 25142068d884539dedf6a90f8c2ab04a2fad2ec5 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Fri, 20 Mar 2026 22:52:34 +0000 Subject: [PATCH 04/11] shutdown router fix Signed-off-by: ahao-anyscale --- .../skyrl_train/inference_servers/router.py | 38 ++++++++++++++++++- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/router.py b/skyrl/backends/skyrl_train/inference_servers/router.py index 441e598306..710741d92b 100644 --- a/skyrl/backends/skyrl_train/inference_servers/router.py +++ b/skyrl/backends/skyrl_train/inference_servers/router.py @@ -242,9 +242,43 @@ def _wait_until_healthy( raise RuntimeError(f"Router failed to start within {timeout}s") def shutdown(self) -> None: - """Shutdown the router gracefully.""" + """Shutdown the router and ensure the port is released.""" logger.info("Shutting down router...") + if self._server: self._server.should_exit = True + if self._server_thread: - self._server_thread.join(timeout=5) + self._server_thread.join(timeout=10) + + if self._server_thread.is_alive(): + logger.warning("Router thread did not exit gracefully, forcing server socket closure") + self._force_close_server_sockets() + + if self._client: + try: + asyncio.get_event_loop().run_until_complete(self._client.aclose()) + except Exception: + pass + self._client = None + + self._server = None + self._server_thread = None + self._app = None + + def _force_close_server_sockets(self) -> None: + """Force-close the underlying server sockets to release the port.""" + if self._server and hasattr(self._server, "servers"): + for server in self._server.servers: + server.close() + import socket + + try: + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.settimeout(1) + result = s.connect_ex((self._host if self._host != "0.0.0.0" else "127.0.0.1", self._port)) + s.close() + if result == 0: + logger.warning(f"Port {self._port} still in use after forced close") + except Exception: + pass From b7aabb16b9b6874854eafd26699c86bb40156228 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 23 Mar 2026 07:56:37 +0000 Subject: [PATCH 05/11] x Signed-off-by: ahao-anyscale --- ci/gpu_ci_run_skyrl_train.sh | 2 +- ci/gpu_ci_run_skyrl_train_megatron.sh | 3 ++ .../workers/megatron/megatron_worker.py | 2 +- skyrl/train/config/config.py | 1 + .../skyrl_train/gpu/gpu_ci/test_lora.py | 37 +++++++++++++------ tests/backends/skyrl_train/gpu/utils.py | 5 +++ 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/ci/gpu_ci_run_skyrl_train.sh b/ci/gpu_ci_run_skyrl_train.sh index 1ef5183b19..5ff22ebbd5 100755 --- a/ci/gpu_ci_run_skyrl_train.sh +++ b/ci/gpu_ci_run_skyrl_train.sh @@ -35,4 +35,4 @@ uv run --directory . --isolated --extra dev --extra fsdp pytest -s tests/backend _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py -_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py +_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -k "fsdp" diff --git a/ci/gpu_ci_run_skyrl_train_megatron.sh b/ci/gpu_ci_run_skyrl_train_megatron.sh index 788cbf28b8..be7b41c3e0 100755 --- a/ci/gpu_ci_run_skyrl_train_megatron.sh +++ b/ci/gpu_ci_run_skyrl_train_megatron.sh @@ -7,3 +7,6 @@ uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k # Run all megatron tests uv run --directory . --isolated --extra dev --extra megatron pytest -s tests/backends/skyrl_train/gpu/gpu_ci -m "megatron" +# Run megatron LoRA tests with new inference layer +_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra megatron pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -m "megatron" + diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 78e61bce43..ce2f075bd3 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -867,7 +867,7 @@ async def broadcast_to_inference_engines( torch.cuda.empty_cache() - if self._is_lora: + if self._is_lora and not self.cfg.policy.megatron_config.lora_config.merge_lora: lora_sync_path = self.cfg.policy.model.lora.lora_sync_path await self._save_lora_adapters_and_sync(lora_sync_path, inference_engine_client) else: diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 310f5da196..12ef011d7f 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -128,6 +128,7 @@ class MegatronTorchProfilerConfig(BaseConfig): @dataclass class MegatronLoraConfig(BaseConfig): lora_type: str = "lora" + merge_lora: bool = True DEFAULT_MEGATRON_OPTIMIZER_KWARGS = { diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py index 1beb6112be..067f41111f 100644 --- a/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py @@ -31,6 +31,7 @@ def get_test_actor_config( colocate_all: bool = False, weight_sync_backend: str = "nccl", tp_size: int = 2, + merge_lora: bool = True, ) -> SkyRLTrainConfig: """Get base config with test-specific overrides.""" cfg = SkyRLTrainConfig() @@ -48,6 +49,7 @@ def get_test_actor_config( if strategy == "megatron": cfg.trainer.policy.megatron_config.tensor_model_parallel_size = 2 cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = 1 + cfg.trainer.policy.megatron_config.lora_config.merge_lora = merge_lora if enable_lora: cfg.trainer.policy.model.lora = SkyRLLoraConfig( @@ -61,25 +63,29 @@ def get_test_actor_config( @pytest.mark.parametrize( - ("colocate_all", "weight_sync_backend", "strategy", "tp_size"), + ("colocate_all", "weight_sync_backend", "strategy", "tp_size", "merge_lora"), [ - pytest.param(False, "nccl", "fsdp", 2), - pytest.param(True, "nccl", "fsdp", 2), - pytest.param(False, "nccl", "fsdp2", 2), - pytest.param(True, "nccl", "fsdp2", 2), - pytest.param(False, "nccl", "megatron", 2, marks=pytest.mark.megatron), - pytest.param(True, "nccl", "megatron", 2, marks=pytest.mark.megatron), + pytest.param(False, "nccl", "fsdp", 2, True), + pytest.param(True, "nccl", "fsdp", 2, True), + pytest.param(False, "nccl", "fsdp2", 2, True), + pytest.param(True, "nccl", "fsdp2", 2, True), + pytest.param(False, "nccl", "megatron", 2, True, marks=pytest.mark.megatron), + pytest.param(True, "nccl", "megatron", 2, True, marks=pytest.mark.megatron), + pytest.param(False, "nccl", "megatron", 2, False, marks=pytest.mark.megatron), + pytest.param(True, "nccl", "megatron", 2, False, marks=pytest.mark.megatron), ], ids=[ "no_colocate_nccl_fsdp", "colocate_nccl_fsdp", "no_colocate_nccl_fsdp2", "colocate_nccl_fsdp2", - "no_colocate_nccl_megatron", - "colocate_nccl_megatron", + "no_colocate_nccl_megatron_merged", + "colocate_nccl_megatron_merged", + "no_colocate_nccl_megatron_adapter", + "colocate_nccl_megatron_adapter", ], ) -def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_backend, strategy, tp_size): +def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_backend, strategy, tp_size, merge_lora): """ Tests initalizing the policy actor group and inference engine, syncing weights, and performing generation. """ @@ -89,8 +95,15 @@ def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_ba colocate_all=colocate_all, weight_sync_backend=weight_sync_backend, tp_size=tp_size, + merge_lora=merge_lora, ) + # Only enable LoRA on the vLLM side when adapters are loaded separately. + # When merge_lora=True the bridge merges LoRA into the full weights, so + # vLLM receives plain weights and must NOT have enable_lora (which wraps + # modules and changes named_parameters(), breaking load_weights). + needs_vllm_lora = not (strategy == "megatron" and merge_lora) + # If colocate is True, this will load the engine, sleep, and wake up the engine with InferenceEngineState.create( cfg=cfg, @@ -99,8 +112,8 @@ def test_policy_local_engines_e2e(ray_init_fixture, colocate_all, weight_sync_ba async_engine=cfg.generator.inference_engine.async_engine, tp_size=cfg.generator.inference_engine.tensor_parallel_size, colocate_all=cfg.trainer.placement.colocate_all, - sleep_level=1, # since we explicitly sync weights - enable_lora=True, # Enable LoRA for this test + sleep_level=1 if needs_vllm_lora else 2, + enable_lora=needs_vllm_lora, ) as engines: client, pg = engines.client, engines.pg diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index bd62dea1ef..3d8573e544 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -519,6 +519,11 @@ def create( cli_args.enable_lora = True if active_lora_name is None: active_lora_name = "skyrl-lora" + else: + # Override build_vllm_cli_args which auto-enables LoRA based + # on lora rank in the config. For merged weight sync the + # inference engine must NOT have LoRA wrapping enabled. + cli_args.enable_lora = False server_group = ServerGroup( cli_args=cli_args, num_servers=ie_cfg.num_engines * ie_cfg.data_parallel_size, From 611dc6869b6ac0ac9a6daf4c336f8fc700e60870 Mon Sep 17 00:00:00 2001 From: Aaron Hao Date: Mon, 23 Mar 2026 11:05:16 -0700 Subject: [PATCH 06/11] Update ci/gpu_ci_run_skyrl_train.sh Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com> --- ci/gpu_ci_run_skyrl_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/gpu_ci_run_skyrl_train.sh b/ci/gpu_ci_run_skyrl_train.sh index 5ff22ebbd5..9b69c5e3bb 100755 --- a/ci/gpu_ci_run_skyrl_train.sh +++ b/ci/gpu_ci_run_skyrl_train.sh @@ -35,4 +35,4 @@ uv run --directory . --isolated --extra dev --extra fsdp pytest -s tests/backend _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py -_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -k "fsdp" +_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -m "not megatron" From 0bf728312b247b9833b1d01574f142d8f2f8ec7e Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 23 Mar 2026 19:32:45 +0000 Subject: [PATCH 07/11] x Signed-off-by: ahao-anyscale Made-with: Cursor --- .../skyrl_train/inference_servers/router.py | 17 ++++++++++------- tests/backends/skyrl_train/gpu/utils.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/router.py b/skyrl/backends/skyrl_train/inference_servers/router.py index 710741d92b..5a2be4e310 100644 --- a/skyrl/backends/skyrl_train/inference_servers/router.py +++ b/skyrl/backends/skyrl_train/inference_servers/router.py @@ -15,6 +15,7 @@ import logging import threading import time +from contextlib import asynccontextmanager from typing import List, Optional import httpx @@ -106,8 +107,17 @@ def _get_server_for_request(self, request: Request) -> str: def _build_app(self) -> FastAPI: """Build the FastAPI app with proxy routes.""" + + @asynccontextmanager + async def lifespan(app: FastAPI): + yield + if self._client: + await self._client.aclose() + self._client = None + app = FastAPI( title="SkyRL Inference Router", + lifespan=lifespan, docs_url=None, redoc_url=None, openapi_url=None, @@ -255,13 +265,6 @@ def shutdown(self) -> None: logger.warning("Router thread did not exit gracefully, forcing server socket closure") self._force_close_server_sockets() - if self._client: - try: - asyncio.get_event_loop().run_until_complete(self._client.aclose()) - except Exception: - pass - self._client = None - self._server = None self._server_thread = None self._app = None diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index 3d8573e544..364394c105 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -160,7 +160,7 @@ def init_worker_with_type( cfg = get_test_actor_config() if shared_pg is not None: - pg = shared_pg + pg = ResolvedPlacementGroup(shared_pg) num_gpus_per_actor = 0.2 else: bundles = [{"GPU": num_gpus_per_node, "CPU": num_gpus_per_node} for _ in range(num_nodes)] @@ -577,7 +577,7 @@ def create( ) if sleep: asyncio.run(client.wake_up()) - return cls(client=client, pg=shared_pg, router=router, server_group=server_group) + return cls(client=client, pg=raw_pg if shared_pg else None, router=router, server_group=server_group) def init_remote_inference_servers( From 79ba72c660dee1e303bffae147b3fe864f61edf8 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 23 Mar 2026 20:47:55 +0000 Subject: [PATCH 08/11] x Signed-off-by: ahao-anyscale --- .../skyrl_train/inference_servers/utils.py | 22 +++++++++++++++++-- tests/backends/skyrl_train/gpu/utils.py | 11 ++-------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/skyrl/backends/skyrl_train/inference_servers/utils.py b/skyrl/backends/skyrl_train/inference_servers/utils.py index 62dd996ff2..5b0179abbd 100644 --- a/skyrl/backends/skyrl_train/inference_servers/utils.py +++ b/skyrl/backends/skyrl_train/inference_servers/utils.py @@ -4,6 +4,20 @@ from skyrl.train.config import SkyRLTrainConfig, get_config_as_dict +def _uses_lora_weight_sync(cfg: SkyRLTrainConfig) -> bool: + """Return True when the trainer syncs LoRA adapters (not merged weights). + + FSDP always syncs LoRA adapters when ``lora.rank > 0``. + Megatron merges LoRA into the base weights by default + (``merge_lora=True``), so the inference engine should not enable LoRA. + """ + if cfg.trainer.policy.model.lora.rank <= 0: + return False + if cfg.trainer.strategy == "megatron": + return not cfg.trainer.policy.megatron_config.lora_config.merge_lora + return True + + # TODO: Add a test for validation def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace: """Build CLI args for vLLM server from config.""" @@ -48,12 +62,16 @@ def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace: for key, value in overrides.items(): setattr(args, key, value) - # Add LoRA params if enabled - if cfg.trainer.policy.model.lora.rank > 0: + # Enable LoRA on the inference engine only when the trainer will sync + # LoRA adapters (not merged weights). Megatron merges by default + # (merge_lora=True), so the inference engine must NOT have LoRA wrapping. + if _uses_lora_weight_sync(cfg): args.enable_lora = True args.max_lora_rank = cfg.trainer.policy.model.lora.rank args.max_loras = 1 args.fully_sharded_loras = ie_cfg.fully_sharded_loras + else: + args.enable_lora = False # Add any extra engine_init_kwargs engine_kwargs = get_config_as_dict(ie_cfg.engine_init_kwargs) diff --git a/tests/backends/skyrl_train/gpu/utils.py b/tests/backends/skyrl_train/gpu/utils.py index 364394c105..437e645cca 100644 --- a/tests/backends/skyrl_train/gpu/utils.py +++ b/tests/backends/skyrl_train/gpu/utils.py @@ -515,15 +515,8 @@ def create( # NOTE: In the case of the new inference backend, server is up by default, so we don't need # any special handling for sleep cli_args = build_vllm_cli_args(cfg) - if enable_lora: - cli_args.enable_lora = True - if active_lora_name is None: - active_lora_name = "skyrl-lora" - else: - # Override build_vllm_cli_args which auto-enables LoRA based - # on lora rank in the config. For merged weight sync the - # inference engine must NOT have LoRA wrapping enabled. - cli_args.enable_lora = False + if cli_args.enable_lora and active_lora_name is None: + active_lora_name = "skyrl-lora" server_group = ServerGroup( cli_args=cli_args, num_servers=ie_cfg.num_engines * ie_cfg.data_parallel_size, From 8e860af70b110593b31647f73ac7f5c32ff63e95 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Mon, 23 Mar 2026 22:25:40 +0000 Subject: [PATCH 09/11] x Signed-off-by: ahao-anyscale --- .../skyrl_train/inference_servers/utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/skyrl/backends/skyrl_train/inference_servers/utils.py b/skyrl/backends/skyrl_train/inference_servers/utils.py index 5b0179abbd..85ff7950a3 100644 --- a/skyrl/backends/skyrl_train/inference_servers/utils.py +++ b/skyrl/backends/skyrl_train/inference_servers/utils.py @@ -1,8 +1,11 @@ +import logging from argparse import Namespace from skyrl.backends.skyrl_train.weight_sync import get_transfer_strategy from skyrl.train.config import SkyRLTrainConfig, get_config_as_dict +logger = logging.getLogger(__name__) + def _uses_lora_weight_sync(cfg: SkyRLTrainConfig) -> bool: """Return True when the trainer syncs LoRA adapters (not merged weights). @@ -70,6 +73,17 @@ def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace: args.max_lora_rank = cfg.trainer.policy.model.lora.rank args.max_loras = 1 args.fully_sharded_loras = ie_cfg.fully_sharded_loras + + if not cfg.trainer.placement.colocate_all: + lora_path = cfg.trainer.policy.model.lora.lora_sync_path + logger.warning( + "LoRA weight sync is enabled but training and inference are not " + "colocated (placement.colocate_all=false). The trainer saves LoRA " + "adapters to disk for the inference engine to load, so both must " + "share a filesystem. Set trainer.policy.model.lora.lora_sync_path " + "to a shared mount (current value: %s).", + lora_path, + ) else: args.enable_lora = False From dd56c1f8110663e0b9ea8814ad3c70409b66106d Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 2 Apr 2026 17:50:10 +0000 Subject: [PATCH 10/11] x Signed-off-by: ahao-anyscale --- ci/gpu_ci_run_skyrl_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/gpu_ci_run_skyrl_train.sh b/ci/gpu_ci_run_skyrl_train.sh index dfc0430d22..b1793d5506 100755 --- a/ci/gpu_ci_run_skyrl_train.sh +++ b/ci/gpu_ci_run_skyrl_train.sh @@ -35,4 +35,4 @@ uv run --directory . --isolated --extra dev --extra fsdp pytest -s tests/backend _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp --extra vllm-router pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp --extra vllm-router pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp --extra vllm-router pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py -_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp --extra vllm-router pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py +_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp --extra vllm-router pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -m "not megatron" From 3eb0c089d71cca79ec7c23f0acd3ea62113e9116 Mon Sep 17 00:00:00 2001 From: ahao-anyscale Date: Thu, 2 Apr 2026 19:16:52 +0000 Subject: [PATCH 11/11] x Signed-off-by: ahao-anyscale --- ci/gpu_ci_run_skyrl_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/gpu_ci_run_skyrl_train.sh b/ci/gpu_ci_run_skyrl_train.sh index cd5476497d..207af96015 100755 --- a/ci/gpu_ci_run_skyrl_train.sh +++ b/ci/gpu_ci_run_skyrl_train.sh @@ -35,5 +35,5 @@ uv run --directory . --isolated --extra dev --extra fsdp pytest -s tests/backend _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_policy_local_engines_e2e.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py -_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py +_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py -m "not megatron" _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_expert_parallel_inference.py