Skip to content

Commit 2575864

Browse files
authored
Support sanity checking weight consistency especially for RL (#13854)
1 parent 2bc8ee8 commit 2575864

File tree

7 files changed

+156
-0
lines changed

7 files changed

+156
-0
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from sglang.srt.function_call.function_call_parser import FunctionCallParser
7777
from sglang.srt.managers.io_struct import (
7878
AbortReq,
79+
CheckWeightsReqInput,
7980
CloseSessionReqInput,
8081
ConfigureLoggingReq,
8182
ContinueGenerationReqInput,
@@ -956,6 +957,15 @@ async def resume_memory_occupation(
956957
return _create_error_response(e)
957958

958959

960+
@app.post("/weights_checker")
961+
async def check_weights(obj: CheckWeightsReqInput, request: Request):
962+
success, message = await _global_state.tokenizer_manager.check_weights(obj, request)
963+
return ORJSONResponse(
964+
{"success": success, "message": message},
965+
status_code=200 if success else HTTPStatus.BAD_REQUEST,
966+
)
967+
968+
959969
@app.api_route("/slow_down", methods=["GET", "POST"])
960970
async def slow_down(obj: SlowDownReqInput, request: Request):
961971
"""Slow down the system deliberately. Only for testing. Example scenario:

python/sglang/srt/managers/io_struct.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,6 +1311,17 @@ class ResumeMemoryOccupationReqOutput(BaseReq):
13111311
pass
13121312

13131313

1314+
@dataclass
1315+
class CheckWeightsReqInput(BaseReq):
1316+
action: str
1317+
1318+
1319+
@dataclass
1320+
class CheckWeightsReqOutput(BaseReq):
1321+
success: bool
1322+
message: str
1323+
1324+
13141325
@dataclass
13151326
class SlowDownReqInput(BaseReq):
13161327
forward_sleep_time: Optional[float]

python/sglang/srt/managers/scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
BaseReq,
7272
BatchTokenizedEmbeddingReqInput,
7373
BatchTokenizedGenerateReqInput,
74+
CheckWeightsReqInput,
7475
ClearHiCacheReqInput,
7576
ClearHiCacheReqOutput,
7677
CloseSessionReqInput,
@@ -568,6 +569,7 @@ def __init__(
568569
(GetWeightsByNameReqInput, self.get_weights_by_name),
569570
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
570571
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
572+
(CheckWeightsReqInput, self.check_weights),
571573
(SlowDownReqInput, self.slow_down),
572574
(ProfileReq, self.profile),
573575
(FreezeGCReq, self.handle_freeze_gc),

python/sglang/srt/managers/scheduler_update_weights_mixin.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import traceback
45
from typing import TYPE_CHECKING, Tuple
56

67
import torch
@@ -12,6 +13,8 @@
1213
GPU_MEMORY_TYPE_WEIGHTS,
1314
)
1415
from sglang.srt.managers.io_struct import (
16+
CheckWeightsReqInput,
17+
CheckWeightsReqOutput,
1518
DestroyWeightsUpdateGroupReqInput,
1619
DestroyWeightsUpdateGroupReqOutput,
1720
GetWeightsByNameReqInput,
@@ -166,6 +169,15 @@ def resume_memory_occupation(
166169

167170
return ResumeMemoryOccupationReqOutput()
168171

172+
def check_weights(self: Scheduler, recv_req: CheckWeightsReqInput):
173+
try:
174+
self.tp_worker.model_runner.check_weights(action=recv_req.action)
175+
return CheckWeightsReqOutput(success=True, message="Success.")
176+
except Exception as e:
177+
logger.warning(f"check_weights see error: {e}")
178+
traceback.print_exc()
179+
return CheckWeightsReqOutput(success=False, message=f"{e}")
180+
169181
def save_remote_model(self: Scheduler, params):
170182
url = params["url"]
171183

python/sglang/srt/managers/tokenizer_communicator_mixin.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import zmq
2323

2424
from sglang.srt.managers.io_struct import (
25+
CheckWeightsReqInput,
26+
CheckWeightsReqOutput,
2527
ClearHiCacheReqInput,
2628
ClearHiCacheReqOutput,
2729
CloseSessionReqInput,
@@ -183,6 +185,9 @@ def init_communicators(self: TokenizerManager, server_args: ServerArgs):
183185
self.resume_memory_occupation_communicator = _Communicator(
184186
self.send_to_scheduler, server_args.dp_size
185187
)
188+
self.check_weights_communicator = _Communicator(
189+
self.send_to_scheduler, server_args.dp_size
190+
)
186191
self.slow_down_communicator = _Communicator(
187192
self.send_to_scheduler, server_args.dp_size
188193
)
@@ -256,6 +261,10 @@ def _get_communicator_dispatcher(self: TokenizerManager):
256261
ResumeMemoryOccupationReqOutput,
257262
self.resume_memory_occupation_communicator.handle_recv,
258263
),
264+
(
265+
CheckWeightsReqOutput,
266+
self.check_weights_communicator.handle_recv,
267+
),
259268
(
260269
SlowDownReqOutput,
261270
self.slow_down_communicator.handle_recv,
@@ -670,6 +679,15 @@ async def resume_memory_occupation(
670679
self.auto_create_handle_loop()
671680
await self.resume_memory_occupation_communicator(obj)
672681

682+
async def check_weights(
683+
self: TokenizerManager,
684+
obj: CheckWeightsReqInput,
685+
request: Optional[fastapi.Request] = None,
686+
) -> CheckWeightsReqOutput:
687+
self.auto_create_handle_loop()
688+
results = await self.check_weights_communicator(obj)
689+
return _Communicator.merge_results(results)
690+
673691
async def slow_down(
674692
self: TokenizerManager,
675693
obj: SlowDownReqInput,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@
170170
)
171171
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
172172
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
173+
from sglang.srt.utils.weight_checker import WeightChecker
173174
from sglang.srt.weight_sync.tensor_bucket import (
174175
FlattenedTensorBucket,
175176
FlattenedTensorMetadata,
@@ -328,6 +329,8 @@ def __init__(
328329
# CPU offload
329330
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
330331

332+
self._weight_checker = WeightChecker(model_runner=self)
333+
331334
if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
332335
slow_rank_detector.execute()
333336
# Init mindspore running environment when model impl is "mindspore"
@@ -2508,6 +2511,9 @@ def save_sharded_model(
25082511
)
25092512
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
25102513

2514+
def check_weights(self, action: str):
2515+
self._weight_checker.handle(action=action)
2516+
25112517
def update_weights_from_ipc(self, recv_req):
25122518
"""Update weights from IPC for checkpoint-engine integration."""
25132519
try:
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import logging
2+
from typing import Dict
3+
4+
import torch
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class WeightChecker:
10+
def __init__(self, model_runner):
11+
self._model_runner = model_runner
12+
self._snapshot_tensors = None
13+
14+
def handle(self, action: str):
15+
logger.info(f"[WeightChecker] handle action={action}")
16+
if action == "snapshot":
17+
self._snapshot()
18+
elif action == "reset_tensors":
19+
self._reset_tensors()
20+
elif action == "compare":
21+
self._compare()
22+
else:
23+
raise Exception(f"Unsupported {action=}")
24+
25+
def _snapshot(self):
26+
named_tensors = [
27+
(name, param.data.detach().cpu()) for name, param in self._model_state()
28+
]
29+
self._snapshot_tensors = dict(named_tensors)
30+
assert len(self._snapshot_tensors) == len(
31+
named_tensors
32+
), f"should not have duplicated tensor name"
33+
34+
def _reset_tensors(self):
35+
for name, param in self._model_state():
36+
param.copy_(_random_like(param))
37+
38+
def _compare(self):
39+
assert self._snapshot_tensors is not None
40+
41+
_check_tensors(
42+
expect_tensors=self._snapshot_tensors,
43+
actual_tensors=dict(self._model_state()),
44+
)
45+
46+
def _model_state(self):
47+
# TODO: support EAGLE etc (e.g. yield from both main model and draft model)
48+
yield from self._model_runner.model.named_parameters()
49+
yield from self._model_runner.model.named_buffers()
50+
51+
52+
def _check_tensors(
53+
expect_tensors: Dict[str, torch.Tensor], actual_tensors: Dict[str, torch.Tensor]
54+
):
55+
from sglang.srt.debug_utils.dumper import get_tensor_info
56+
57+
assert len(expect_tensors) == len(actual_tensors)
58+
59+
good_names = []
60+
error_messages = []
61+
62+
for name in expect_tensors:
63+
expect = expect_tensors[name].cuda()
64+
actual = actual_tensors[name].cuda()
65+
66+
if torch.all(expect == actual):
67+
good_names.append(name)
68+
else:
69+
abs_diff = (actual.float() - expect.float()).abs()
70+
error_messages.append(
71+
f"name={name} "
72+
f"max_abs_err={abs_diff.max()} "
73+
f"mean_abs_err={abs_diff.mean()} "
74+
f"{get_tensor_info(expect)=} "
75+
f"{get_tensor_info(actual)=} "
76+
)
77+
78+
logger.info(f"[check_tensors] passed: {good_names}")
79+
if len(error_messages) > 0:
80+
raise Exception(f"check tensor equality failed:\n" + "\n".join(error_messages))
81+
82+
83+
def _random_like(t: torch.Tensor):
84+
device = t.device
85+
shape = t.shape
86+
dtype = t.dtype
87+
88+
if dtype.is_floating_point:
89+
return torch.rand(shape, device=device, dtype=torch.float32).to(dtype)
90+
91+
if dtype == torch.bool:
92+
return torch.rand(shape, device=device) > 0.5
93+
94+
info = torch.iinfo(dtype)
95+
return torch.randint(
96+
low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype
97+
)

0 commit comments

Comments
 (0)