Skip to content

Commit d30dc51

Browse files
jianglan89Nancheng-11
authored andcommitted
fix - fp8 per block load quant need skip [W.mla_kc, W.mla_vc]
1 parent 32db79d commit d30dc51

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

rtp_llm/model_loader/per_block_fp8_quant_weight.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from rtp_llm.model_loader.attn_weight import AttnAtomicWeight, MlaAttnAtomicWeight
99
from rtp_llm.model_loader.ffn_weight import FfnAtomicWeight, MoeAtomicWeight
1010
from rtp_llm.model_loader.load_config import LoadConfig
11+
from rtp_llm.model_loader.tensor_source import TensorSource
1112
from rtp_llm.model_loader.weight_module import (
1213
AtomicWeight,
1314
CompositeWeight,
1415
QuantWeight,
1516
WeightModule,
1617
)
17-
from rtp_llm.model_loader.tensor_source import TensorSource
1818
from rtp_llm.utils.model_weight import (
1919
FP8_E4M3_MAX,
2020
CkptWeightInfo,
@@ -710,7 +710,7 @@ def support(
710710
):
711711
return False
712712
name = src_weight_info.name
713-
return name in cls.w8a8_weight_list
713+
return name in cls.w8a8_weight_list and name not in [W.mla_kc, W.mla_vc]
714714

715715
def __init__(
716716
self,
@@ -750,7 +750,9 @@ def _load_raw_tensor(
750750
device: str,
751751
load_config: LoadConfig,
752752
):
753-
kernel = self.kernel._load_raw_tensor(tensor_source, layer_id, device, load_config)
753+
kernel = self.kernel._load_raw_tensor(
754+
tensor_source, layer_id, device, load_config
755+
)
754756

755757
res = {}
756758
scale = None
@@ -774,7 +776,7 @@ def _load_raw_tensor(
774776
res.update({self.scale.name: scale.contiguous().to(device)})
775777

776778
return res
777-
779+
778780
def get_tensor_names(
779781
self, layer_id: Optional[int], load_config: LoadConfig
780782
) -> set[str]:

0 commit comments

Comments
 (0)