@@ -480,30 +480,22 @@ def load_lora_weight_tensor(
480480 load_lora_weight_tensor (buffer_view , weights )
481481
482482 def get_tensor (
483- self , target_module : str , layer_id : int , lora_type : LoRAType , context : str = None
483+ self , target_module : str , layer_id : int , lora_type : LoRAType
484484 ) -> torch .Tensor :
485485 """
486486 Get LoRA tensor buffer (automatically handles both 3D and 4D tensors).
487487
488488 Args:
489- target_module: Target module name (e.g., 'gate_up_proj')
489+ target_module: Target module name (e.g., 'gate_up_proj' or 'gate_up_proj_moe' for MoE )
490490 layer_id: Layer index
491491 lora_type: LoRAType.LORA_A or LoRAType.LORA_B
492- context: Optional context hint ('moe' or None for auto-detect)
493492
494493 Returns:
495494 - 3D tensor [num_loras, rank, hidden] for standard modules
496495 - 4D tensor [num_loras, num_experts, rank, hidden] for MoE modules
497496 """
498497 buffer_dict = self .A_buffer if lora_type == LoRAType .LORA_A else self .B_buffer
499-
500- # Handle context-specific buffer selection for ambiguous modules
501- ambiguous_modules = {"gate_up_proj" , "down_proj" }
502- if target_module in ambiguous_modules :
503- if context == "moe" and f"{ target_module } _moe" in buffer_dict :
504- return buffer_dict [f"{ target_module } _moe" ][layer_id ]
505-
506- # Fall back to original key for non-ambiguous modules
498+
507499 return buffer_dict [target_module ][layer_id ]
508500
509501 def get_buffer_id (self , lora_uid : str ):
0 commit comments