diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a05179ed5a6..c9448554880 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -76,6 +76,7 @@ from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, + CheckWeightsReqInput, CloseSessionReqInput, ConfigureLoggingReq, ContinueGenerationReqInput, @@ -956,6 +957,15 @@ async def resume_memory_occupation( return _create_error_response(e) +@app.post("/weights_checker") +async def check_weights(obj: CheckWeightsReqInput, request: Request): + success, message = await _global_state.tokenizer_manager.check_weights(obj, request) + return ORJSONResponse( + {"success": success, "message": message}, + status_code=200 if success else HTTPStatus.BAD_REQUEST, + ) + + @app.api_route("/slow_down", methods=["GET", "POST"]) async def slow_down(obj: SlowDownReqInput, request: Request): """Slow down the system deliberately. Only for testing. Example scenario: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 46647d01f57..4814786ddef 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1311,6 +1311,17 @@ class ResumeMemoryOccupationReqOutput(BaseReq): pass +@dataclass +class CheckWeightsReqInput(BaseReq): + action: str + + +@dataclass +class CheckWeightsReqOutput(BaseReq): + success: bool + message: str + + @dataclass class SlowDownReqInput(BaseReq): forward_sleep_time: Optional[float] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e7651cd9b5b..920e8529205 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -71,6 +71,7 @@ BaseReq, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, + CheckWeightsReqInput, ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, @@ -568,6 +569,7 @@ def __init__( (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), + (CheckWeightsReqInput, self.check_weights), (SlowDownReqInput, self.slow_down), (ProfileReq, self.profile), (FreezeGCReq, self.handle_freeze_gc), diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 1d2965a8680..f8ebfc1f4a1 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import traceback from typing import TYPE_CHECKING, Tuple import torch @@ -12,6 +13,8 @@ GPU_MEMORY_TYPE_WEIGHTS, ) from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqOutput, GetWeightsByNameReqInput, @@ -166,6 +169,15 @@ def resume_memory_occupation( return ResumeMemoryOccupationReqOutput() + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): + try: + self.tp_worker.model_runner.check_weights(action=recv_req.action) + return CheckWeightsReqOutput(success=True, message="Success.") + except Exception as e: + logger.warning(f"check_weights see error: {e}") + traceback.print_exc() + return CheckWeightsReqOutput(success=False, message=f"{e}") + def save_remote_model(self: Scheduler, params): url = params["url"] diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 558b78756c6..d6e237c71ca 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -22,6 +22,8 @@ import zmq from sglang.srt.managers.io_struct import ( + CheckWeightsReqInput, + CheckWeightsReqOutput, ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, @@ -183,6 +185,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.check_weights_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.slow_down_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -256,6 +261,10 @@ def _get_communicator_dispatcher(self: TokenizerManager): ResumeMemoryOccupationReqOutput, self.resume_memory_occupation_communicator.handle_recv, ), + ( + CheckWeightsReqOutput, + self.check_weights_communicator.handle_recv, + ), ( SlowDownReqOutput, self.slow_down_communicator.handle_recv, @@ -670,6 +679,15 @@ async def resume_memory_occupation( self.auto_create_handle_loop() await self.resume_memory_occupation_communicator(obj) + async def check_weights( + self: TokenizerManager, + obj: CheckWeightsReqInput, + request: Optional[fastapi.Request] = None, + ) -> CheckWeightsReqOutput: + self.auto_create_handle_loop() + results = await self.check_weights_communicator(obj) + return _Communicator.merge_results(results) + async def slow_down( self: TokenizerManager, obj: SlowDownReqInput, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c4a4511ba53..92f9e878836 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -170,6 +170,7 @@ ) from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils.weight_checker import WeightChecker from sglang.srt.weight_sync.tensor_bucket import ( FlattenedTensorBucket, FlattenedTensorMetadata, @@ -328,6 +329,8 @@ def __init__( # CPU offload set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank)) + self._weight_checker = WeightChecker(model_runner=self) + if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"): slow_rank_detector.execute() # Init mindspore running environment when model impl is "mindspore" @@ -2508,6 +2511,9 @@ def save_sharded_model( ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) + def check_weights(self, action: str): + self._weight_checker.handle(action=action) + def update_weights_from_ipc(self, recv_req): """Update weights from IPC for checkpoint-engine integration.""" try: diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py new file mode 100644 index 00000000000..98ae631df61 --- /dev/null +++ b/python/sglang/srt/utils/weight_checker.py @@ -0,0 +1,97 @@ +import logging +from typing import Dict + +import torch + +logger = logging.getLogger(__name__) + + +class WeightChecker: + def __init__(self, model_runner): + self._model_runner = model_runner + self._snapshot_tensors = None + + def handle(self, action: str): + logger.info(f"[WeightChecker] handle action={action}") + if action == "snapshot": + self._snapshot() + elif action == "reset_tensors": + self._reset_tensors() + elif action == "compare": + self._compare() + else: + raise Exception(f"Unsupported {action=}") + + def _snapshot(self): + named_tensors = [ + (name, param.data.detach().cpu()) for name, param in self._model_state() + ] + self._snapshot_tensors = dict(named_tensors) + assert len(self._snapshot_tensors) == len( + named_tensors + ), f"should not have duplicated tensor name" + + def _reset_tensors(self): + for name, param in self._model_state(): + param.copy_(_random_like(param)) + + def _compare(self): + assert self._snapshot_tensors is not None + + _check_tensors( + expect_tensors=self._snapshot_tensors, + actual_tensors=dict(self._model_state()), + ) + + def _model_state(self): + # TODO: support EAGLE etc (e.g. yield from both main model and draft model) + yield from self._model_runner.model.named_parameters() + yield from self._model_runner.model.named_buffers() + + +def _check_tensors( + expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor] +): + from sglang.srt.debug_utils.dumper import get_tensor_info + + assert len(expect_tensors) == len(actual_tensors) + + good_names = [] + error_messages = [] + + for name in expect_tensors: + expect = expect_tensors[name].cuda() + actual = actual_tensors[name].cuda() + + if torch.all(expect == actual): + good_names.append(name) + else: + abs_diff = (actual.float() - expect.float()).abs() + error_messages.append( + f"name={name} " + f"max_abs_err={abs_diff.max()} " + f"mean_abs_err={abs_diff.mean()} " + f"{get_tensor_info(expect)=} " + f"{get_tensor_info(actual)=} " + ) + + logger.info(f"[check_tensors] passed: {good_names}") + if len(error_messages) > 0: + raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages)) + + +def _random_like(t: torch.Tensor): + device = t.device + shape = t.shape + dtype = t.dtype + + if dtype.is_floating_point: + return torch.rand(shape, device=device, dtype=torch.float32).to(dtype) + + if dtype == torch.bool: + return torch.rand(shape, device=device) > 0.5 + + info = torch.iinfo(dtype) + return torch.randint( + low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype + )