From c0fb7e04e06a605aa99179ef041d9827e771211c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:14:18 +0800 Subject: [PATCH 01/59] more --- python/sglang/srt/managers/io_struct.py | 10 ++++++++++ .../srt/managers/scheduler_update_weights_mixin.py | 6 ++++++ python/sglang/srt/model_executor/model_runner.py | 3 +++ 3 files changed, 19 insertions(+) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 01ffab062c..d5ff16b84e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1272,6 +1272,16 @@ class ResumeMemoryOccupationReqOutput(BaseReq): pass +@dataclass +class WeightCheckerReqInput(BaseReq): + tags: str + + +@dataclass +class WeightCheckerReqOutput(BaseReq): + pass + + @dataclass class SlowDownReqInput(BaseReq): forward_sleep_time: Optional[float] diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index fa0d612e2e..03e19dc199 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -30,6 +30,8 @@ UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, + WeightCheckerReqInput, + WeightCheckerReqOutput, ) if TYPE_CHECKING: @@ -184,6 +186,10 @@ def save_sharded_model(self: Scheduler, params): max_size=params["max_size"], ) + def weight_checker(self: Scheduler, recv_req: WeightCheckerReqInput): + self.tp_worker.model_runner.handle_weight_checker(action=recv_req.action) + return WeightCheckerReqOutput() + def _export_static_state(model): return dict( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b95759af48..b98d1b87ec 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2444,6 +2444,9 @@ def save_sharded_model( ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) + def handle_weight_checker(self, action: str): + TODO + def update_weights_from_ipc(self, recv_req): """Update weights from IPC for checkpoint-engine integration.""" try: From e1f1a6321f5be83044dfd98a65866161dba1888e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:15:17 +0800 Subject: [PATCH 02/59] more --- python/sglang/srt/entrypoints/http_server.py | 11 +++++++++++ python/sglang/srt/managers/io_struct.py | 6 +++--- .../srt/managers/scheduler_update_weights_mixin.py | 12 ++++++------ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 87197e5b7c..29d8f67f31 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -953,6 +953,17 @@ async def resume_memory_occupation( except Exception as e: return _create_error_response(e) +@app.api_route("/weight_checker", methods=["POST"]) +async def weight_checker( + obj: ResumeMemoryOccupationReqInput, request: Request +): + """Resume GPU memory occupation.""" + try: + await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + @app.api_route("/slow_down", methods=["GET", "POST"]) async def slow_down(obj: SlowDownReqInput, request: Request): diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index d5ff16b84e..2aecee1aa9 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1273,12 +1273,12 @@ class ResumeMemoryOccupationReqOutput(BaseReq): @dataclass -class WeightCheckerReqInput(BaseReq): - tags: str +class CheckWeightReqInput(BaseReq): + action: str @dataclass -class WeightCheckerReqOutput(BaseReq): +class CheckWeightReqOutput(BaseReq): pass diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 03e19dc199..2a8d1234cc 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -30,8 +30,8 @@ UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, - WeightCheckerReqInput, - WeightCheckerReqOutput, + CheckWeightReqInput, + CheckWeightReqOutput, ) if TYPE_CHECKING: @@ -167,6 +167,10 @@ def resume_memory_occupation( return ResumeMemoryOccupationReqOutput() + def check_weight(self: Scheduler, recv_req: CheckWeightReqInput): + self.tp_worker.model_runner.handle_weight_checker(action=recv_req.action) + return CheckWeightReqOutput() + def save_remote_model(self: Scheduler, params): url = params["url"] @@ -186,10 +190,6 @@ def save_sharded_model(self: Scheduler, params): max_size=params["max_size"], ) - def weight_checker(self: Scheduler, recv_req: WeightCheckerReqInput): - self.tp_worker.model_runner.handle_weight_checker(action=recv_req.action) - return WeightCheckerReqOutput() - def _export_static_state(model): return dict( From 92cd75e5feacafff7f15eeddaafba87f71f0469e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:15:29 +0800 Subject: [PATCH 03/59] more --- python/sglang/srt/managers/scheduler_update_weights_mixin.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 2a8d1234cc..2bb4011afa 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -168,7 +168,7 @@ def resume_memory_occupation( return ResumeMemoryOccupationReqOutput() def check_weight(self: Scheduler, recv_req: CheckWeightReqInput): - self.tp_worker.model_runner.handle_weight_checker(action=recv_req.action) + self.tp_worker.model_runner.check_weight(action=recv_req.action) return CheckWeightReqOutput() def save_remote_model(self: Scheduler, params): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b98d1b87ec..e02cf8b068 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2444,7 +2444,7 @@ def save_sharded_model( ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) - def handle_weight_checker(self, action: str): + def check_weight(self, action: str): TODO def update_weights_from_ipc(self, recv_req): From 2ecbb75241c5855c5a777c0bd457ac4e2cadd049 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:15:56 +0800 Subject: [PATCH 04/59] more --- python/sglang/srt/entrypoints/http_server.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 29d8f67f31..e3dbe69f03 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -90,6 +90,7 @@ ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, + CheckWeightReqInput, SendWeightsToRemoteInstanceReqInput, SeparateReasoningReqInput, SetInternalStateReq, @@ -953,13 +954,10 @@ async def resume_memory_occupation( except Exception as e: return _create_error_response(e) -@app.api_route("/weight_checker", methods=["POST"]) -async def weight_checker( - obj: ResumeMemoryOccupationReqInput, request: Request -): - """Resume GPU memory occupation.""" +@app.api_route("/check_weight", methods=["POST"]) +async def check_weight(obj: CheckWeightReqInput, request: Request): try: - await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + await _global_state.tokenizer_manager.check_weight(obj, request) except Exception as e: return _create_error_response(e) From 616fa04c1f3aebea678a273d87045fbbbe14e665 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:17:47 +0800 Subject: [PATCH 05/59] more --- .../managers/tokenizer_communicator_mixin.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 70129ea8c8..19c61d6487 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -53,6 +53,8 @@ ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, + CheckWeightReqInput, + CheckWeightReqOutput, SendWeightsToRemoteInstanceReqInput, SendWeightsToRemoteInstanceReqOutput, SetInternalStateReq, @@ -183,6 +185,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.check_weight_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.slow_down_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -256,6 +261,10 @@ def _get_communicator_dispatcher(self: TokenizerManager): ResumeMemoryOccupationReqOutput, self.resume_memory_occupation_communicator.handle_recv, ), + ( + CheckWeightReqOutput, + self.check_weight_communicator.handle_recv, + ), ( SlowDownReqOutput, self.slow_down_communicator.handle_recv, @@ -656,6 +665,14 @@ async def resume_memory_occupation( self.auto_create_handle_loop() await self.resume_memory_occupation_communicator(obj) + async def check_weight( + self: TokenizerManager, + obj: CheckWeightReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.check_weight_communicator(obj) + async def slow_down( self: TokenizerManager, obj: SlowDownReqInput, From e40e5d460de2306343bdad2e529e8c9245e41b00 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:17:52 +0800 Subject: [PATCH 06/59] more --- python/sglang/srt/managers/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e1bd793314..94029822c2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -561,6 +561,7 @@ def __init__( (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), + (CheckWeightReqInput, self.check_weight), (SlowDownReqInput, self.slow_down), (ProfileReq, self.profile), (FreezeGCReq, self.handle_freeze_gc), From 3f4f5c76c23e4d87ce4f75419d6e265fdb13c86b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:17:54 +0800 Subject: [PATCH 07/59] fmt --- python/sglang/srt/entrypoints/http_server.py | 4 ++-- python/sglang/srt/managers/scheduler_update_weights_mixin.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index e3dbe69f03..a56fab46e8 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -76,6 +76,7 @@ from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, + CheckWeightReqInput, CloseSessionReqInput, ConfigureLoggingReq, DestroyWeightsUpdateGroupReqInput, @@ -90,7 +91,6 @@ ProfileReqInput, ReleaseMemoryOccupationReqInput, ResumeMemoryOccupationReqInput, - CheckWeightReqInput, SendWeightsToRemoteInstanceReqInput, SeparateReasoningReqInput, SetInternalStateReq, @@ -954,6 +954,7 @@ async def resume_memory_occupation( except Exception as e: return _create_error_response(e) + @app.api_route("/check_weight", methods=["POST"]) async def check_weight(obj: CheckWeightReqInput, request: Request): try: @@ -962,7 +963,6 @@ async def check_weight(obj: CheckWeightReqInput, request: Request): return _create_error_response(e) - @app.api_route("/slow_down", methods=["GET", "POST"]) async def slow_down(obj: SlowDownReqInput, request: Request): """Slow down the system deliberately. Only for testing. Example scenario: diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 2bb4011afa..3f907141d3 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -12,6 +12,8 @@ GPU_MEMORY_TYPE_WEIGHTS, ) from sglang.srt.managers.io_struct import ( + CheckWeightReqInput, + CheckWeightReqOutput, DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqOutput, GetWeightsByNameReqInput, @@ -30,8 +32,6 @@ UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, - CheckWeightReqInput, - CheckWeightReqOutput, ) if TYPE_CHECKING: From 11f4041643f24bf1147310b58ba78805f75b170a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:19:09 +0800 Subject: [PATCH 08/59] fmt --- python/sglang/srt/managers/tokenizer_communicator_mixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 19c61d6487..570023e49e 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -22,6 +22,8 @@ import zmq from sglang.srt.managers.io_struct import ( + CheckWeightReqInput, + CheckWeightReqOutput, ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, @@ -53,8 +55,6 @@ ReleaseMemoryOccupationReqOutput, ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqOutput, - CheckWeightReqInput, - CheckWeightReqOutput, SendWeightsToRemoteInstanceReqInput, SendWeightsToRemoteInstanceReqOutput, SetInternalStateReq, From de67d82c00531d57c58e10c48240014d02a31661 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:19:10 +0800 Subject: [PATCH 09/59] more --- python/sglang/srt/utils/weight_checker.py | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 python/sglang/srt/utils/weight_checker.py diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py new file mode 100644 index 0000000000..c7fca944ac --- /dev/null +++ b/python/sglang/srt/utils/weight_checker.py @@ -0,0 +1,10 @@ +class WeightChecker: + def handle(self, action: str): + if action == "snapshot": + self._snapshot() + elif action == "reset_param": + self._reset_param() + elif action == "compare": + self._compare() + else: + raise Exception(f"Unsupported {action=}") From 3544da292f13b195a13e3406ac10becb33294fcc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:19:23 +0800 Subject: [PATCH 10/59] more --- python/sglang/srt/utils/weight_checker.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index c7fca944ac..553c684843 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -8,3 +8,12 @@ def handle(self, action: str): self._compare() else: raise Exception(f"Unsupported {action=}") + + def _snapshot(self): + TODO + + def _reset_param(self): + TODO + + def _compare(self): + TODO From fca66f1a3c803cfc1ef650135fe1e3f80155ea09 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:21:46 +0800 Subject: [PATCH 11/59] more --- .../sglang/srt/managers/scheduler_update_weights_mixin.py | 2 +- python/sglang/srt/model_executor/model_runner.py | 6 +++++- python/sglang/srt/utils/weight_checker.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 3f907141d3..63a3aed004 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -168,7 +168,7 @@ def resume_memory_occupation( return ResumeMemoryOccupationReqOutput() def check_weight(self: Scheduler, recv_req: CheckWeightReqInput): - self.tp_worker.model_runner.check_weight(action=recv_req.action) + self.tp_worker.check_weight(action=recv_req.action) return CheckWeightReqOutput() def save_remote_model(self: Scheduler, params): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e02cf8b068..d84a58e530 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -172,6 +172,8 @@ FlattenedTensorMetadata, ) +from sglang.srt.utils.weight_checker import WeightChecker + MLA_ATTENTION_BACKENDS = [ "aiter", "flashinfer", @@ -324,6 +326,8 @@ def __init__( # CPU offload set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank)) + self._weight_checker = WeightChecker(model_runner=self) + if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"): slow_rank_detector.execute() # Init mindspore running environment when model impl is "mindspore" @@ -2445,7 +2449,7 @@ def save_sharded_model( ShardedStateLoader.save_model(self.model, path, pattern, max_size) def check_weight(self, action: str): - TODO + self._weight_checker.handle(action=action) def update_weights_from_ipc(self, recv_req): """Update weights from IPC for checkpoint-engine integration.""" diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 553c684843..1d8baa09b9 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -1,4 +1,7 @@ class WeightChecker: + def __init__(self, model_runner): + self._model_runner = model_runner + def handle(self, action: str): if action == "snapshot": self._snapshot() From 4ded2c63a6685172c9b76365bd4b1826908e0b64 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:22:39 +0800 Subject: [PATCH 12/59] morr --- python/sglang/srt/utils/weight_checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 1d8baa09b9..73f1d0bc7e 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -20,3 +20,7 @@ def _reset_param(self): def _compare(self): TODO + + def _model_state(self): + # TODO: support EAGLE etc (e.g. yield from both main model and draft model) + yield from self._model_runner.model.named_parameters() From 715f84953b8efd771320b01c6ebdb6a3fe671246 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:24:58 +0800 Subject: [PATCH 13/59] more --- python/sglang/srt/utils/weight_checker.py | 76 ++++++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 73f1d0bc7e..eaada4f817 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -16,7 +16,8 @@ def _snapshot(self): TODO def _reset_param(self): - TODO + for name, param in self._model_state(): + TODO def _compare(self): TODO @@ -24,3 +25,76 @@ def _compare(self): def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() + +import torch + +def fill_tensor_with_random(t: torch.Tensor, *, low=None, high=None, dist='uniform'): + device = t.device + shape = t.shape + dtype = t.dtype + + if dtype.is_floating_point: + gen_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype + if dist == 'normal': + tmp = torch.randn(shape, device=device, dtype=gen_dtype) + else: + tmp = torch.rand(shape, device=device, dtype=gen_dtype) + t.copy_(tmp.to(dtype, copy=False)) + return + + # Complex types + if dtype.is_complex: + # choose real dtype for components + comp_dtype = torch.float32 if dtype == torch.complex64 else torch.float64 + if dist == 'normal': + real = torch.randn(shape, device=device, dtype=comp_dtype) + imag = torch.randn(shape, device=device, dtype=comp_dtype) + else: + real = torch.rand(shape, device=device, dtype=comp_dtype) + imag = torch.rand(shape, device=device, dtype=comp_dtype) + comp = torch.complex(real, imag).to(dtype) + t.copy_(comp) + return + + # Bool + if dtype == torch.bool: + # Bernoulli p=0.5 + mask = torch.rand(shape, device=device) > 0.5 + t.copy_(mask) + return + + # Integer types (signed/unsigned) + # Use torch.iinfo to get dtype range; pick sensible subrange if full-range is too large for randint. + if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + info = torch.iinfo(dtype) + minv = int(info.min) + maxv = int(info.max) + + if low is None: + low = minv + if high is None: + # torch.randint's high is exclusive; default to maxv+1 if safe + # but if range is gigantic (e.g., full uint64-like), clamp to a safe 32-bit window + try: + range_size = maxv - minv + 1 + except OverflowError: + range_size = 1 << 63 # fallback big + if range_size <= (1 << 31): + high = maxv + 1 + else: + # choose a centered 32-bit window to avoid overflowing torch.randint's internal limits + low = max(minv, -2**31) + high = min(maxv, 2**31 - 1) + 1 + + # torch.randint requires low < high + if not (low < high): + raise ValueError(f"invalid integer bounds: low={low}, high={high}") + + # produce as int64 then cast if necessary (torch.randint supports dtype arg for integer types, + # but generating in int64 then cast is robust) + rand = torch.randint(low=low, high=high, size=shape, device=device, dtype=torch.int64) + # cast to target dtype and copy + t.copy_(rand.to(dtype)) + return + + raise TypeError(f"unsupported dtype: {dtype}") From 5ed352d05fb220a1d1a8da2224c5853ba927d6e9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:25:45 +0800 Subject: [PATCH 14/59] more --- python/sglang/srt/utils/weight_checker.py | 61 +++++++++-------------- 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index eaada4f817..c5b4ca67ff 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -26,75 +26,60 @@ def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() -import torch - -def fill_tensor_with_random(t: torch.Tensor, *, low=None, high=None, dist='uniform'): +def _random_fill_tensor(t: torch.Tensor, *, low=None, high=None): + """ + Fill tensor `t` in-place with uniform random values. + - Floating: U(0,1) + - Complex: real/imag both U(0,1) + - Bool: Bernoulli with p=0.5 + - Integer: uniform integer in [low, high) + """ device = t.device shape = t.shape dtype = t.dtype + # Floating types (float32/64/16 + bfloat16) if dtype.is_floating_point: gen_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype - if dist == 'normal': - tmp = torch.randn(shape, device=device, dtype=gen_dtype) - else: - tmp = torch.rand(shape, device=device, dtype=gen_dtype) + tmp = torch.rand(shape, device=device, dtype=gen_dtype) t.copy_(tmp.to(dtype, copy=False)) return # Complex types if dtype.is_complex: - # choose real dtype for components comp_dtype = torch.float32 if dtype == torch.complex64 else torch.float64 - if dist == 'normal': - real = torch.randn(shape, device=device, dtype=comp_dtype) - imag = torch.randn(shape, device=device, dtype=comp_dtype) - else: - real = torch.rand(shape, device=device, dtype=comp_dtype) - imag = torch.rand(shape, device=device, dtype=comp_dtype) - comp = torch.complex(real, imag).to(dtype) - t.copy_(comp) + real = torch.rand(shape, device=device, dtype=comp_dtype) + imag = torch.rand(shape, device=device, dtype=comp_dtype) + t.copy_(torch.complex(real, imag).to(dtype)) return # Bool if dtype == torch.bool: - # Bernoulli p=0.5 mask = torch.rand(shape, device=device) > 0.5 t.copy_(mask) return - # Integer types (signed/unsigned) - # Use torch.iinfo to get dtype range; pick sensible subrange if full-range is too large for randint. + # Integer types if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): info = torch.iinfo(dtype) - minv = int(info.min) - maxv = int(info.max) - - if low is None: - low = minv + # Default integer range: full dtype range + if low is None: low = int(info.min) if high is None: - # torch.randint's high is exclusive; default to maxv+1 if safe - # but if range is gigantic (e.g., full uint64-like), clamp to a safe 32-bit window - try: - range_size = maxv - minv + 1 - except OverflowError: - range_size = 1 << 63 # fallback big - if range_size <= (1 << 31): + # torch.randint high is exclusive; make maxv+1 if safe + maxv = int(info.max) + if maxv - low + 1 <= (1 << 31): high = maxv + 1 else: - # choose a centered 32-bit window to avoid overflowing torch.randint's internal limits - low = max(minv, -2**31) - high = min(maxv, 2**31 - 1) + 1 + # huge range fallback to a safe 32-bit window + low = max(low, -2**31) + high = 2**31 - 1 - # torch.randint requires low < high if not (low < high): raise ValueError(f"invalid integer bounds: low={low}, high={high}") - # produce as int64 then cast if necessary (torch.randint supports dtype arg for integer types, - # but generating in int64 then cast is robust) rand = torch.randint(low=low, high=high, size=shape, device=device, dtype=torch.int64) - # cast to target dtype and copy t.copy_(rand.to(dtype)) return raise TypeError(f"unsupported dtype: {dtype}") + From 5f9c887958f8be79ab110b79cabbbb7dcc51f581 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:25:58 +0800 Subject: [PATCH 15/59] more --- python/sglang/srt/utils/weight_checker.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index c5b4ca67ff..8f37de8c55 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -1,3 +1,5 @@ +import torch + class WeightChecker: def __init__(self, model_runner): self._model_runner = model_runner @@ -27,22 +29,14 @@ def _model_state(self): yield from self._model_runner.model.named_parameters() def _random_fill_tensor(t: torch.Tensor, *, low=None, high=None): - """ - Fill tensor `t` in-place with uniform random values. - - Floating: U(0,1) - - Complex: real/imag both U(0,1) - - Bool: Bernoulli with p=0.5 - - Integer: uniform integer in [low, high) - """ device = t.device shape = t.shape dtype = t.dtype - # Floating types (float32/64/16 + bfloat16) if dtype.is_floating_point: gen_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype tmp = torch.rand(shape, device=device, dtype=gen_dtype) - t.copy_(tmp.to(dtype, copy=False)) + t.copy_(tmp.to(dtype)) return # Complex types From 78ed1351a684fd9b3ffcd07b5e4ad774e0750036 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:26:14 +0800 Subject: [PATCH 16/59] more --- python/sglang/srt/utils/weight_checker.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 8f37de8c55..0bcad0180d 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -39,18 +39,8 @@ def _random_fill_tensor(t: torch.Tensor, *, low=None, high=None): t.copy_(tmp.to(dtype)) return - # Complex types - if dtype.is_complex: - comp_dtype = torch.float32 if dtype == torch.complex64 else torch.float64 - real = torch.rand(shape, device=device, dtype=comp_dtype) - imag = torch.rand(shape, device=device, dtype=comp_dtype) - t.copy_(torch.complex(real, imag).to(dtype)) - return - - # Bool if dtype == torch.bool: - mask = torch.rand(shape, device=device) > 0.5 - t.copy_(mask) + t.copy_(torch.rand(shape, device=device) > 0.5) return # Integer types From 7dc4f5b6cf6ad7d706a99f498e7e5a2767a722d6 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:26:40 +0800 Subject: [PATCH 17/59] more --- python/sglang/srt/utils/weight_checker.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 0bcad0180d..81f7c2f19b 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -28,7 +28,7 @@ def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() -def _random_fill_tensor(t: torch.Tensor, *, low=None, high=None): +def _random_like(t: torch.Tensor, *, low=None, high=None): device = t.device shape = t.shape dtype = t.dtype @@ -36,18 +36,17 @@ def _random_fill_tensor(t: torch.Tensor, *, low=None, high=None): if dtype.is_floating_point: gen_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype tmp = torch.rand(shape, device=device, dtype=gen_dtype) - t.copy_(tmp.to(dtype)) - return + return tmp.to(dtype) if dtype == torch.bool: - t.copy_(torch.rand(shape, device=device) > 0.5) - return + return torch.rand(shape, device=device) > 0.5 # Integer types if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): info = torch.iinfo(dtype) # Default integer range: full dtype range - if low is None: low = int(info.min) + if low is None: + low = int(info.min) if high is None: # torch.randint high is exclusive; make maxv+1 if safe maxv = int(info.max) @@ -61,9 +60,7 @@ def _random_fill_tensor(t: torch.Tensor, *, low=None, high=None): if not (low < high): raise ValueError(f"invalid integer bounds: low={low}, high={high}") - rand = torch.randint(low=low, high=high, size=shape, device=device, dtype=torch.int64) - t.copy_(rand.to(dtype)) - return + return torch.randint(low=low, high=high, size=shape, device=device, dtype=torch.int64) raise TypeError(f"unsupported dtype: {dtype}") From f7e95d164e7f0911cf902c50f691c275445ba60c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:27:11 +0800 Subject: [PATCH 18/59] more --- python/sglang/srt/utils/weight_checker.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 81f7c2f19b..3104568270 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -44,22 +44,8 @@ def _random_like(t: torch.Tensor, *, low=None, high=None): # Integer types if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): info = torch.iinfo(dtype) - # Default integer range: full dtype range - if low is None: - low = int(info.min) - if high is None: - # torch.randint high is exclusive; make maxv+1 if safe - maxv = int(info.max) - if maxv - low + 1 <= (1 << 31): - high = maxv + 1 - else: - # huge range fallback to a safe 32-bit window - low = max(low, -2**31) - high = 2**31 - 1 - - if not (low < high): - raise ValueError(f"invalid integer bounds: low={low}, high={high}") - + low = int(info.min) + high = int(info.max) return torch.randint(low=low, high=high, size=shape, device=device, dtype=torch.int64) raise TypeError(f"unsupported dtype: {dtype}") From 14b821d8f8d7682f1417773983995c9d10a3c75e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:27:25 +0800 Subject: [PATCH 19/59] more --- python/sglang/srt/utils/weight_checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 3104568270..2b4f66a454 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -19,7 +19,7 @@ def _snapshot(self): def _reset_param(self): for name, param in self._model_state(): - TODO + param.copy_(_random_like(param)) def _compare(self): TODO @@ -28,7 +28,7 @@ def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() -def _random_like(t: torch.Tensor, *, low=None, high=None): +def _random_like(t: torch.Tensor): device = t.device shape = t.shape dtype = t.dtype From 9791468f1ecc7f7170fcb312b20b4d35b89395c8 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:27:52 +0800 Subject: [PATCH 20/59] more --- python/sglang/srt/utils/weight_checker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 2b4f66a454..412da4dd1d 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -1,10 +1,15 @@ +import logging + import torch +logger = logging.getLogger(__name__) + class WeightChecker: def __init__(self, model_runner): self._model_runner = model_runner def handle(self, action: str): + logger.info(f"[WeightChecker] handle action={action}") if action == "snapshot": self._snapshot() elif action == "reset_param": From 720956f96e25d8af5b716baa7adcf7c4ca3d7314 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:28:25 +0800 Subject: [PATCH 21/59] more --- python/sglang/srt/utils/weight_checker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 412da4dd1d..5d56123da2 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -7,6 +7,7 @@ class WeightChecker: def __init__(self, model_runner): self._model_runner = model_runner + self._snapshot_tensors = None def handle(self, action: str): logger.info(f"[WeightChecker] handle action={action}") @@ -20,7 +21,10 @@ def handle(self, action: str): raise Exception(f"Unsupported {action=}") def _snapshot(self): - TODO + self._snapshot_tensors = [ + (name, param.data.detach().cpu()) + for name, param in self._model_state() + ] def _reset_param(self): for name, param in self._model_state(): From bc6830d224b3ae0e61f536c52439532c7e86926d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:28:42 +0800 Subject: [PATCH 22/59] more --- python/sglang/srt/utils/weight_checker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 5d56123da2..e35989e07d 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -31,6 +31,7 @@ def _reset_param(self): param.copy_(_random_like(param)) def _compare(self): + assert self._snapshot_tensors is not None TODO def _model_state(self): From e6aed98c7ede3dc3396d204c4d361d9b43f51355 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:29:11 +0800 Subject: [PATCH 23/59] fmt --- python/sglang/srt/model_executor/model_runner.py | 3 +-- python/sglang/srt/utils/weight_checker.py | 10 ++++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d84a58e530..aeb7b2bcb9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -167,13 +167,12 @@ ) from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils.weight_checker import WeightChecker from sglang.srt.weight_sync.tensor_bucket import ( FlattenedTensorBucket, FlattenedTensorMetadata, ) -from sglang.srt.utils.weight_checker import WeightChecker - MLA_ATTENTION_BACKENDS = [ "aiter", "flashinfer", diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index e35989e07d..64d3ba3978 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -4,6 +4,7 @@ logger = logging.getLogger(__name__) + class WeightChecker: def __init__(self, model_runner): self._model_runner = model_runner @@ -22,8 +23,7 @@ def handle(self, action: str): def _snapshot(self): self._snapshot_tensors = [ - (name, param.data.detach().cpu()) - for name, param in self._model_state() + (name, param.data.detach().cpu()) for name, param in self._model_state() ] def _reset_param(self): @@ -38,6 +38,7 @@ def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() + def _random_like(t: torch.Tensor): device = t.device shape = t.shape @@ -56,7 +57,8 @@ def _random_like(t: torch.Tensor): info = torch.iinfo(dtype) low = int(info.min) high = int(info.max) - return torch.randint(low=low, high=high, size=shape, device=device, dtype=torch.int64) + return torch.randint( + low=low, high=high, size=shape, device=device, dtype=torch.int64 + ) raise TypeError(f"unsupported dtype: {dtype}") - From a374cb87b0c985e09e625dcd873893f832f630ae Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:29:54 +0800 Subject: [PATCH 24/59] more --- python/sglang/srt/utils/weight_checker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 64d3ba3978..27baafbf4f 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -22,9 +22,9 @@ def handle(self, action: str): raise Exception(f"Unsupported {action=}") def _snapshot(self): - self._snapshot_tensors = [ - (name, param.data.detach().cpu()) for name, param in self._model_state() - ] + named_tensors = [(name, param.data.detach().cpu()) for name, param in self._model_state()] + self._snapshot_tensors = dict(named_tensors) + assert len(self._snapshot_tensors) == len(named_tensors), f"should not have duplicated tensor name" def _reset_param(self): for name, param in self._model_state(): From 2490c9d5c939ac17d8644c78933fff2db79df0fd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:32:35 +0800 Subject: [PATCH 25/59] more --- python/sglang/srt/utils/weight_checker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 27baafbf4f..30ec2681c7 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -32,7 +32,14 @@ def _reset_param(self): def _compare(self): assert self._snapshot_tensors is not None - TODO + + curr_tensors = dict(self._model_state()) + assert len(curr_tensors) == len(self._snapshot_tensors) + + for name in curr_tensors: + curr_tensor = curr_tensors[name] + snapshot_tensor = self._snapshot_tensors[name] + TODO def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) From 721e4e8328787e18818a58d818016e26483429f0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:32:52 +0800 Subject: [PATCH 26/59] more --- python/sglang/srt/utils/weight_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 30ec2681c7..4f6c15616e 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -38,7 +38,7 @@ def _compare(self): for name in curr_tensors: curr_tensor = curr_tensors[name] - snapshot_tensor = self._snapshot_tensors[name] + snapshot_tensor = self._snapshot_tensors[name].cuda() TODO def _model_state(self): From c4015e597c2a6fb029487d932804c4a7d36aa505 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:33:14 +0800 Subject: [PATCH 27/59] more --- python/sglang/srt/utils/weight_checker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 4f6c15616e..7b2dd38c84 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -39,7 +39,8 @@ def _compare(self): for name in curr_tensors: curr_tensor = curr_tensors[name] snapshot_tensor = self._snapshot_tensors[name].cuda() - TODO + if not torch.all(curr_tensor == snapshot_tensor): + TODO def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) From 8990e09cd8d73fa9a659364d18a7069c313191b4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:33:35 +0800 Subject: [PATCH 28/59] more --- python/sglang/srt/utils/weight_checker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 7b2dd38c84..7d6d59fcb4 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -47,6 +47,10 @@ def _model_state(self): yield from self._model_runner.model.named_parameters() +def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]): + TODO + + def _random_like(t: torch.Tensor): device = t.device shape = t.shape From 33fd97162b34edfdbdafb23231ec790ab54221f3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:34:15 +0800 Subject: [PATCH 29/59] more --- python/sglang/srt/utils/weight_checker.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 7d6d59fcb4..ca41e2d2be 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -1,4 +1,5 @@ import logging +from typing import Dict import torch @@ -33,14 +34,10 @@ def _reset_param(self): def _compare(self): assert self._snapshot_tensors is not None - curr_tensors = dict(self._model_state()) - assert len(curr_tensors) == len(self._snapshot_tensors) - - for name in curr_tensors: - curr_tensor = curr_tensors[name] - snapshot_tensor = self._snapshot_tensors[name].cuda() - if not torch.all(curr_tensor == snapshot_tensor): - TODO + _check_tensors( + expect_tensors=self._snapshot_tensors, + actual_tensors=dict(self._model_state()), + ) def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) @@ -48,7 +45,13 @@ def _model_state(self): def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]): - TODO + assert len(expect_tensors) == len(actual_tensors) + + for name in expect_tensors: + expect = expect_tensors[name] + actual = actual_tensors[name] + if not torch.all(curr_tensor == snapshot_tensor): + TODO def _random_like(t: torch.Tensor): From c3c7932796b22c6791dcce52b8092a5e40b4fabe Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:34:28 +0800 Subject: [PATCH 30/59] more --- python/sglang/srt/utils/weight_checker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index ca41e2d2be..5d31b88b68 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -50,8 +50,11 @@ def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict for name in expect_tensors: expect = expect_tensors[name] actual = actual_tensors[name] - if not torch.all(curr_tensor == snapshot_tensor): - TODO + + if torch.all(expect == actual): + continue + + TODO def _random_like(t: torch.Tensor): From c599ea696e819a88bd55d5ea02961e1eecb79c44 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:35:13 +0800 Subject: [PATCH 31/59] more --- python/sglang/srt/utils/weight_checker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 5d31b88b68..9455879e63 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -47,6 +47,8 @@ def _model_state(self): def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]): assert len(expect_tensors) == len(actual_tensors) + error_messages = [] + for name in expect_tensors: expect = expect_tensors[name] actual = actual_tensors[name] @@ -54,7 +56,9 @@ def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict if torch.all(expect == actual): continue - TODO + error_messages.append() + + raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages)) def _random_like(t: torch.Tensor): From 8555a5de148bef4f0186ebdf6e97728b9faf574b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:36:51 +0800 Subject: [PATCH 32/59] more --- python/sglang/srt/utils/weight_checker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 9455879e63..6e1305aa72 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -56,7 +56,10 @@ def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict if torch.all(expect == actual): continue - error_messages.append() + error_messages.append( + f"name={name} " + f"{TODO}" + ) raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages)) From 94fae9bc8ce42f25dd270895c1b7a9ec46212000 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:39:06 +0800 Subject: [PATCH 33/59] more --- python/sglang/srt/utils/weight_checker.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 6e1305aa72..d9c5cfa7c7 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -3,6 +3,7 @@ import torch + logger = logging.getLogger(__name__) @@ -45,6 +46,8 @@ def _model_state(self): def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]): + from sglang.srt.debug_utils.dumper import get_tensor_info + assert len(expect_tensors) == len(actual_tensors) error_messages = [] @@ -58,7 +61,10 @@ def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict error_messages.append( f"name={name} " - f"{TODO}" + f"max_abs_err={(actual - expect).abs().max()} " + f"mean_abs_err={(actual - expect).abs().mean()} " + f"{get_tensor_info(expect)=} " + f"{get_tensor_info(actual)=} " ) raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages)) From 7a43c3d52d585626f0141c00715ae59568146951 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:39:47 +0800 Subject: [PATCH 34/59] fmt --- python/sglang/srt/utils/weight_checker.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index d9c5cfa7c7..931b872234 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -3,7 +3,6 @@ import torch - logger = logging.getLogger(__name__) @@ -24,9 +23,13 @@ def handle(self, action: str): raise Exception(f"Unsupported {action=}") def _snapshot(self): - named_tensors = [(name, param.data.detach().cpu()) for name, param in self._model_state()] + named_tensors = [ + (name, param.data.detach().cpu()) for name, param in self._model_state() + ] self._snapshot_tensors = dict(named_tensors) - assert len(self._snapshot_tensors) == len(named_tensors), f"should not have duplicated tensor name" + assert len(self._snapshot_tensors) == len( + named_tensors + ), f"should not have duplicated tensor name" def _reset_param(self): for name, param in self._model_state(): @@ -45,7 +48,9 @@ def _model_state(self): yield from self._model_runner.model.named_parameters() -def _check_tensors(expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]): +def _check_tensors( + expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor] +): from sglang.srt.debug_utils.dumper import get_tensor_info assert len(expect_tensors) == len(actual_tensors) From 2e7e157e721058fcffbd560630e423304d541012 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:44:14 +0800 Subject: [PATCH 35/59] more --- python/sglang/srt/entrypoints/http_server.py | 8 ++++---- python/sglang/srt/managers/io_struct.py | 4 ++-- python/sglang/srt/managers/scheduler.py | 4 +++- .../srt/managers/scheduler_update_weights_mixin.py | 10 +++++----- .../srt/managers/tokenizer_communicator_mixin.py | 14 +++++++------- python/sglang/srt/model_executor/model_runner.py | 2 +- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index a56fab46e8..03625770e7 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -76,7 +76,7 @@ from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.managers.io_struct import ( AbortReq, - CheckWeightReqInput, + CheckWeightsReqInput, CloseSessionReqInput, ConfigureLoggingReq, DestroyWeightsUpdateGroupReqInput, @@ -955,10 +955,10 @@ async def resume_memory_occupation( return _create_error_response(e) -@app.api_route("/check_weight", methods=["POST"]) -async def check_weight(obj: CheckWeightReqInput, request: Request): +@app.api_route("/check_weights", methods=["POST"]) +async def check_weights(obj: CheckWeightsReqInput, request: Request): try: - await _global_state.tokenizer_manager.check_weight(obj, request) + await _global_state.tokenizer_manager.check_weights(obj, request) except Exception as e: return _create_error_response(e) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 2aecee1aa9..4466d22f9c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1273,12 +1273,12 @@ class ResumeMemoryOccupationReqOutput(BaseReq): @dataclass -class CheckWeightReqInput(BaseReq): +class CheckWeightsReqInput(BaseReq): action: str @dataclass -class CheckWeightReqOutput(BaseReq): +class CheckWeightsReqOutput(BaseReq): pass diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 94029822c2..d0f7f469cb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -112,6 +112,7 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, + CheckWeightsReqInput, ) from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.managers.overlap_utils import FutureMap @@ -194,6 +195,7 @@ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback + logger = logging.getLogger(__name__) # Test retract decode for debugging purposes @@ -561,7 +563,7 @@ def __init__( (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), - (CheckWeightReqInput, self.check_weight), + (CheckWeightsReqInput, self.check_weights), (SlowDownReqInput, self.slow_down), (ProfileReq, self.profile), (FreezeGCReq, self.handle_freeze_gc), diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 63a3aed004..f8d21c409b 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -12,8 +12,8 @@ GPU_MEMORY_TYPE_WEIGHTS, ) from sglang.srt.managers.io_struct import ( - CheckWeightReqInput, - CheckWeightReqOutput, + CheckWeightsReqInput, + CheckWeightsReqOutput, DestroyWeightsUpdateGroupReqInput, DestroyWeightsUpdateGroupReqOutput, GetWeightsByNameReqInput, @@ -167,9 +167,9 @@ def resume_memory_occupation( return ResumeMemoryOccupationReqOutput() - def check_weight(self: Scheduler, recv_req: CheckWeightReqInput): - self.tp_worker.check_weight(action=recv_req.action) - return CheckWeightReqOutput() + def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): + self.tp_worker.check_weights(action=recv_req.action) + return CheckWeightsReqOutput() def save_remote_model(self: Scheduler, params): url = params["url"] diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 570023e49e..c632a4f797 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -22,8 +22,8 @@ import zmq from sglang.srt.managers.io_struct import ( - CheckWeightReqInput, - CheckWeightReqOutput, + CheckWeightsReqInput, + CheckWeightsReqOutput, ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, @@ -185,7 +185,7 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs): self.resume_memory_occupation_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - self.check_weight_communicator = _Communicator( + self.check_weights_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) self.slow_down_communicator = _Communicator( @@ -262,8 +262,8 @@ def _get_communicator_dispatcher(self: TokenizerManager): self.resume_memory_occupation_communicator.handle_recv, ), ( - CheckWeightReqOutput, - self.check_weight_communicator.handle_recv, + CheckWeightsReqOutput, + self.check_weights_communicator.handle_recv, ), ( SlowDownReqOutput, @@ -665,9 +665,9 @@ async def resume_memory_occupation( self.auto_create_handle_loop() await self.resume_memory_occupation_communicator(obj) - async def check_weight( + async def check_weights( self: TokenizerManager, - obj: CheckWeightReqInput, + obj: CheckWeightsReqInput, request: Optional[fastapi.Request] = None, ): self.auto_create_handle_loop() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index aeb7b2bcb9..bff5abc373 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2447,7 +2447,7 @@ def save_sharded_model( ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) - def check_weight(self, action: str): + def check_weights(self, action: str): self._weight_checker.handle(action=action) def update_weights_from_ipc(self, recv_req): From 7256e33a578a69c318ef42a6fda8854f94774713 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:44:47 +0800 Subject: [PATCH 36/59] fmt --- python/sglang/srt/managers/scheduler.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d0f7f469cb..6e82b6850d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -70,6 +70,7 @@ BaseReq, BatchTokenizedEmbeddingReqInput, BatchTokenizedGenerateReqInput, + CheckWeightsReqInput, ClearHiCacheReqInput, ClearHiCacheReqOutput, CloseSessionReqInput, @@ -112,7 +113,6 @@ UpdateWeightsFromDistributedReqInput, UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, - CheckWeightsReqInput, ) from sglang.srt.managers.mm_utils import init_mm_embedding_cache from sglang.srt.managers.overlap_utils import FutureMap @@ -195,7 +195,6 @@ from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.utils import TypeBasedDispatcher, get_exception_traceback - logger = logging.getLogger(__name__) # Test retract decode for debugging purposes From 299df306290d5b304787be329b2668e6b7586b1c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 09:48:51 +0800 Subject: [PATCH 37/59] more --- python/sglang/srt/utils/weight_checker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 931b872234..af145d0963 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -15,8 +15,8 @@ def handle(self, action: str): logger.info(f"[WeightChecker] handle action={action}") if action == "snapshot": self._snapshot() - elif action == "reset_param": - self._reset_param() + elif action == "reset_tensors": + self._reset_tensors() elif action == "compare": self._compare() else: @@ -31,7 +31,7 @@ def _snapshot(self): named_tensors ), f"should not have duplicated tensor name" - def _reset_param(self): + def _reset_tensors(self): for name, param in self._model_state(): param.copy_(_random_like(param)) From df17ccfd4c3e7f8f3e7cf315f7ba79a1f2d4ebb9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 17:47:33 +0800 Subject: [PATCH 38/59] fix moe check --- python/sglang/srt/model_executor/model_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bff5abc373..4a1d8d3107 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -596,8 +596,11 @@ def check_quantized_moe_compatibility(self): moe_tp_size = self.tp_size // self.moe_ep_size moe_intermediate_size = ( - self.model_config.hf_text_config.moe_intermediate_size + getattr(self.model_config.hf_text_config, "moe_intermediate_size", None) ) + if moe_intermediate_size is None: + return + if moe_intermediate_size % moe_tp_size != 0: raise ValueError( f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})." From 399073b064f8ea5906526f7085004c343d1c8b31 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:32:37 +0800 Subject: [PATCH 39/59] more --- python/sglang/srt/utils/weight_checker.py | 26 +++++++++++++---------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index af145d0963..1eae5c100e 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -55,6 +55,7 @@ def _check_tensors( assert len(expect_tensors) == len(actual_tensors) + good_names = [] error_messages = [] for name in expect_tensors: @@ -62,17 +63,20 @@ def _check_tensors( actual = actual_tensors[name] if torch.all(expect == actual): - continue - - error_messages.append( - f"name={name} " - f"max_abs_err={(actual - expect).abs().max()} " - f"mean_abs_err={(actual - expect).abs().mean()} " - f"{get_tensor_info(expect)=} " - f"{get_tensor_info(actual)=} " - ) - - raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages)) + good_names.append(name) + else: + error_messages.append( + f"name={name} " + f"max_abs_err={(actual - expect).abs().max()} " + f"mean_abs_err={(actual - expect).abs().mean()} " + f"{get_tensor_info(expect)=} " + f"{get_tensor_info(actual)=} " + ) + + logger.info(f"[check_tensors] passed: {good_names}") + if len(error_messages) > 0: + msg = f"check tensor equality failed:\n" + "\n".join(error_messages) + raise Exception(msg) def _random_like(t: torch.Tensor): From 2dbafd462ff8f33efb65503977bdee4c748184ed Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:49:55 +0800 Subject: [PATCH 40/59] another fix --- python/sglang/srt/model_executor/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4a1d8d3107..a19c2e63b1 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -595,8 +595,8 @@ def check_quantized_moe_compatibility(self): ) moe_tp_size = self.tp_size // self.moe_ep_size - moe_intermediate_size = ( - getattr(self.model_config.hf_text_config, "moe_intermediate_size", None) + moe_intermediate_size = getattr( + self.model_config.hf_text_config, "moe_intermediate_size", None ) if moe_intermediate_size is None: return From 648b118a5baa0e47d9d0049f84e4f0464a4ac3fa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:49:57 +0800 Subject: [PATCH 41/59] more --- python/sglang/srt/entrypoints/http_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 03625770e7..0e2d69a043 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -955,7 +955,7 @@ async def resume_memory_occupation( return _create_error_response(e) -@app.api_route("/check_weights", methods=["POST"]) +@app.post("/check_weights") async def check_weights(obj: CheckWeightsReqInput, request: Request): try: await _global_state.tokenizer_manager.check_weights(obj, request) From 029c377a9d6cd53240614433e0e82e151cf778ce Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:51:26 +0800 Subject: [PATCH 42/59] more --- python/sglang/srt/managers/tokenizer_communicator_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index c632a4f797..3e15d23b38 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -671,7 +671,7 @@ async def check_weights( request: Optional[fastapi.Request] = None, ): self.auto_create_handle_loop() - await self.check_weight_communicator(obj) + await self.check_weights_communicator(obj) async def slow_down( self: TokenizerManager, From ed02bdbd99627339be78645b5bf8d01bde644059 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:54:25 +0800 Subject: [PATCH 43/59] more --- python/sglang/srt/entrypoints/http_server.py | 7 +++---- python/sglang/srt/managers/io_struct.py | 3 ++- python/sglang/srt/managers/tokenizer_communicator_mixin.py | 5 +++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 0e2d69a043..bf5d429d6a 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -957,10 +957,9 @@ async def resume_memory_occupation( @app.post("/check_weights") async def check_weights(obj: CheckWeightsReqInput, request: Request): - try: - await _global_state.tokenizer_manager.check_weights(obj, request) - except Exception as e: - return _create_error_response(e) + resp = await _global_state.tokenizer_manager.check_weights(obj, request) + return ORJSONResponse({"success": resp.success, "message": resp.message}, + status_code=200 if resp.success else HTTPStatus.BAD_REQUEST) @app.api_route("/slow_down", methods=["GET", "POST"]) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 4466d22f9c..ba061b484c 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1279,7 +1279,8 @@ class CheckWeightsReqInput(BaseReq): @dataclass class CheckWeightsReqOutput(BaseReq): - pass + success: bool + message: str @dataclass diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index 3e15d23b38..6a4d67c02e 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -669,9 +669,10 @@ async def check_weights( self: TokenizerManager, obj: CheckWeightsReqInput, request: Optional[fastapi.Request] = None, - ): + ) -> CheckWeightsReqOutput: self.auto_create_handle_loop() - await self.check_weights_communicator(obj) + results = await self.check_weights_communicator(obj) + return _Communicator.merge_results(results) async def slow_down( self: TokenizerManager, From 3ec8f77505be2c35650fb55d1bcbfb77b52bb2ea Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:55:07 +0800 Subject: [PATCH 44/59] more --- .../sglang/srt/managers/scheduler_update_weights_mixin.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index f8d21c409b..a211b8bfbf 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -168,8 +168,11 @@ def resume_memory_occupation( return ResumeMemoryOccupationReqOutput() def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): - self.tp_worker.check_weights(action=recv_req.action) - return CheckWeightsReqOutput() + try: + self.tp_worker.check_weights(action=recv_req.action) + return CheckWeightsReqOutput(success=True, message="") + except Exception as e: + return CheckWeightsReqOutput(success=False, message=f"{e}") def save_remote_model(self: Scheduler, params): url = params["url"] From 8b25bca9a33c4b032556c1e5a7b23fc19a6d838b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:55:41 +0800 Subject: [PATCH 45/59] fmt --- python/sglang/srt/entrypoints/http_server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index bf5d429d6a..fe8ff720e7 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -958,8 +958,10 @@ async def resume_memory_occupation( @app.post("/check_weights") async def check_weights(obj: CheckWeightsReqInput, request: Request): resp = await _global_state.tokenizer_manager.check_weights(obj, request) - return ORJSONResponse({"success": resp.success, "message": resp.message}, - status_code=200 if resp.success else HTTPStatus.BAD_REQUEST) + return ORJSONResponse( + {"success": resp.success, "message": resp.message}, + status_code=200 if resp.success else HTTPStatus.BAD_REQUEST, + ) @app.api_route("/slow_down", methods=["GET", "POST"]) From aa2cdf2befc28b9fd2719d72adfc851417c293bd Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:56:40 +0800 Subject: [PATCH 46/59] more --- python/sglang/srt/entrypoints/http_server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index fe8ff720e7..39c9c27740 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -957,10 +957,10 @@ async def resume_memory_occupation( @app.post("/check_weights") async def check_weights(obj: CheckWeightsReqInput, request: Request): - resp = await _global_state.tokenizer_manager.check_weights(obj, request) + success, message = await _global_state.tokenizer_manager.check_weights(obj, request) return ORJSONResponse( - {"success": resp.success, "message": resp.message}, - status_code=200 if resp.success else HTTPStatus.BAD_REQUEST, + {"success": success, "message": message}, + status_code=200 if success else HTTPStatus.BAD_REQUEST, ) From a51838cff8e52ab1e156a3e23767f7ad0091e5b3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:56:58 +0800 Subject: [PATCH 47/59] more --- python/sglang/srt/managers/scheduler_update_weights_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index a211b8bfbf..027d45016d 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -170,7 +170,7 @@ def resume_memory_occupation( def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): try: self.tp_worker.check_weights(action=recv_req.action) - return CheckWeightsReqOutput(success=True, message="") + return CheckWeightsReqOutput(success=True, message="Success.") except Exception as e: return CheckWeightsReqOutput(success=False, message=f"{e}") From 996366bd925168ea0c308b78a0947e954eabcc2d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 18:58:44 +0800 Subject: [PATCH 48/59] more --- python/sglang/srt/managers/scheduler_update_weights_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 027d45016d..7646b31fa6 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -169,7 +169,7 @@ def resume_memory_occupation( def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): try: - self.tp_worker.check_weights(action=recv_req.action) + self.tp_worker.model_runner.check_weights(action=recv_req.action) return CheckWeightsReqOutput(success=True, message="Success.") except Exception as e: return CheckWeightsReqOutput(success=False, message=f"{e}") From c64b43066891fa86cebcd4dd7dc29d7a5a544bd7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 19:00:07 +0800 Subject: [PATCH 49/59] more --- python/sglang/srt/managers/scheduler_update_weights_mixin.py | 3 +++ python/sglang/srt/utils/weight_checker.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 7646b31fa6..baf85f19b5 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import traceback from typing import TYPE_CHECKING, Tuple import torch @@ -172,6 +173,8 @@ def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): self.tp_worker.model_runner.check_weights(action=recv_req.action) return CheckWeightsReqOutput(success=True, message="Success.") except Exception as e: + logger.warning(f"check_weights see error: {e}") + traceback.print_stack() return CheckWeightsReqOutput(success=False, message=f"{e}") def save_remote_model(self: Scheduler, params): diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 1eae5c100e..21c41cd4b9 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -75,8 +75,7 @@ def _check_tensors( logger.info(f"[check_tensors] passed: {good_names}") if len(error_messages) > 0: - msg = f"check tensor equality failed:\n" + "\n".join(error_messages) - raise Exception(msg) + raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages)) def _random_like(t: torch.Tensor): From f8bacdeaa58c3fbfee0f5c72e1e82f033cabfc52 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 19:00:28 +0800 Subject: [PATCH 50/59] more --- python/sglang/srt/utils/weight_checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 21c41cd4b9..2af5b3c9ea 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -59,8 +59,8 @@ def _check_tensors( error_messages = [] for name in expect_tensors: - expect = expect_tensors[name] - actual = actual_tensors[name] + expect = expect_tensors[name].cuda() + actual = actual_tensors[name].cuda() if torch.all(expect == actual): good_names.append(name) From 27df0cdb9b2f4cf0cb74eb6c994bbd2247283a86 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 20:03:24 +0800 Subject: [PATCH 51/59] more --- python/sglang/srt/utils/weight_checker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 2af5b3c9ea..022235f7e6 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -96,8 +96,6 @@ def _random_like(t: torch.Tensor): info = torch.iinfo(dtype) low = int(info.min) high = int(info.max) - return torch.randint( - low=low, high=high, size=shape, device=device, dtype=torch.int64 - ) + return torch.randint(low=low, high=high, size=shape, device=device, dtype=dtype) raise TypeError(f"unsupported dtype: {dtype}") From 39d5c3d7a0e4dec35dce1de1c8382fe7631777ee Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 20:07:04 +0800 Subject: [PATCH 52/59] more --- python/sglang/srt/utils/weight_checker.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 022235f7e6..e862d2fabe 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -85,17 +85,13 @@ def _random_like(t: torch.Tensor): if dtype.is_floating_point: gen_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype - tmp = torch.rand(shape, device=device, dtype=gen_dtype) - return tmp.to(dtype) + return torch.rand(shape, device=device, dtype=gen_dtype).to(dtype) if dtype == torch.bool: return torch.rand(shape, device=device) > 0.5 - # Integer types if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): info = torch.iinfo(dtype) - low = int(info.min) - high = int(info.max) - return torch.randint(low=low, high=high, size=shape, device=device, dtype=dtype) + return torch.randint(low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype) raise TypeError(f"unsupported dtype: {dtype}") From 018c9109680f6f5b0654114b9159101256e1e18b Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 20:08:55 +0800 Subject: [PATCH 53/59] more --- python/sglang/srt/managers/scheduler_update_weights_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index baf85f19b5..5c1f91fa4d 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -174,7 +174,7 @@ def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput): return CheckWeightsReqOutput(success=True, message="Success.") except Exception as e: logger.warning(f"check_weights see error: {e}") - traceback.print_stack() + traceback.print_exc() return CheckWeightsReqOutput(success=False, message=f"{e}") def save_remote_model(self: Scheduler, params): From c111c29b4574a83100fca596109a3cbd0114b4c4 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 20:10:37 +0800 Subject: [PATCH 54/59] more --- python/sglang/srt/utils/weight_checker.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index e862d2fabe..72d1204ec6 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -84,14 +84,10 @@ def _random_like(t: torch.Tensor): dtype = t.dtype if dtype.is_floating_point: - gen_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype - return torch.rand(shape, device=device, dtype=gen_dtype).to(dtype) + return torch.rand(shape, device=device, dtype=torch.float32).to(dtype) if dtype == torch.bool: return torch.rand(shape, device=device) > 0.5 - if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): - info = torch.iinfo(dtype) - return torch.randint(low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype) - - raise TypeError(f"unsupported dtype: {dtype}") + info = torch.iinfo(dtype) + return torch.randint(low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype) From 80b77a7700800c93e87d5aaa421017a77a89bb20 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 20:15:18 +0800 Subject: [PATCH 55/59] more --- python/sglang/srt/utils/weight_checker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 72d1204ec6..4a75167031 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -65,10 +65,11 @@ def _check_tensors( if torch.all(expect == actual): good_names.append(name) else: + abs_diff = (actual.float() - expect.float()).abs() error_messages.append( f"name={name} " - f"max_abs_err={(actual - expect).abs().max()} " - f"mean_abs_err={(actual - expect).abs().mean()} " + f"max_abs_err={abs_diff.max()} " + f"mean_abs_err={abs_diff.mean()} " f"{get_tensor_info(expect)=} " f"{get_tensor_info(actual)=} " ) From ae06c782c2d4dfaeb53a7fd80cb30c911818c95c Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 20:32:11 +0800 Subject: [PATCH 56/59] more --- python/sglang/srt/utils/weight_checker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 4a75167031..3a9113de11 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -46,6 +46,7 @@ def _compare(self): def _model_state(self): # TODO: support EAGLE etc (e.g. yield from both main model and draft model) yield from self._model_runner.model.named_parameters() + yield from self._model_runner.model.named_buffers() def _check_tensors( From ce71a2c629cc785a35ed869a93bbc528d8af9a51 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 22:40:04 +0800 Subject: [PATCH 57/59] more --- python/sglang/srt/model_executor/model_runner.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a19c2e63b1..bff5abc373 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -595,12 +595,9 @@ def check_quantized_moe_compatibility(self): ) moe_tp_size = self.tp_size // self.moe_ep_size - moe_intermediate_size = getattr( - self.model_config.hf_text_config, "moe_intermediate_size", None + moe_intermediate_size = ( + self.model_config.hf_text_config.moe_intermediate_size ) - if moe_intermediate_size is None: - return - if moe_intermediate_size % moe_tp_size != 0: raise ValueError( f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})." From cf448d378e8baba359e1c0cb835bb2ce22b96db7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Mon, 24 Nov 2025 22:49:09 +0800 Subject: [PATCH 58/59] fmt --- python/sglang/srt/utils/weight_checker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 3a9113de11..98ae631df6 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -92,4 +92,6 @@ def _random_like(t: torch.Tensor): return torch.rand(shape, device=device) > 0.5 info = torch.iinfo(dtype) - return torch.randint(low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype) + return torch.randint( + low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype + ) From 5e1902fe0e3e06237ffdcc89d04f017e9c8934f5 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Thu, 27 Nov 2025 09:36:28 +0800 Subject: [PATCH 59/59] Update http_server.py --- python/sglang/srt/entrypoints/http_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 39c9c27740..c677dce4c1 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -955,7 +955,7 @@ async def resume_memory_occupation( return _create_error_response(e) -@app.post("/check_weights") +@app.post("/weights_checker") async def check_weights(obj: CheckWeightsReqInput, request: Request): success, message = await _global_state.tokenizer_manager.check_weights(obj, request) return ORJSONResponse(