Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/gpu_ci_run_skyrl_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ _SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_engine_generation.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_lora.py
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_expert_parallel_inference.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you timed how long this takes on L4s?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but I don't remember it taking more than 10 mins. The model is pretty small

1 change: 1 addition & 0 deletions skyrl/backends/skyrl_train/inference_servers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace:
enable_prefix_caching=ie_cfg.enable_prefix_caching,
enforce_eager=ie_cfg.enforce_eager,
max_num_batched_tokens=ie_cfg.max_num_batched_tokens,
enable_expert_parallel=ie_cfg.expert_parallel_size > 1,
max_num_seqs=ie_cfg.max_num_seqs,
enable_sleep_mode=cfg.trainer.placement.colocate_all,
weight_transfer_config=WeightTransferConfig(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,31 @@
"""
Tests for expert parallel (EP).

uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/test_expert_parallel_inference.py

uv run --isolated --extra dev --extra fsdp pytest tests/backends/skyrl_train/gpu/gpu_ci/test_expert_parallel_inference.py
"""

import asyncio
from typing import Optional

import pytest
import ray
from ray.util.placement_group import PlacementGroup, placement_group
from transformers import AutoTokenizer

from skyrl.backends.skyrl_train.inference_engines.base import InferenceEngineInput
from skyrl.backends.skyrl_train.inference_engines.inference_engine_client import (
InferenceEngineClient,
)
from skyrl.backends.skyrl_train.inference_engines.ray_wrapped_inference_engine import (
create_ray_wrapped_inference_engines,
)
from skyrl.backends.skyrl_train.inference_engines.utils import (
get_sampling_params_for_backend,
)
from skyrl.train.config import SkyRLTrainConfig
from skyrl.train.utils import get_ray_pg_ready_with_timeout, initialize_ray
from skyrl.utils.tok import get_tokenizer
from tests.backends.skyrl_train.gpu.utils import (
InferenceEngineState,
_ensure_chat_template,
are_responses_similar,
get_available_gpus,
get_test_actor_config,
get_test_prompts,
init_worker_with_type,
run_inference,
)

MODEL = "Qwen/Qwen1.5-MoE-A2.7B-Chat"
MODEL = "huihui-ai/Huihui-MoE-0.8B-2E"
NUM_GPUS = 4 # Should be divisible by 2


Expand Down Expand Up @@ -78,8 +70,8 @@ def _get_test_cfg() -> SkyRLTrainConfig:
return cfg


async def _run_single_generation(client: InferenceEngineClient, prompts, sampling_params):
tasks = [client.generate(InferenceEngineInput(prompts=[p], sampling_params=sampling_params)) for p in prompts]
async def _run_single_generation(client, prompts, sampling_params, tokenizer):
tasks = [run_inference(client, [p], sampling_params, tokenizer=tokenizer) for p in prompts]
results = await asyncio.gather(*tasks)
responses, reasons = [], []
for r in results:
Expand All @@ -88,34 +80,6 @@ async def _run_single_generation(client: InferenceEngineClient, prompts, samplin
return responses, reasons


def init_ray_inference_engines(
backend: str, tp_size: int, shared_pg: Optional[PlacementGroup], config: SkyRLTrainConfig
) -> InferenceEngineClient:
"""Initialize ray-wrapped inference engines for the specified backend"""
tokenizer = AutoTokenizer.from_pretrained(MODEL)
engine = create_ray_wrapped_inference_engines(
num_inference_engines=1,
tensor_parallel_size=tp_size,
expert_parallel_size=config.generator.inference_engine.expert_parallel_size,
model_dtype="bfloat16",
pretrain=MODEL,
seed=42,
vllm_v1_disable_multiproc=True,
enable_prefix_caching=True,
enforce_eager=True,
shared_pg=shared_pg,
gpu_memory_utilization=0.8,
inference_engine_enable_sleep=False,
async_engine=True,
max_num_batched_tokens=8192,
max_num_seqs=1024,
tokenizer=tokenizer,
backend=backend,
)
client = InferenceEngineClient(engine, tokenizer, config)
return client


def test_ep_generation():
"""
Ensure vLLM generation with expert parallel enabled (EP=2) runs without errors.
Expand All @@ -129,73 +93,54 @@ def test_ep_generation():
cfg.generator.sampling_params.temperature = 0.0
cfg.generator.sampling_params.top_p = 1.0
cfg.generator.sampling_params.top_k = -1
initialize_ray(cfg)

client = init_ray_inference_engines(
backend=cfg.generator.inference_engine.backend,
tp_size=cfg.generator.inference_engine.tensor_parallel_size,
shared_pg=None,
config=cfg,
)

prompts = get_test_prompts(MODEL, num_samples=4)
sampling_params = get_sampling_params_for_backend(
cfg.generator.inference_engine.backend, cfg.generator.sampling_params
)
with InferenceEngineState.create(cfg, sleep_level=1) as state:
tokenizer = get_tokenizer(MODEL)
_ensure_chat_template(tokenizer)
state.client.tokenizer = tokenizer
prompts = get_test_prompts(MODEL, num_samples=4)
sampling_params = get_sampling_params_for_backend(
cfg.generator.inference_engine.backend, cfg.generator.sampling_params
)

responses, reasons = asyncio.run(_run_single_generation(client, prompts, sampling_params))
assert len(responses) == len(prompts)
assert len(reasons) == len(prompts)
responses, reasons = asyncio.run(_run_single_generation(state.client, prompts, sampling_params, tokenizer))
assert len(responses) == len(prompts)
assert len(reasons) == len(prompts)
finally:
ray.shutdown()


def test_ep_weight_sync():
def test_ep_weight_sync(ray_init_fixture):
"""
Ensure generation works after syncing weights from training policy worker.
"""
_check_gpus(num_gpus=NUM_GPUS)

pg = None
try:
cfg = _get_test_cfg()
cfg.trainer.placement.colocate_all = True
# Deterministic sampling for robust comparisons
cfg.generator.sampling_params.temperature = 0.0
cfg.generator.sampling_params.top_p = 1.0
cfg.generator.sampling_params.top_k = -1

initialize_ray(cfg)

# Create a shared PG with 2 bundles (sufficient for two engines with tp=2 and training)
pg = placement_group([{"GPU": 1, "CPU": 1}] * NUM_GPUS, strategy="PACK")
get_ray_pg_ready_with_timeout(pg, timeout=60)

# Spin up two inference engines with EP enabled, colocated
client = init_ray_inference_engines(
backend=cfg.generator.inference_engine.backend,
tp_size=cfg.generator.inference_engine.tensor_parallel_size,
shared_pg=pg,
config=cfg,
)
asyncio.run(client.wake_up())
cfg = _get_test_cfg()
cfg.trainer.placement.colocate_all = True
# Deterministic sampling for robust comparisons
cfg.generator.sampling_params.temperature = 0.0
cfg.generator.sampling_params.top_p = 1.0
cfg.generator.sampling_params.top_k = -1

with InferenceEngineState.create(cfg, colocate_all=True) as state:
# Generate before weight sync
tokenizer = get_tokenizer(MODEL)
_ensure_chat_template(tokenizer)
state.client.tokenizer = tokenizer
prompts = get_test_prompts(MODEL, num_samples=4)
sampling_params = get_sampling_params_for_backend(
cfg.generator.inference_engine.backend, cfg.generator.sampling_params
)
out_before = asyncio.run(
client.generate(InferenceEngineInput(prompts=prompts, sampling_params=sampling_params))
)
out_before = asyncio.run(run_inference(state.client, prompts, sampling_params, tokenizer=tokenizer))
assert len(out_before["responses"]) == len(prompts)

asyncio.run(client.sleep())
asyncio.run(state.client.sleep())

# Initialize policy worker
# Initialize policy worker on the same placement group
policy = init_worker_with_type(
"policy",
shared_pg=pg,
shared_pg=state.pg,
colocate_all=True,
num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node,
cfg=cfg,
Expand All @@ -204,21 +149,27 @@ def test_ep_weight_sync():
# Sync weights to inference engines
ray.get(
policy.async_run_ray_method(
"pass_through", "init_weight_sync_state", client, cfg.generator.inference_engine
"pass_through",
"init_weight_sync_state",
state.client,
cfg.generator.inference_engine,
)
)
asyncio.run(client.wake_up(tags=["weights"]))
asyncio.run(state.client.wake_up(tags=["weights"]))
ray.get(
policy.async_run_ray_method(
"pass_through", "broadcast_to_inference_engines", client, cfg.generator.inference_engine
"pass_through",
"broadcast_to_inference_engines",
state.client,
cfg.generator.inference_engine,
)
)
policy.offload_to_cpu()
asyncio.run(client.wake_up(tags=["kv_cache"]))
asyncio.run(client.reset_prefix_cache())
asyncio.run(state.client.wake_up(tags=["kv_cache"]))
asyncio.run(state.client.reset_prefix_cache())

# Generate after weight sync
out_after = asyncio.run(client.generate(InferenceEngineInput(prompts=prompts, sampling_params=sampling_params)))
out_after = asyncio.run(run_inference(state.client, prompts, sampling_params, tokenizer=tokenizer))
assert len(out_after["responses"]) == len(prompts)
assert len(out_after["stop_reasons"]) == len(prompts)

Expand All @@ -228,10 +179,3 @@ def test_ep_weight_sync():
print(
f"Response changed significantly after weight sync: before={out_before['responses'][i][:200]} ... after={out_after['responses'][i][:200]} ..."
)
finally:
if pg is not None:
try:
ray.util.remove_placement_group(pg)
except Exception:
pass
ray.shutdown()
6 changes: 6 additions & 0 deletions tests/backends/skyrl_train/gpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,8 @@ def ray_init_for_tests():
env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
env_vars["NVTE_FUSED_ATTN"] = "0"
env_vars["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH")
if _SKYRL_USE_NEW_INFERENCE:
env_vars["_SKYRL_USE_NEW_INFERENCE"] = "1"
ray.init(runtime_env={"env_vars": env_vars})


Expand Down Expand Up @@ -465,6 +467,7 @@ def create(
engine_init_kwargs: Optional[Dict[str, Any]] = None,
use_new_inference_servers: Optional[bool] = None,
distributed_executor_backend: Optional[str] = None,
expert_parallel_size: Optional[int] = None,
) -> "InferenceEngineState":
"""
Instantiates inference engines in SkyRL with the provided configuration and overrides
Expand Down Expand Up @@ -496,6 +499,8 @@ def create(
ie_cfg.engine_init_kwargs = engine_init_kwargs
if distributed_executor_backend is not None:
ie_cfg.distributed_executor_backend = distributed_executor_backend
if expert_parallel_size is not None:
ie_cfg.expert_parallel_size = expert_parallel_size

assert ie_cfg.run_engines_locally, "This test does not yet support remote engines."

Expand Down Expand Up @@ -559,6 +564,7 @@ def create(
eps = create_ray_wrapped_inference_engines(
num_inference_engines=ie_cfg.num_engines,
tensor_parallel_size=ie_cfg.tensor_parallel_size,
expert_parallel_size=ie_cfg.expert_parallel_size,
model_dtype="bfloat16",
pretrain=cfg.trainer.policy.model.path,
seed=42,
Expand Down
Loading