Skip to content

perf[next-dace]: Allow more fusion of ConditionalBlocks#2517

Open
iomaganaris wants to merge 8 commits intomainfrom
allow_more_conditionals_fusion
Open

perf[next-dace]: Allow more fusion of ConditionalBlocks#2517
iomaganaris wants to merge 8 commits intomainfrom
allow_more_conditionals_fusion

Conversation

@iomaganaris
Copy link
Contributor

@iomaganaris iomaganaris commented Mar 10, 2026

This PR allows more fusion of ConditionalBlocks in 2 ways:

  • Allows MoveDataflowIntoIfBody to move dataflow inside the ConditionalBlock States even if the number of outputs is not the same in the States of the ConditionalBlock is different
  • Allows FuseHorizontalConditionalBlocks to fuse ConditionalBlocks if they have different number of states and does the matching of the branches based on their conditions

These changes are necessary to fuse the ConditionalBlocks of the first kernel of the graupel code while removing the copies of the false branches of the ConditionalBlocks of the output AccessNodes

Comment on lines 129 to 155
@@ -140,11 +140,17 @@ def can_be_applied(
fused_conditional_block_state_names = [
state.name for state in fused_conditional_block.all_states()
]
# Allow the states of conditional blocks to have either "true_branch" or "false_branch" in their name. This check is related to the function `_find_corresponding_state_in_second` below
# TODO(iomaganaris): Raise this restriction if there's any need for that. Currently where statements generate only true/false branchs so this check is sufficient
if not (
any("true_branch" in name for name in extended_conditional_block_state_names)
and any("false_branch" in name for name in extended_conditional_block_state_names)
and any("true_branch" in name for name in fused_conditional_block_state_names)
and any("false_branch" in name for name in fused_conditional_block_state_names)
all(
"true_branch" in state_name or "false_branch" in state_name
for state_name in extended_conditional_block_state_names
)
and all(
"true_branch" in state_name or "false_branch" in state_name
for state_name in fused_conditional_block_state_names
)
):
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the generic check to merge a conditional block CB1 into CB2 is that:

  • number of subregions (branches) in CB1 <= number of subregions in CB2
  • all conditions used for the subregions in CB1 exist in in CB2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first check is handled by _order_conditional_blocks_based_on_number_of_states_and_label.
The second check is difficult to figure out because state names can be like the attached picture. That's why I restrict to true and false branches for now

Image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. _order_conditional_blocks_based_on_number_of_states_and_label counts the number of states, not the number of branches. There is a fundamental problem here, that you assume branch ~= state.
As for the true and false branch, there is not such thing as true and false branch in the SDFG representation. Each branch has a symbolic condition which can be evaluated to True/False. Optionally, there is a branch with condition=None which behaves as else-branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, this comment applies to the code in baseline, not to the change in this PR. The current implementation will work anyway, but it is tightly dependent on the way the SDFG is generated (e.g. state names).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. _order_conditional_blocks_based_on_number_of_states_and_label counts the number of states, not the number of branches. There is a fundamental problem here, that you assume branch ~= state.

You're right. I assume that without checking it. I should add a check for this assumption otherwise things will get a lot more complicated and it's something we haven't come across yet.

Each branch has a symbolic condition which can be evaluated to True/False. Optionally, there is a branch with condition=None which behaves as else-branch.

Would it be safer to distinguish the matching branches/states based on the condition? Something like:

        def _find_corresponding_branch_in_fused(
            fused_conditional_block: ConditionalBlock,
        ) -> dace.sdfg.state.ControlFlowRegion | None:
            extended_branch_condition = extended_branch[0].as_string # extended_branch[0] is `dace.properties.CodeBlock object`. `extended_branch_condition` will be `__cond` or `(not __cond)`
            for branch in fused_conditional_block.branches:
                 if branch[0].as_string == extended_branch_condition:
                     return branch
            return None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the matching logic to only compare the conditions of the branches, which I expect they will be the same between ConditionalBlocks that have the same condition. It could be that the condition would have a different name or have the same meaning but is written differently but I hope that's not the case from the lowering 🙈

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants