Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
c0fb7e0
more
fzyzcjy Nov 24, 2025
e1f1a63
more
fzyzcjy Nov 24, 2025
92cd75e
more
fzyzcjy Nov 24, 2025
2ecbb75
more
fzyzcjy Nov 24, 2025
616fa04
more
fzyzcjy Nov 24, 2025
e40e5d4
more
fzyzcjy Nov 24, 2025
3f4f5c7
fmt
fzyzcjy Nov 24, 2025
11f4041
fmt
fzyzcjy Nov 24, 2025
de67d82
more
fzyzcjy Nov 24, 2025
3544da2
more
fzyzcjy Nov 24, 2025
fca66f1
more
fzyzcjy Nov 24, 2025
4ded2c6
morr
fzyzcjy Nov 24, 2025
715f849
more
fzyzcjy Nov 24, 2025
5ed352d
more
fzyzcjy Nov 24, 2025
5f9c887
more
fzyzcjy Nov 24, 2025
78ed135
more
fzyzcjy Nov 24, 2025
7dc4f5b
more
fzyzcjy Nov 24, 2025
f7e95d1
more
fzyzcjy Nov 24, 2025
14b821d
more
fzyzcjy Nov 24, 2025
9791468
more
fzyzcjy Nov 24, 2025
720956f
more
fzyzcjy Nov 24, 2025
bc6830d
more
fzyzcjy Nov 24, 2025
e6aed98
fmt
fzyzcjy Nov 24, 2025
a374cb8
more
fzyzcjy Nov 24, 2025
2490c9d
more
fzyzcjy Nov 24, 2025
721e4e8
more
fzyzcjy Nov 24, 2025
c4015e5
more
fzyzcjy Nov 24, 2025
8990e09
more
fzyzcjy Nov 24, 2025
33fd971
more
fzyzcjy Nov 24, 2025
c3c7932
more
fzyzcjy Nov 24, 2025
c599ea6
more
fzyzcjy Nov 24, 2025
8555a5d
more
fzyzcjy Nov 24, 2025
94fae9b
more
fzyzcjy Nov 24, 2025
7a43c3d
fmt
fzyzcjy Nov 24, 2025
2e7e157
more
fzyzcjy Nov 24, 2025
7256e33
fmt
fzyzcjy Nov 24, 2025
299df30
more
fzyzcjy Nov 24, 2025
df17ccf
fix moe check
fzyzcjy Nov 24, 2025
399073b
more
fzyzcjy Nov 24, 2025
2dbafd4
another fix
fzyzcjy Nov 24, 2025
648b118
more
fzyzcjy Nov 24, 2025
029c377
more
fzyzcjy Nov 24, 2025
ed02bdb
more
fzyzcjy Nov 24, 2025
3ec8f77
more
fzyzcjy Nov 24, 2025
8b25bca
fmt
fzyzcjy Nov 24, 2025
aa2cdf2
more
fzyzcjy Nov 24, 2025
a51838c
more
fzyzcjy Nov 24, 2025
996366b
more
fzyzcjy Nov 24, 2025
c64b430
more
fzyzcjy Nov 24, 2025
f8bacde
more
fzyzcjy Nov 24, 2025
27df0cd
more
fzyzcjy Nov 24, 2025
39d5c3d
more
fzyzcjy Nov 24, 2025
018c910
more
fzyzcjy Nov 24, 2025
c111c29
more
fzyzcjy Nov 24, 2025
80b77a7
more
fzyzcjy Nov 24, 2025
ae06c78
more
fzyzcjy Nov 24, 2025
a016280
Merge branch 'main-upstream' into feat/dev_20251124
fzyzcjy Nov 24, 2025
ce71a2c
more
fzyzcjy Nov 24, 2025
cf448d3
fmt
fzyzcjy Nov 24, 2025
647455a
Merge branch 'main' into feat/weight_checker
fzyzcjy Nov 25, 2025
5e1902f
Update http_server.py
fzyzcjy Nov 27, 2025
39e0588
Merge branch 'main' into feat/weight_checker
fzyzcjy Nov 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.managers.io_struct import (
AbortReq,
CheckWeightsReqInput,
CloseSessionReqInput,
ConfigureLoggingReq,
ContinueGenerationReqInput,
Expand Down Expand Up @@ -956,6 +957,15 @@ async def resume_memory_occupation(
return _create_error_response(e)


@app.post("/weights_checker")
async def check_weights(obj: CheckWeightsReqInput, request: Request):
success, message = await _global_state.tokenizer_manager.check_weights(obj, request)
return ORJSONResponse(
{"success": success, "message": message},
status_code=200 if success else HTTPStatus.BAD_REQUEST,
)


@app.api_route("/slow_down", methods=["GET", "POST"])
async def slow_down(obj: SlowDownReqInput, request: Request):
"""Slow down the system deliberately. Only for testing. Example scenario:
Expand Down
11 changes: 11 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,17 @@ class ResumeMemoryOccupationReqOutput(BaseReq):
pass


@dataclass
class CheckWeightsReqInput(BaseReq):
action: str


@dataclass
class CheckWeightsReqOutput(BaseReq):
success: bool
message: str


@dataclass
class SlowDownReqInput(BaseReq):
forward_sleep_time: Optional[float]
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
BaseReq,
BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput,
CheckWeightsReqInput,
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput,
Expand Down Expand Up @@ -568,6 +569,7 @@ def __init__(
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(CheckWeightsReqInput, self.check_weights),
(SlowDownReqInput, self.slow_down),
(ProfileReq, self.profile),
(FreezeGCReq, self.handle_freeze_gc),
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/managers/scheduler_update_weights_mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import traceback
from typing import TYPE_CHECKING, Tuple

import torch
Expand All @@ -12,6 +13,8 @@
GPU_MEMORY_TYPE_WEIGHTS,
)
from sglang.srt.managers.io_struct import (
CheckWeightsReqInput,
CheckWeightsReqOutput,
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
GetWeightsByNameReqInput,
Expand Down Expand Up @@ -166,6 +169,15 @@ def resume_memory_occupation(

return ResumeMemoryOccupationReqOutput()

def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
try:
self.tp_worker.model_runner.check_weights(action=recv_req.action)
return CheckWeightsReqOutput(success=True, message="Success.")
except Exception as e:
logger.warning(f"check_weights see error: {e}")
traceback.print_exc()
return CheckWeightsReqOutput(success=False, message=f"{e}")

def save_remote_model(self: Scheduler, params):
url = params["url"]

Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/managers/tokenizer_communicator_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import zmq

from sglang.srt.managers.io_struct import (
CheckWeightsReqInput,
CheckWeightsReqOutput,
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput,
Expand Down Expand Up @@ -183,6 +185,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs):
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.check_weights_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
Expand Down Expand Up @@ -256,6 +261,10 @@ def _get_communicator_dispatcher(self: TokenizerManager):
ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv,
),
(
CheckWeightsReqOutput,
self.check_weights_communicator.handle_recv,
),
(
SlowDownReqOutput,
self.slow_down_communicator.handle_recv,
Expand Down Expand Up @@ -670,6 +679,15 @@ async def resume_memory_occupation(
self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj)

async def check_weights(
self: TokenizerManager,
obj: CheckWeightsReqInput,
request: Optional[fastapi.Request] = None,
) -> CheckWeightsReqOutput:
self.auto_create_handle_loop()
results = await self.check_weights_communicator(obj)
return _Communicator.merge_results(results)

async def slow_down(
self: TokenizerManager,
obj: SlowDownReqInput,
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
)
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils.weight_checker import WeightChecker
from sglang.srt.weight_sync.tensor_bucket import (
FlattenedTensorBucket,
FlattenedTensorMetadata,
Expand Down Expand Up @@ -328,6 +329,8 @@ def __init__(
# CPU offload
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))

self._weight_checker = WeightChecker(model_runner=self)

if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
slow_rank_detector.execute()
# Init mindspore running environment when model impl is "mindspore"
Expand Down Expand Up @@ -2508,6 +2511,9 @@ def save_sharded_model(
)
ShardedStateLoader.save_model(self.model, path, pattern, max_size)

def check_weights(self, action: str):
self._weight_checker.handle(action=action)

def update_weights_from_ipc(self, recv_req):
"""Update weights from IPC for checkpoint-engine integration."""
try:
Expand Down
97 changes: 97 additions & 0 deletions python/sglang/srt/utils/weight_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
from typing import Dict

import torch

logger = logging.getLogger(__name__)


class WeightChecker:
def __init__(self, model_runner):
self._model_runner = model_runner
self._snapshot_tensors = None

def handle(self, action: str):
logger.info(f"[WeightChecker] handle action={action}")
if action == "snapshot":
self._snapshot()
elif action == "reset_tensors":
self._reset_tensors()
elif action == "compare":
self._compare()
else:
raise Exception(f"Unsupported {action=}")

def _snapshot(self):
named_tensors = [
(name, param.data.detach().cpu()) for name, param in self._model_state()
]
self._snapshot_tensors = dict(named_tensors)
assert len(self._snapshot_tensors) == len(
named_tensors
), f"should not have duplicated tensor name"

def _reset_tensors(self):
for name, param in self._model_state():
param.copy_(_random_like(param))

def _compare(self):
assert self._snapshot_tensors is not None

_check_tensors(
expect_tensors=self._snapshot_tensors,
actual_tensors=dict(self._model_state()),
)

def _model_state(self):
# TODO: support EAGLE etc (e.g. yield from both main model and draft model)
yield from self._model_runner.model.named_parameters()
yield from self._model_runner.model.named_buffers()


def _check_tensors(
expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]
):
from sglang.srt.debug_utils.dumper import get_tensor_info

assert len(expect_tensors) == len(actual_tensors)

good_names = []
error_messages = []

for name in expect_tensors:
expect = expect_tensors[name].cuda()
actual = actual_tensors[name].cuda()

if torch.all(expect == actual):
good_names.append(name)
else:
abs_diff = (actual.float() - expect.float()).abs()
error_messages.append(
f"name={name} "
f"max_abs_err={abs_diff.max()} "
f"mean_abs_err={abs_diff.mean()} "
f"{get_tensor_info(expect)=} "
f"{get_tensor_info(actual)=} "
)

logger.info(f"[check_tensors] passed: {good_names}")
if len(error_messages) > 0:
raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages))


def _random_like(t: torch.Tensor):
device = t.device
shape = t.shape
dtype = t.dtype

if dtype.is_floating_point:
return torch.rand(shape, device=device, dtype=torch.float32).to(dtype)

if dtype == torch.bool:
return torch.rand(shape, device=device) > 0.5

info = torch.iinfo(dtype)
return torch.randint(
low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype
)
Loading