[llvm] goldsteinn/fold op into select (PR #84696)

via llvm-commits llvm-commits at lists.llvm.org
Sun Mar 10 14:56:39 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/84696

- **[InstCombine] Add more tests for folding rem/div/mul with select; NFC**
- **[InstCombine] Use the select condition to try to constant fold binops into select**


>From 6f597c5edf7b3cf5b975884afd77d9043b0858dd Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 10 Mar 2024 16:30:24 -0500
Subject: [PATCH 1/2] [InstCombine] Add more tests for folding rem/div/mul with
 select; NFC

---
 .../Transforms/InstCombine/binop-select.ll    | 72 +++++++++++++++++++
 1 file changed, 72 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll
index 6cd4132eadd77b..0e10de7180fec8 100644
--- a/llvm/test/Transforms/InstCombine/binop-select.ll
+++ b/llvm/test/Transforms/InstCombine/binop-select.ll
@@ -403,3 +403,75 @@ define i32 @ashr_sel_op1_use(i1 %b) {
   %r = ashr i32 -2, %s
   ret i32 %r
 }
+
+
+define i32 @test_mul_to_const_Cmul(i32 %x) {
+; CHECK-LABEL: @test_mul_to_const_Cmul(
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 14
+; CHECK-NEXT:    [[R:%.*]] = mul i32 [[COND]], [[X]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp eq i32 %x, 61
+  %cond = select i1 %c, i32 9, i32 14
+  %r = mul i32 %x, %cond
+  ret i32 %r
+}
+
+define i32 @test_mul_to_const_mul(i32 %x, i32 %y) {
+; CHECK-LABEL: @test_mul_to_const_mul(
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 [[Y:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = mul i32 [[COND]], [[X]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp eq i32 %x, 61
+  %cond = select i1 %c, i32 9, i32 %y
+  %r = mul i32 %x, %cond
+  ret i32 %r
+}
+
+
+define i32 @test_mul_to_const_Cmul_fail_multiuse(i32 %x, i32 %y) {
+; CHECK-LABEL: @test_mul_to_const_Cmul_fail_multiuse(
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 14
+; CHECK-NEXT:    [[R:%.*]] = mul i32 [[COND]], [[X]]
+; CHECK-NEXT:    call void @use(i32 [[COND]])
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp eq i32 %x, 61
+  %cond = select i1 %c, i32 9, i32 14
+  %r = mul i32 %x, %cond
+  call void @use(i32 %cond)
+  ret i32 %r
+}
+
+
+define i32 @test_div_to_const_div_fail_non_speculatable(i32 %x, i32 %y) {
+; CHECK-LABEL: @test_div_to_const_div_fail_non_speculatable(
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 [[Y:%.*]]
+; CHECK-NEXT:    [[R:%.*]] = udiv i32 [[X]], [[COND]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp eq i32 %x, 61
+  %cond = select i1 %c, i32 9, i32 %y
+  %r = udiv i32 %x, %cond
+  ret i32 %r
+}
+
+
+define i32 @test_div_to_const_Cdiv_todo(i32 %x, i32 %y) {
+; CHECK-LABEL: @test_div_to_const_Cdiv_todo(
+; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
+; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 14
+; CHECK-NEXT:    [[R:%.*]] = udiv i32 [[X]], [[COND]]
+; CHECK-NEXT:    ret i32 [[R]]
+;
+  %c = icmp eq i32 %x, 61
+  %cond = select i1 %c, i32 9, i32 14
+  %r = udiv i32 %x, %cond
+  ret i32 %r
+}
+

>From c330000ba52b18e1ae93d5d9e38142b4e8e1365a Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Sun, 10 Mar 2024 15:14:25 -0500
Subject: [PATCH 2/2] [InstCombine] Use the select condition to try to constant
 fold binops into select

The select condition may allow us to constant fold binops on
non-constant arms if the condition implies one of the binop operand is constant.
For example if we have:
```
%c = icmp eq i8 %y, 10
%s = select i1 %c, i8 123, i8 %x
%r = mul i8 %s, %y
```

We can replace substitate `10` in for `%y` on the true arm and do:
```
%c = icmp eq i8 %y, 10
%mul = mul i8 %x, %y
%r = select i1 %c, i8 1230, i8 %mul
```
---
 .../InstCombine/InstCombineInternal.h         |  5 +-
 .../InstCombine/InstructionCombining.cpp      | 48 +++++++++++++++++--
 .../Transforms/InstCombine/binop-select.ll    | 10 ++--
 .../Transforms/InstCombine/extractelement.ll  |  5 +-
 llvm/test/Transforms/InstCombine/pr72433.ll   |  3 +-
 llvm/test/Transforms/InstCombine/pr80597.ll   |  9 +---
 6 files changed, 57 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 6a1ef6edeb4077..c2558e0d882a71 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -610,7 +610,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
   Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
                                 bool FoldWithMultiUse = false);
 
-  /// This is a convenience wrapper function for the above two functions.
+  /// This is a convenience wrapper function for the above function.
+  Instruction *foldBinOpIntoSelect(BinaryOperator &I);
+
+  /// This is a convenience wrapper function for the above three functions.
   Instruction *foldBinOpIntoSelectOrPhi(BinaryOperator &I);
 
   Instruction *foldAddWithConstant(BinaryOperator &Add);
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 1a831805dc72a0..2ae2de1daf8830 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1951,14 +1951,54 @@ Instruction *InstCombinerImpl::foldBinopWithPhiOperands(BinaryOperator &BO) {
   return NewPhi;
 }
 
+// Return std::nullopt if we should not fold. Return true if we should fold
+// multi-use select and false for single-use select.
+static std::optional<bool> shouldFoldOpIntoSelect(BinaryOperator &I, Value *Op,
+                                                  Value *OpOther) {
+  if (isa<SelectInst>(Op))
+    // If we will be able to constant fold the incorperated binop, then
+    // multi-use. Otherwise single-use.
+    return match(OpOther, m_ImmConstant()) &&
+           match(Op, m_Select(m_Value(), m_ImmConstant(), m_ImmConstant()));
+
+  return std::nullopt;
+}
+
+Instruction *InstCombinerImpl::foldBinOpIntoSelect(BinaryOperator &I) {
+  std::optional<bool> CanSpeculativelyExecuteRes;
+  for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) {
+    // Slightly more involved logic for select. For select we use the condition
+    // to to infer information about the arm. This allows us to constant-fold
+    // even when the select arm(s) are not constant. For example if we have: `(X
+    // == 10 ? 19 : Y) * X`, we can entirely contant fold the true arm as `X ==
+    // 10` dominates it. So we end up with `X == 10 ? 190 : (X * Y))`.
+    if (auto MultiUse = shouldFoldOpIntoSelect(I, I.getOperand(OpIdx),
+                                               I.getOperand(1 - OpIdx))) {
+      if (!*MultiUse) {
+        if (!CanSpeculativelyExecuteRes.has_value()) {
+          const SimplifyQuery Q = SQ.getWithInstruction(&I);
+          CanSpeculativelyExecuteRes =
+              isSafeToSpeculativelyExecute(&I, Q.CxtI, Q.AC, Q.DT, &TLI);
+        }
+        if (!*CanSpeculativelyExecuteRes)
+          return nullptr;
+      }
+      if (Instruction *NewSel = FoldOpIntoSelect(
+              I, cast<SelectInst>(I.getOperand(OpIdx)), *MultiUse))
+        return NewSel;
+    }
+  }
+  return nullptr;
+}
+
 Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) {
+  if (auto *SI = foldBinOpIntoSelect(I))
+    return SI;
+
   if (!isa<Constant>(I.getOperand(1)))
     return nullptr;
 
-  if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {
-    if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))
-      return NewSel;
-  } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
+  if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
     if (Instruction *NewPhi = foldOpIntoPhi(I, PN))
       return NewPhi;
   }
diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll
index 0e10de7180fec8..2929edf85bd962 100644
--- a/llvm/test/Transforms/InstCombine/binop-select.ll
+++ b/llvm/test/Transforms/InstCombine/binop-select.ll
@@ -274,7 +274,7 @@ define i32 @and_sel_op0_use(i1 %b) {
 ; CHECK-LABEL: @and_sel_op0_use(
 ; CHECK-NEXT:    [[S:%.*]] = select i1 [[B:%.*]], i32 25, i32 0
 ; CHECK-NEXT:    call void @use(i32 [[S]])
-; CHECK-NEXT:    [[R:%.*]] = and i32 [[S]], 1
+; CHECK-NEXT:    [[R:%.*]] = zext i1 [[B]] to i32
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %s = select i1 %b, i32 25, i32 0
@@ -408,8 +408,8 @@ define i32 @ashr_sel_op1_use(i1 %b) {
 define i32 @test_mul_to_const_Cmul(i32 %x) {
 ; CHECK-LABEL: @test_mul_to_const_Cmul(
 ; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 14
-; CHECK-NEXT:    [[R:%.*]] = mul i32 [[COND]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[X]], 14
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 549, i32 [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %c = icmp eq i32 %x, 61
@@ -421,8 +421,8 @@ define i32 @test_mul_to_const_Cmul(i32 %x) {
 define i32 @test_mul_to_const_mul(i32 %x, i32 %y) {
 ; CHECK-LABEL: @test_mul_to_const_mul(
 ; CHECK-NEXT:    [[C:%.*]] = icmp eq i32 [[X:%.*]], 61
-; CHECK-NEXT:    [[COND:%.*]] = select i1 [[C]], i32 9, i32 [[Y:%.*]]
-; CHECK-NEXT:    [[R:%.*]] = mul i32 [[COND]], [[X]]
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i32 [[Y:%.*]], [[X]]
+; CHECK-NEXT:    [[R:%.*]] = select i1 [[C]], i32 549, i32 [[TMP1]]
 ; CHECK-NEXT:    ret i32 [[R]]
 ;
   %c = icmp eq i32 %x, 61
diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll
index bc5dd060a540ae..18a26623a90c1c 100644
--- a/llvm/test/Transforms/InstCombine/extractelement.ll
+++ b/llvm/test/Transforms/InstCombine/extractelement.ll
@@ -800,10 +800,7 @@ define i32 @extelt_vecselect_const_operand_vector(<3 x i1> %c) {
 
 define i32 @extelt_select_const_operand_extractelt_use(i1 %c) {
 ; ANY-LABEL: @extelt_select_const_operand_extractelt_use(
-; ANY-NEXT:    [[E:%.*]] = select i1 [[C:%.*]], i32 4, i32 7
-; ANY-NEXT:    [[M:%.*]] = shl nuw nsw i32 [[E]], 1
-; ANY-NEXT:    [[M_2:%.*]] = shl nuw nsw i32 [[E]], 2
-; ANY-NEXT:    [[R:%.*]] = mul nuw nsw i32 [[M]], [[M_2]]
+; ANY-NEXT:    [[R:%.*]] = select i1 [[C:%.*]], i32 128, i32 392
 ; ANY-NEXT:    ret i32 [[R]]
 ;
   %s = select i1 %c, <3 x i32> <i32 2, i32 3, i32 4>, <3 x i32> <i32 5, i32 6, i32 7>
diff --git a/llvm/test/Transforms/InstCombine/pr72433.ll b/llvm/test/Transforms/InstCombine/pr72433.ll
index c6e74582a13d30..1633885075e872 100644
--- a/llvm/test/Transforms/InstCombine/pr72433.ll
+++ b/llvm/test/Transforms/InstCombine/pr72433.ll
@@ -6,8 +6,7 @@ define i32 @widget(i32 %arg, i32 %arg1) {
 ; CHECK-SAME: i32 [[ARG:%.*]], i32 [[ARG1:%.*]]) {
 ; CHECK-NEXT:  bb:
 ; CHECK-NEXT:    [[ICMP:%.*]] = icmp ne i32 [[ARG]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = zext i1 [[ICMP]] to i32
-; CHECK-NEXT:    [[MUL:%.*]] = shl nuw nsw i32 20, [[TMP0]]
+; CHECK-NEXT:    [[MUL:%.*]] = select i1 [[ICMP]], i32 40, i32 20
 ; CHECK-NEXT:    [[XOR:%.*]] = zext i1 [[ICMP]] to i32
 ; CHECK-NEXT:    [[ADD9:%.*]] = or disjoint i32 [[MUL]], [[XOR]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = zext i1 [[ICMP]] to i32
diff --git a/llvm/test/Transforms/InstCombine/pr80597.ll b/llvm/test/Transforms/InstCombine/pr80597.ll
index 5feae4a06c45c0..bf536b9ecd133e 100644
--- a/llvm/test/Transforms/InstCombine/pr80597.ll
+++ b/llvm/test/Transforms/InstCombine/pr80597.ll
@@ -5,14 +5,9 @@ define i64 @pr80597(i1 %cond) {
 ; CHECK-LABEL: define i64 @pr80597(
 ; CHECK-SAME: i1 [[COND:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[ADD:%.*]] = select i1 [[COND]], i64 0, i64 -12884901888
-; CHECK-NEXT:    [[SEXT1:%.*]] = add nsw i64 [[ADD]], 8836839514384105472
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ult i64 [[SEXT1]], -34359738368
-; CHECK-NEXT:    br i1 [[CMP]], label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
+; CHECK-NEXT:    br i1 true, label [[IF_THEN:%.*]], label [[IF_ELSE:%.*]]
 ; CHECK:       if.else:
-; CHECK-NEXT:    [[SEXT2:%.*]] = ashr exact i64 [[ADD]], 1
-; CHECK-NEXT:    [[ASHR:%.*]] = or i64 [[SEXT2]], 4418419761487020032
-; CHECK-NEXT:    ret i64 [[ASHR]]
+; CHECK-NEXT:    ret i64 poison
 ; CHECK:       if.then:
 ; CHECK-NEXT:    ret i64 0
 ;



More information about the llvm-commits mailing list