[llvm] [SimplifyCFG] Add support for hoisting commutative instructions (PR #104805)

via llvm-commits llvm-commits at lists.llvm.org
Mon Aug 19 08:45:46 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

This extends SimplifyCFG hoisting to also hoist instructions with commuted operands, for example a+b on one side and b+a on the other side.

This should address the issue mentioned in:
https://github.com/llvm/llvm-project/pull/91185#issuecomment-2097447927

---
Full diff: https://github.com/llvm/llvm-project/pull/104805.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+21-1) 
- (modified) llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll (+10-25) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index ebdf760bda7f1..00efd3c0eb72e 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1583,6 +1583,26 @@ static void hoistLockstepIdenticalDbgVariableRecords(
   }
 }
 
+static bool areIdenticalUpToCommutativity(const Instruction *I1,
+                                          const Instruction *I2) {
+  if (I1->isIdenticalToWhenDefined(I2))
+    return true;
+
+  if (auto *Cmp1 = dyn_cast<CmpInst>(I1))
+    if (auto *Cmp2 = dyn_cast<CmpInst>(I2))
+      return Cmp1->getPredicate() == Cmp2->getSwappedPredicate() &&
+             Cmp1->getOperand(0) == Cmp2->getOperand(1) &&
+             Cmp1->getOperand(1) == Cmp2->getOperand(0);
+
+  if (I1->isCommutative() && I1->isSameOperationAs(I2)) {
+    return I1->getOperand(0) == I2->getOperand(1) &&
+           I1->getOperand(1) == I2->getOperand(0) &&
+           equal(drop_begin(I1->operands(), 2), drop_begin(I2->operands(), 2));
+  }
+
+  return false;
+}
+
 /// Hoist any common code in the successor blocks up into the block. This
 /// function guarantees that BB dominates all successors. If EqTermsOnly is
 /// given, only perform hoisting in case both blocks only contain a terminator.
@@ -1676,7 +1696,7 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB,
     for (auto &SuccIter : OtherSuccIterRange) {
       Instruction *I2 = &*SuccIter;
       HasTerminator |= I2->isTerminator();
-      if (AllInstsAreIdentical && (!I1->isIdenticalToWhenDefined(I2) ||
+      if (AllInstsAreIdentical && (!areIdenticalUpToCommutativity(I1, I2) ||
                                    MMRAMetadata(*I1) != MMRAMetadata(*I2)))
         AllInstsAreIdentical = false;
     }
diff --git a/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll b/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll
index e2705dc95ab97..8ce94d1cf5b4e 100644
--- a/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll
+++ b/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll
@@ -159,16 +159,14 @@ declare void @foo()
 
 define i1 @test_icmp_simple(i1 %c, i32 %a, i32 %b) {
 ; CHECK-LABEL: @test_icmp_simple(
+; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       common.ret:
-; CHECK-NEXT:    [[COMMON_RET_OP:%.*]] = phi i1 [ [[CMP1:%.*]], [[IF]] ], [ [[CMP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT:    ret i1 [[COMMON_RET_OP]]
+; CHECK-NEXT:    ret i1 [[CMP1]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[CMP1]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label [[COMMON_RET:%.*]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[CMP2]] = icmp ugt i32 [[B]], [[A]]
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[COMMON_RET]]
 ;
@@ -187,13 +185,8 @@ else:
 
 define void @test_icmp_complex(i1 %c, i32 %a, i32 %b) {
 ; CHECK-LABEL: @test_icmp_complex(
-; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
-; CHECK:       if:
 ; CHECK-NEXT:    [[CMP1:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    br i1 [[CMP1]], label [[IF2:%.*]], label [[ELSE2:%.*]]
-; CHECK:       else:
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp ugt i32 [[B]], [[A]]
-; CHECK-NEXT:    br i1 [[CMP2]], label [[IF2]], label [[ELSE2]]
 ; CHECK:       common.ret:
 ; CHECK-NEXT:    ret void
 ; CHECK:       if2:
@@ -280,16 +273,14 @@ else:
 
 define i32 @test_binop(i1 %c, i32 %a, i32 %b) {
 ; CHECK-LABEL: @test_binop(
+; CHECK-NEXT:    [[OP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       common.ret:
-; CHECK-NEXT:    [[COMMON_RET_OP:%.*]] = phi i32 [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT:    ret i32 [[COMMON_RET_OP]]
+; CHECK-NEXT:    ret i32 [[OP1]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OP1]] = add i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label [[COMMON_RET:%.*]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[OP2]] = add i32 [[B]], [[A]]
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[COMMON_RET]]
 ;
@@ -308,16 +299,14 @@ else:
 
 define i32 @test_binop_flags(i1 %c, i32 %a, i32 %b) {
 ; CHECK-LABEL: @test_binop_flags(
+; CHECK-NEXT:    [[OP1:%.*]] = add nsw i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       common.ret:
-; CHECK-NEXT:    [[COMMON_RET_OP:%.*]] = phi i32 [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT:    ret i32 [[COMMON_RET_OP]]
+; CHECK-NEXT:    ret i32 [[OP1]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OP1]] = add nuw nsw i32 [[A:%.*]], [[B:%.*]]
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label [[COMMON_RET:%.*]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[OP2]] = add nsw i32 [[B]], [[A]]
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[COMMON_RET]]
 ;
@@ -392,16 +381,14 @@ else:
 
 define i32 @test_intrin(i1 %c, i32 %a, i32 %b) {
 ; CHECK-LABEL: @test_intrin(
+; CHECK-NEXT:    [[OP1:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       common.ret:
-; CHECK-NEXT:    [[COMMON_RET_OP:%.*]] = phi i32 [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT:    ret i32 [[COMMON_RET_OP]]
+; CHECK-NEXT:    ret i32 [[OP1]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OP1]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 [[B:%.*]])
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label [[COMMON_RET:%.*]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[OP2]] = call i32 @llvm.umin.i32(i32 [[B]], i32 [[A]])
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[COMMON_RET]]
 ;
@@ -448,16 +435,14 @@ else:
 
 define float @test_intrin_3arg(i1 %c, float %a, float %b, float %d) {
 ; CHECK-LABEL: @test_intrin_3arg(
+; CHECK-NEXT:    [[OP1:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[D:%.*]])
 ; CHECK-NEXT:    br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
 ; CHECK:       common.ret:
-; CHECK-NEXT:    [[COMMON_RET_OP:%.*]] = phi float [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT:    ret float [[COMMON_RET_OP]]
+; CHECK-NEXT:    ret float [[OP1]]
 ; CHECK:       if:
-; CHECK-NEXT:    [[OP1]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[D:%.*]])
 ; CHECK-NEXT:    call void @foo()
 ; CHECK-NEXT:    br label [[COMMON_RET:%.*]]
 ; CHECK:       else:
-; CHECK-NEXT:    [[OP2]] = call float @llvm.fma.f32(float [[B]], float [[A]], float [[D]])
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[COMMON_RET]]
 ;

``````````

</details>


https://github.com/llvm/llvm-project/pull/104805


More information about the llvm-commits mailing list