Skip to content

Commit bc5bc97

Browse files
Fix compatibility with vLLM 0.10.x
MultiStepModelRunner was removed in vLLM 0.10.x and merged into ModelRunner, which now has built-in LoRA support. This commit updates the code to handle both old (<0.10) and new (>=0.10) versions of vLLM: - Make MultiStepModelRunner import optional with try/except - Update all isinstance checks to handle None case - Convert patch_multi_step_model_runner to a no-op for vLLM 0.10.x - Add deprecation notes explaining the changes This fixes ModuleNotFoundError when using vLLM 0.10.2.
1 parent 5229cf4 commit bc5bc97

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

src/art/unsloth/state.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717
from trl import GRPOConfig, GRPOTrainer
1818
from vllm.engine.arg_utils import AsyncEngineArgs
1919
from vllm.engine.async_llm_engine import AsyncLLMEngine
20-
from vllm.worker.multi_step_model_runner import MultiStepModelRunner
2120
from vllm.worker.worker_base import WorkerWrapperBase
2221

22+
try:
23+
from vllm.worker.multi_step_model_runner import MultiStepModelRunner
24+
except ImportError:
25+
# MultiStepModelRunner was removed in vLLM 0.10.x
26+
MultiStepModelRunner = None # type: ignore
27+
2328
from ..dev.model import InternalModelConfig
2429
from .train import gc_and_empty_cuda_cache
2530

@@ -41,8 +46,8 @@ class ModelState:
4146
def __init__(self, config: InternalModelConfig) -> None:
4247
from vllm.engine import async_llm_engine
4348

44-
# Patch MultiStepModelRunner for Unsloth compatibility
45-
if not hasattr(MultiStepModelRunner, "model"):
49+
# Patch MultiStepModelRunner for Unsloth compatibility (vLLM < 0.10.x)
50+
if MultiStepModelRunner is not None and not hasattr(MultiStepModelRunner, "model"):
4651
MultiStepModelRunner.model = property( # type: ignore
4752
lambda self: self._base_model_runner.model
4853
)
@@ -143,7 +148,10 @@ def __init__(self, async_engine: AsyncLLMEngine, enable_sleep_mode: bool) -> Non
143148
"WorkerWrapperBase",
144149
getattr(self.async_engine.engine.model_executor, "driver_worker"),
145150
)
146-
if isinstance(self.driver_worker.model_runner, MultiStepModelRunner):
151+
# Patch MultiStepModelRunner if it exists (vLLM < 0.10.x)
152+
if MultiStepModelRunner is not None and isinstance(
153+
self.driver_worker.model_runner, MultiStepModelRunner
154+
):
147155
patch_multi_step_model_runner(self.driver_worker.model_runner)
148156

149157
@asynccontextmanager

src/art/vllm/patches.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
"""Monkey patches and modifications for vLLM."""
22

33
import ctypes
4-
from typing import Any
4+
from typing import Any, TYPE_CHECKING
55

66
import torch
7-
from vllm.worker.multi_step_model_runner import MultiStepModelRunner
7+
8+
if TYPE_CHECKING:
9+
from vllm.worker.model_runner import ModelRunner
810

911

1012
def patch_allocator() -> None:
@@ -185,13 +187,13 @@ def patch(
185187
ToolParserManager.get_tool_parser = patched_get_tool_parser
186188

187189

188-
def patch_multi_step_model_runner(runner: MultiStepModelRunner) -> None:
190+
def patch_multi_step_model_runner(runner: "ModelRunner") -> None:
189191
"""
190192
Patches a MultiStepModelRunner to support LoRA adapters.
193+
194+
Note: This function is deprecated as of vLLM 0.10.x. MultiStepModelRunner
195+
was merged into ModelRunner, which now has built-in LoRA support.
196+
This function is kept for backward compatibility but does nothing.
191197
"""
192-
base_runner = runner._base_model_runner
193-
runner.set_active_loras = base_runner.set_active_loras
194-
runner.add_lora = base_runner.add_lora
195-
runner.remove_lora = base_runner.remove_lora
196-
runner.pin_lora = base_runner.pin_lora
197-
runner.list_loras = base_runner.list_loras
198+
# No-op: ModelRunner in vLLM 0.10.x already has LoRA methods built-in
199+
pass

0 commit comments

Comments
 (0)