Skip to content

Commit 52a1e8f

Browse files
committed
cp
1 parent aeac622 commit 52a1e8f

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

python/sglang/srt/layers/linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,8 @@ def weight_loader_v2(
736736

737737
if isinstance(param, BlockQuantScaleParameter):
738738
weight_block_size = self.quant_method.quant_config.weight_block_size
739-
block_n, _ = weight_block_size[0], weight_block_size[1]
739+
raw_block_n, _ = weight_block_size[0], weight_block_size[1]
740+
block_n = 1 if getattr(param, "format_ue8m0", False) else raw_block_n
740741
shard_offset = (
741742
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n
742743
) // self.tp_size
@@ -966,7 +967,8 @@ def weight_loader_v2(
966967

967968
if isinstance(param, BlockQuantScaleParameter):
968969
weight_block_size = self.quant_method.quant_config.weight_block_size
969-
block_n, _ = weight_block_size[0], weight_block_size[1]
970+
raw_block_n, _ = weight_block_size[0], weight_block_size[1]
971+
block_n = 1 if getattr(param, "format_ue8m0", False) else raw_block_n
970972
shard_offset = (shard_offset + block_n - 1) // block_n
971973
shard_size = (shard_size + block_n - 1) // block_n
972974

python/sglang/srt/layers/quantization/fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def create_weights(
262262
layer.input_size_per_partition = input_size_per_partition
263263
layer.output_size_per_partition = output_size_per_partition
264264
layer.orig_dtype = params_dtype
265-
layer.executed_weight_requant_ue8m0 = False
266265

267266
# WEIGHT
268267
weight_dtype = (
@@ -300,6 +299,7 @@ def create_weights(
300299
output_dim=0,
301300
weight_loader=weight_loader,
302301
)
302+
scale.format_ue8m0 = False
303303
scale[:] = torch.finfo(torch.float32).min
304304
layer.register_parameter("weight_scale_inv", scale)
305305
else:
@@ -367,14 +367,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
367367
self.w8a8_block_fp8_linear
368368
is deepgemm_w8a8_block_fp8_linear_with_fallback
369369
)
370-
and (not layer.executed_weight_requant_ue8m0)
370+
and (not layer.weight_scale_inv.format_ue8m0)
371371
):
372372
requant_weight_ue8m0_inplace(
373373
layer.weight,
374374
layer.weight_scale_inv,
375375
self.quant_config.weight_block_size,
376376
)
377-
layer.executed_weight_requant_ue8m0 = True
377+
layer.weight_scale_inv.format_ue8m0 = True
378378
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
379379

380380
layer.weight.data = weight.data

0 commit comments

Comments
 (0)