@@ -324,6 +324,35 @@ def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[
324324 return tensor_consumer_map
325325
326326
327+ def _is_following_cask_partition (
328+ node : Node , cask_partition_nodes : set [str ], max_depth : int = 10
329+ ) -> bool :
330+ """Check if a CASK fusible partition can be reached by traversing backward through copy ops.
331+
332+ Args:
333+ node: The node to check.
334+ cask_partition_nodes: Set of node names belonging to CASK partitions.
335+ max_depth: Maximum recursion depth to guard against pathological graphs.
336+
337+ Returns:
338+ True if the node belongs to or follows a CASK partition through copy ops.
339+ """
340+ if node .name in cask_partition_nodes :
341+ return True
342+
343+ if max_depth <= 0 or not is_copy_op (node .op ):
344+ return False
345+
346+ parent_nodes = get_parent_nodes (node )
347+ if len (parent_nodes ) == 0 :
348+ return False
349+
350+ return all (
351+ _is_following_cask_partition (parent , cask_partition_nodes , max_depth - 1 )
352+ for parent in parent_nodes
353+ )
354+
355+
327356def find_conv_to_layernorm_nodes (
328357 graph : Graph ,
329358 cask_fusible_partitions : list [list [Node ]],
@@ -342,77 +371,44 @@ def find_conv_to_layernorm_nodes(
342371 Returns:
343372 List of LayerNormalization nodes that consume CASK partition outputs.
344373 """
345- # Collect the output tensor names from CASK partitions
346- cask_output_tensor_names = set ()
374+ cask_partition_nodes : set [str ] = set ()
347375 for partition in cask_fusible_partitions :
348- last_node = partition [- 1 ]
349- for output_tensor in last_node .outputs :
350- cask_output_tensor_names .add (output_tensor .name )
376+ cask_partition_nodes .update (node .name for node in partition )
351377
352378 conv_to_ln_nodes = []
353379 for node in graph .nodes :
354380 if node .op != "LayerNormalization" :
355381 continue
356382
357- # Check if the first input (activation) comes from a CASK partition output
383+ # Check if the first input (activation) comes from a CASK partition
358384 # possibly through copy ops (Reshape, Transpose, etc.)
359- if _is_input_from_cask_partition (node .inputs [0 ], cask_output_tensor_names ):
360- conv_to_ln_nodes .append (node )
361- logger .debug (
362- f"Found Conv->LayerNorm pattern: LayerNorm node '{ node .name } ' "
363- f"consumes CASK partition output"
364- )
385+ inp_tensor = node .inputs [0 ]
386+ if inp_tensor .inputs :
387+ producer = inp_tensor .inputs [0 ]
388+ if _is_following_cask_partition (producer , cask_partition_nodes ):
389+ conv_to_ln_nodes .append (node )
390+ logger .debug (
391+ f"Found Conv->LayerNorm pattern: LayerNorm node '{ node .name } ' "
392+ f"consumes CASK partition output"
393+ )
365394
366395 logger .info (f"Found { len (conv_to_ln_nodes )} Conv->LayerNorm patterns to quantize" )
367396 return conv_to_ln_nodes
368397
369398
370- def _is_input_from_cask_partition (tensor : Tensor , cask_output_tensor_names : set [str ]) -> bool :
371- """Check if a tensor originates from a CASK partition output, traversing through copy ops."""
372- if tensor .name in cask_output_tensor_names :
373- return True
374-
375- # Traverse backward through copy ops (Reshape, Transpose, etc.)
376- if tensor .inputs :
377- producer = tensor .inputs [0 ]
378- if is_copy_op (producer .op ):
379- for inp in producer .inputs :
380- if not isinstance (inp , Constant ) and _is_input_from_cask_partition (
381- inp , cask_output_tensor_names
382- ):
383- return True
384-
385- return False
386-
387-
388399def filter_quantizable_kgen_heads (
389400 cask_fusible_partitions : list [list [Node ]],
390401 kgen_partitions : list [list [Node ]],
391402 quantizable_op_types : list [str ],
392403 graph : Graph ,
393404) -> tuple [list [Node ], list [tuple [Node , Node , str ]]]:
394405 """Returns the list of kgen head names if it follows a CASK partition."""
395- cask_partition_nodes = set ()
406+ cask_partition_nodes : set [ str ] = set ()
396407 for partition in cask_fusible_partitions :
397- cask_partition_nodes .update ([ node .name for node in partition ] )
408+ cask_partition_nodes .update (node .name for node in partition )
398409
399410 cask_partition_heads = [partition [0 ] for partition in cask_fusible_partitions ]
400411
401- def _is_following_cask_partition (node : Node ):
402- # Checking if cask fusible partition can be reached backward
403- # ignoring the copy ops
404- if node .name in cask_partition_nodes :
405- return True
406-
407- if not is_copy_op (node .op ):
408- return False
409-
410- parent_nodes = get_parent_nodes (node )
411- if len (parent_nodes ) == 0 :
412- return False
413-
414- return all (_is_following_cask_partition (parent ) for parent in parent_nodes )
415-
416412 def _is_mha_epilogue_pattern (node : Node , graph : Graph ):
417413 if head_node .op != "Add" :
418414 return False
@@ -483,7 +479,10 @@ def _has_other_quantizable_consumer(
483479 # and decide which input of kgen head needs quantization
484480 for parent in head_parents :
485481 # If the head is consuming output of any quantizable op, then it is quantizable
486- if _is_following_cask_partition (parent ) or parent .op in output_quantization_candidates :
482+ if (
483+ _is_following_cask_partition (parent , cask_partition_nodes )
484+ or parent .op in output_quantization_candidates
485+ ):
487486 # The mask add of MHA should not be quantized
488487 if _is_mha_epilogue_pattern (head_node , graph ):
489488 no_quantize_inputs_of_head .append (
0 commit comments