[llvm] [SimplifyCFG] Add support for hoisting commutative instructions (PR #104805)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 19 08:45:46 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Nikita Popov (nikic)
<details>
<summary>Changes</summary>
This extends SimplifyCFG hoisting to also hoist instructions with commuted operands, for example a+b on one side and b+a on the other side.
This should address the issue mentioned in:
https://github.com/llvm/llvm-project/pull/91185#issuecomment-2097447927
---
Full diff: https://github.com/llvm/llvm-project/pull/104805.diff
2 Files Affected:
- (modified) llvm/lib/Transforms/Utils/SimplifyCFG.cpp (+21-1)
- (modified) llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll (+10-25)
``````````diff
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index ebdf760bda7f1..00efd3c0eb72e 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1583,6 +1583,26 @@ static void hoistLockstepIdenticalDbgVariableRecords(
}
}
+static bool areIdenticalUpToCommutativity(const Instruction *I1,
+ const Instruction *I2) {
+ if (I1->isIdenticalToWhenDefined(I2))
+ return true;
+
+ if (auto *Cmp1 = dyn_cast<CmpInst>(I1))
+ if (auto *Cmp2 = dyn_cast<CmpInst>(I2))
+ return Cmp1->getPredicate() == Cmp2->getSwappedPredicate() &&
+ Cmp1->getOperand(0) == Cmp2->getOperand(1) &&
+ Cmp1->getOperand(1) == Cmp2->getOperand(0);
+
+ if (I1->isCommutative() && I1->isSameOperationAs(I2)) {
+ return I1->getOperand(0) == I2->getOperand(1) &&
+ I1->getOperand(1) == I2->getOperand(0) &&
+ equal(drop_begin(I1->operands(), 2), drop_begin(I2->operands(), 2));
+ }
+
+ return false;
+}
+
/// Hoist any common code in the successor blocks up into the block. This
/// function guarantees that BB dominates all successors. If EqTermsOnly is
/// given, only perform hoisting in case both blocks only contain a terminator.
@@ -1676,7 +1696,7 @@ bool SimplifyCFGOpt::hoistCommonCodeFromSuccessors(BasicBlock *BB,
for (auto &SuccIter : OtherSuccIterRange) {
Instruction *I2 = &*SuccIter;
HasTerminator |= I2->isTerminator();
- if (AllInstsAreIdentical && (!I1->isIdenticalToWhenDefined(I2) ||
+ if (AllInstsAreIdentical && (!areIdenticalUpToCommutativity(I1, I2) ||
MMRAMetadata(*I1) != MMRAMetadata(*I2)))
AllInstsAreIdentical = false;
}
diff --git a/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll b/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll
index e2705dc95ab97..8ce94d1cf5b4e 100644
--- a/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll
+++ b/llvm/test/Transforms/SimplifyCFG/hoist-common-code.ll
@@ -159,16 +159,14 @@ declare void @foo()
define i1 @test_icmp_simple(i1 %c, i32 %a, i32 %b) {
; CHECK-LABEL: @test_icmp_simple(
+; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
; CHECK: common.ret:
-; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i1 [ [[CMP1:%.*]], [[IF]] ], [ [[CMP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT: ret i1 [[COMMON_RET_OP]]
+; CHECK-NEXT: ret i1 [[CMP1]]
; CHECK: if:
-; CHECK-NEXT: [[CMP1]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
-; CHECK-NEXT: [[CMP2]] = icmp ugt i32 [[B]], [[A]]
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[COMMON_RET]]
;
@@ -187,13 +185,8 @@ else:
define void @test_icmp_complex(i1 %c, i32 %a, i32 %b) {
; CHECK-LABEL: @test_icmp_complex(
-; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
-; CHECK: if:
; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: br i1 [[CMP1]], label [[IF2:%.*]], label [[ELSE2:%.*]]
-; CHECK: else:
-; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[B]], [[A]]
-; CHECK-NEXT: br i1 [[CMP2]], label [[IF2]], label [[ELSE2]]
; CHECK: common.ret:
; CHECK-NEXT: ret void
; CHECK: if2:
@@ -280,16 +273,14 @@ else:
define i32 @test_binop(i1 %c, i32 %a, i32 %b) {
; CHECK-LABEL: @test_binop(
+; CHECK-NEXT: [[OP1:%.*]] = add i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
; CHECK: common.ret:
-; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
+; CHECK-NEXT: ret i32 [[OP1]]
; CHECK: if:
-; CHECK-NEXT: [[OP1]] = add i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
-; CHECK-NEXT: [[OP2]] = add i32 [[B]], [[A]]
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[COMMON_RET]]
;
@@ -308,16 +299,14 @@ else:
define i32 @test_binop_flags(i1 %c, i32 %a, i32 %b) {
; CHECK-LABEL: @test_binop_flags(
+; CHECK-NEXT: [[OP1:%.*]] = add nsw i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
; CHECK: common.ret:
-; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
+; CHECK-NEXT: ret i32 [[OP1]]
; CHECK: if:
-; CHECK-NEXT: [[OP1]] = add nuw nsw i32 [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
-; CHECK-NEXT: [[OP2]] = add nsw i32 [[B]], [[A]]
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[COMMON_RET]]
;
@@ -392,16 +381,14 @@ else:
define i32 @test_intrin(i1 %c, i32 %a, i32 %b) {
; CHECK-LABEL: @test_intrin(
+; CHECK-NEXT: [[OP1:%.*]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
; CHECK: common.ret:
-; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi i32 [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT: ret i32 [[COMMON_RET_OP]]
+; CHECK-NEXT: ret i32 [[OP1]]
; CHECK: if:
-; CHECK-NEXT: [[OP1]] = call i32 @llvm.umin.i32(i32 [[A:%.*]], i32 [[B:%.*]])
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
-; CHECK-NEXT: [[OP2]] = call i32 @llvm.umin.i32(i32 [[B]], i32 [[A]])
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[COMMON_RET]]
;
@@ -448,16 +435,14 @@ else:
define float @test_intrin_3arg(i1 %c, float %a, float %b, float %d) {
; CHECK-LABEL: @test_intrin_3arg(
+; CHECK-NEXT: [[OP1:%.*]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[D:%.*]])
; CHECK-NEXT: br i1 [[C:%.*]], label [[IF:%.*]], label [[ELSE:%.*]]
; CHECK: common.ret:
-; CHECK-NEXT: [[COMMON_RET_OP:%.*]] = phi float [ [[OP1:%.*]], [[IF]] ], [ [[OP2:%.*]], [[ELSE]] ]
-; CHECK-NEXT: ret float [[COMMON_RET_OP]]
+; CHECK-NEXT: ret float [[OP1]]
; CHECK: if:
-; CHECK-NEXT: [[OP1]] = call float @llvm.fma.f32(float [[A:%.*]], float [[B:%.*]], float [[D:%.*]])
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: br label [[COMMON_RET:%.*]]
; CHECK: else:
-; CHECK-NEXT: [[OP2]] = call float @llvm.fma.f32(float [[B]], float [[A]], float [[D]])
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: br label [[COMMON_RET]]
;
``````````
</details>
https://github.com/llvm/llvm-project/pull/104805
More information about the llvm-commits
mailing list