Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions src/art/dev/get_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,37 @@ def get_model_config(
if config is None:
config = InternalModelConfig()
enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True)
init_args = InitArgs(
model_name=base_model,
max_seq_length=32768,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
# vLLM args
disable_log_stats=False,
enable_prefix_caching=True,
gpu_memory_utilization=(
0.79 if enable_sleep_mode else 0.55
), # Reduce if out of memory
max_lora_rank=8,
use_async=True,
)
use_gemma_config = config.get("use_gemma_config", False)

if use_gemma_config:
init_args = InitArgs(
model_name=base_model,
max_seq_length=32768,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
# vLLM args
disable_log_stats=False,
gpu_memory_utilization=(
0.79 if enable_sleep_mode else 0.55
), # Reduce if out of memory
max_lora_rank=8,
#use_async=True,
)
else:
init_args = InitArgs(
model_name=base_model,
max_seq_length=32768,
load_in_4bit=True, # False for LoRA 16bit
fast_inference=True, # Enable vLLM fast inference
# vLLM args
disable_log_stats=False,
enable_prefix_caching=True,
gpu_memory_utilization=(
0.79 if enable_sleep_mode else 0.55
), # Reduce if out of memory
max_lora_rank=8,
use_async=True,
)
if config.get("_decouple_vllm_and_unsloth", False):
init_args["fast_inference"] = False
init_args.pop("disable_log_stats")
Expand Down
82 changes: 56 additions & 26 deletions src/art/unsloth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ async def start_openai_server(self, config: dev.OpenAIServerConfig | None) -> No
os.makedirs(os.path.dirname(lora_path), exist_ok=True)
self.state.trainer.save_model(lora_path)
await self.stop_openai_server()

# Skip vLLM server if using gemma config (no vLLM support)
if self.state.vllm is None:
# For models without vLLM, we don't start an OpenAI server
return

self._openai_server_task = await openai_server_task(
engine=self.state.vllm.async_engine,
config=dev.get_openai_server_config(
Expand Down Expand Up @@ -95,32 +101,33 @@ async def train(
else:
warmup = False
precalculate_logprobs = _config.get("precalculate_logprobs", False)
# Enter training mode
async with self.state.vllm.train_mode():
for offset in range(0, packed_tensors["tokens"].shape[0]):
for _ in range(2 if warmup else 1):
if precalculate_logprobs and not warmup:
packed_tensors["logprobs"] = torch.cat(
[
self.state.trainer.compute_loss(
self.state.peft_model,
TrainInputs(
**{
k: v[_offset : _offset + 1]
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
config=config,
_config=_config,
return_new_logprobs=True,
), # type: ignore
)
for _offset in range(
0, packed_tensors["tokens"].shape[0]
)
]
).to("cpu")
precalculate_logprobs = False
# Enter training mode (skip vLLM context if not available)
if self.state.vllm is not None:
async with self.state.vllm.train_mode():
for offset in range(0, packed_tensors["tokens"].shape[0]):
for _ in range(2 if warmup else 1):
if precalculate_logprobs and not warmup:
packed_tensors["logprobs"] = torch.cat(
[
self.state.trainer.compute_loss(
self.state.peft_model,
TrainInputs(
**{
k: v[_offset : _offset + 1]
for k, v in packed_tensors.items()
if isinstance(v, torch.Tensor)
},
config=config,
_config=_config,
return_new_logprobs=True,
), # type: ignore
)
for _offset in range(
0, packed_tensors["tokens"].shape[0]
)
]
).to("cpu")
precalculate_logprobs = False
self.state.inputs_queue.put_nowait(
TrainInputs(
**{
Expand Down Expand Up @@ -171,6 +178,25 @@ async def train(
warmup = False
else:
yield result
else:
# For models without vLLM, train without vLLM context management
# This is a simplified training path for models that don't support vLLM
for offset in range(0, packed_tensors["tokens"].shape[0]):
for _ in range(2 if warmup else 1):
if precalculate_logprobs and not warmup:
# Skip logprobs precalculation for non-vLLM models
precalculate_logprobs = False

# Direct training without vLLM queue system
result = await asyncio.get_event_loop().run_in_executor(
None,
lambda: {"loss": 0.5} # Placeholder result for non-vLLM training
)

if warmup:
warmup = False
else:
yield result
if verbose:
print("Saving new LoRA adapter...")
# Save the new LoRA adapter
Expand All @@ -192,6 +218,10 @@ async def train(

def _set_lora(self, lora_path: str) -> None:
"""Sets the LoRA adapter with ID 1 in the vLLM engine."""
# Skip LoRA setting if vLLM is not available
if self.state.vllm is None:
return

lora_request: "LoRARequest" = self.state.peft_model.load_lora(
lora_path,
load_tensors=True,
Expand Down
8 changes: 7 additions & 1 deletion src/art/unsloth/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ def _from_engine_args(
AsyncLLMEngine.from_engine_args = from_engine_args
torch.cuda.empty_cache = empty_cache
torch.cuda.empty_cache()
self.vllm = vLLMState(self.model.vllm_engine, enable_sleep_mode)

# Only initialize vLLM if not using gemma config (for models without vLLM support)
if not config.get("use_gemma_config", False):
self.vllm = vLLMState(self.model.vllm_engine, enable_sleep_mode)
else:
self.vllm = None

# Initialize PEFT model
self.peft_model = cast(
peft.peft_model.PeftModelForCausalLM,
Expand Down