|
45 | 45 | logger = logging.getLogger(__name__) |
46 | 46 |
|
47 | 47 |
|
| 48 | +if _is_npu: |
| 49 | + import torch_npu |
| 50 | + |
| 51 | + |
48 | 52 | class DeepEPMoE(FusedMoE): |
49 | 53 | """ |
50 | 54 | MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) |
@@ -411,9 +415,142 @@ def npu_fused_moe_without_routing_weights_bf16( |
411 | 415 | return hidden_states |
412 | 416 |
|
413 | 417 |
|
| 418 | +class NpuFuseEPMoE(DeepEPMoE): |
| 419 | + def __init__( |
| 420 | + self, |
| 421 | + num_experts: int, |
| 422 | + top_k: int, |
| 423 | + hidden_size: int, |
| 424 | + intermediate_size: int, |
| 425 | + layer_id: int, |
| 426 | + num_fused_shared_experts: int = 0, |
| 427 | + params_dtype: Optional[torch.dtype] = None, |
| 428 | + quant_config: Optional[QuantizationConfig] = None, |
| 429 | + prefix: str = "", |
| 430 | + activation: str = "silu", |
| 431 | + routed_scaling_factor: Optional[float] = None, |
| 432 | + ): |
| 433 | + super().__init__( |
| 434 | + num_experts=num_experts, |
| 435 | + top_k=top_k, |
| 436 | + hidden_size=hidden_size, |
| 437 | + intermediate_size=intermediate_size, |
| 438 | + layer_id=layer_id, |
| 439 | + num_fused_shared_experts=num_fused_shared_experts, |
| 440 | + params_dtype=params_dtype, |
| 441 | + quant_config=quant_config, |
| 442 | + prefix=prefix, |
| 443 | + activation=activation, |
| 444 | + routed_scaling_factor=routed_scaling_factor, |
| 445 | + ) |
| 446 | + |
| 447 | + self.quant_method.process_weights_after_loading = ( |
| 448 | + self._process_weights_after_loading |
| 449 | + ) |
| 450 | + |
| 451 | + def forward( |
| 452 | + self, |
| 453 | + hidden_states: torch.Tensor, |
| 454 | + topk_output: TopKOutput, |
| 455 | + forward_shared_experts=None, |
| 456 | + alt_stream=None, |
| 457 | + disable_sbo=False, |
| 458 | + ): |
| 459 | + return self.dispatcher.dispatch( |
| 460 | + hidden_states=hidden_states, |
| 461 | + topk_output=topk_output, |
| 462 | + gmm1_permuted_weight=self.w13_weight, |
| 463 | + gmm1_permuted_weight_scale=self.w13_weight_scale, |
| 464 | + gmm2_weight=self.w2_weight, |
| 465 | + gmm2_weight_scale=self.w2_weight_scale, |
| 466 | + ).hidden_state |
| 467 | + |
| 468 | + def release_weight_cache(self, weight: torch.Tensor): |
| 469 | + # .contiguous() introduces additional memory overhead and needs to be released using resize_(0) |
| 470 | + origin_weight = weight.data.transpose(1, 2) |
| 471 | + new_weight = origin_weight.contiguous() |
| 472 | + origin_weight.untyped_storage().resize_(0) |
| 473 | + return new_weight |
| 474 | + |
| 475 | + def permute_w13_weight_scale(self, w: torch.Tensor, tile_n: int): |
| 476 | + if tile_n % 2 != 0: |
| 477 | + raise ValueError(f"tile_n must be even, got {tile_n}") |
| 478 | + |
| 479 | + *dims, n = w.shape |
| 480 | + if n % tile_n != 0: |
| 481 | + raise ValueError(f"Last dimension {n} must be divisible by tile_n {tile_n}") |
| 482 | + |
| 483 | + w_reshaped = w.reshape(*dims, 2, n // tile_n, tile_n // 2) |
| 484 | + |
| 485 | + # Permute the last two dimensions. |
| 486 | + perm_order = list(range(len(dims))) + [-2, -3, -1] |
| 487 | + w_permuted = w_reshaped.permute(perm_order) |
| 488 | + |
| 489 | + return w_permuted.reshape(*dims, n) |
| 490 | + |
| 491 | + def reshape_w13_weight(self, weight: torch.Tensor, dim: int, chunk_size: int = 64): |
| 492 | + # Achieving greater computing power through reshape on Ascend. |
| 493 | + original_shape = weight.shape |
| 494 | + if dim < 0: |
| 495 | + dim += len(original_shape) |
| 496 | + |
| 497 | + if original_shape[dim] % (2 * chunk_size) != 0: |
| 498 | + raise ValueError( |
| 499 | + f"Dimension {dim} size {original_shape[dim]} must be divisible by {2 * chunk_size}" |
| 500 | + ) |
| 501 | + |
| 502 | + new_shape = ( |
| 503 | + *original_shape[:dim], |
| 504 | + 2, |
| 505 | + original_shape[dim] // (2 * chunk_size), |
| 506 | + chunk_size, |
| 507 | + *original_shape[dim + 1 :], |
| 508 | + ) |
| 509 | + |
| 510 | + weight = weight.view(new_shape) |
| 511 | + weight = weight.transpose(dim, dim + 1).contiguous() |
| 512 | + |
| 513 | + return weight.view(*original_shape[:dim], -1, *original_shape[dim + 1 :]) |
| 514 | + |
| 515 | + def _process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 516 | + w13 = self.release_weight_cache(layer.w13_weight) |
| 517 | + torch_npu.npu_format_cast_(w13, 2) |
| 518 | + cpu_w13 = w13.cpu() |
| 519 | + w13 = self.reshape_w13_weight(cpu_w13, -1).npu() |
| 520 | + torch_npu.npu_format_cast_(w13, 29) |
| 521 | + layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) |
| 522 | + |
| 523 | + w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29) |
| 524 | + layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) |
| 525 | + |
| 526 | + w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous() |
| 527 | + w13_scale = self.permute_w13_weight_scale(w13_scale, 128) |
| 528 | + layer.w13_weight_scale = torch.nn.Parameter( |
| 529 | + w13_scale.to(torch.float32), requires_grad=False |
| 530 | + ) |
| 531 | + |
| 532 | + w2_scale = layer.w2_weight_scale.data.squeeze(-1).contiguous() |
| 533 | + layer.w2_weight_scale = torch.nn.Parameter( |
| 534 | + w2_scale.to(torch.float32), requires_grad=False |
| 535 | + ) |
| 536 | + |
| 537 | + if hasattr(layer, "w13_weight_offset"): |
| 538 | + layer.w13_weight_offset = torch.nn.Parameter( |
| 539 | + layer.w13_weight_offset.data.squeeze(-1).contiguous(), |
| 540 | + requires_grad=False, |
| 541 | + ) |
| 542 | + if hasattr(layer, "w2_weight_offset"): |
| 543 | + layer.w2_weight_offset = torch.nn.Parameter( |
| 544 | + layer.w2_weight_offset.data.squeeze(-1).contiguous(), |
| 545 | + requires_grad=False, |
| 546 | + ) |
| 547 | + |
| 548 | + |
414 | 549 | def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): |
415 | 550 | if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): |
416 | 551 | return DeepEPMoE |
| 552 | + if get_moe_a2a_backend().is_ascend_fuseep(): |
| 553 | + return NpuFuseEPMoE |
417 | 554 |
|
418 | 555 | # NEW: Direct FP4 detection (bypasses EP requirements) |
419 | 556 | # Check for FP4 quantization with TRTLLM flag, regardless of EP |
|
0 commit comments