[llvm] [InstCombine] Missed optimization for select a%2==0, (a/2*2)*(a/2*2), 0 (PR #92658)

Jorge Botto via llvm-commits llvm-commits at lists.llvm.org
Sat Jun 15 02:09:07 PDT 2024


https://github.com/jf-botto updated https://github.com/llvm/llvm-project/pull/92658

>From 195e0fa744b57474b9220c79deee0fb2eb79159f Mon Sep 17 00:00:00 2001
From: Jorge Botto <jorge.botto.16 at ucl.ac.uk>
Date: Fri, 14 Jun 2024 23:20:18 +0100
Subject: [PATCH 1/2] precommit test

---
 llvm/test/Transforms/InstCombine/select.ll | 125 +++++++++++++++++++++
 1 file changed, 125 insertions(+)

diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index b37e9175b26a5..5addea127ee51 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -1456,6 +1456,131 @@ define <2 x i32> @select_icmp_slt0_xor_vec(<2 x i32> %x) {
   ret <2 x i32> %x.xor
 }
 
+define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_mul_and(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %mul = mul i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+  ret i8 %retval.0
+}
+
+define i8 @select_icmp_eq_shl_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_shl_and(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[SHL:%.*]] = shl i8 [[DIV7]], [[DIV7]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[SHL]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %shl = shl i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %shl, i8 %b
+  ret i8 %retval.0
+}
+
+define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_and(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_icmp_eq_mul_and_undef(i8 %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_mul_and_undef(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %mul = mul i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_icmp_eq_and_undef(i8 %a, i8 %b)  {
+; CHECK-LABEL: @select_icmp_eq_and_undef(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %a, -2
+  %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_and(i8 noundef %a, i8 %b, i1 %cmp)  {
+; CHECK-LABEL: @select_and(
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A:%.*]], -2
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP:%.*]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %div7 = and i8 %a, -2
+  %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_mul_and(i8 noundef %a, i8 %b, i1 %cmp)  {
+; CHECK-LABEL: @select_mul_and(
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A:%.*]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP:%.*]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %div7 = and i8 %a, -2
+  %mul = mul i8 %div7, %div7
+  %retval.0 = select i1 %cmp, i8 %mul, i8 %b
+  ret i8 %retval.0
+}
+
+;negative test
+define i8 @select_icmp_eq_and_diff(i8 noundef %a, i8 %b, i8 %c)  {
+; CHECK-LABEL: @select_icmp_eq_and_diff(
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[C:%.*]], -2
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    ret i8 [[RETVAL_0]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div7 = and i8 %c, -2
+  %retval.0 = select i1 %cmp, i8 %div7, i8 %b
+  ret i8 %retval.0
+}
+
 define <4 x i32> @canonicalize_to_shuffle(<4 x i32> %a, <4 x i32> %b) {
 ; CHECK-LABEL: @canonicalize_to_shuffle(
 ; CHECK-NEXT:    [[SEL:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> [[B:%.*]], <4 x i32> <i32 0, i32 5, i32 6, i32 3>

>From e210641aebfc14f738bc0a56493b9db2e5d728d7 Mon Sep 17 00:00:00 2001
From: Jorge Botto <jorge.botto.16 at ucl.ac.uk>
Date: Fri, 14 Jun 2024 23:54:32 +0100
Subject: [PATCH 2/2] Adding the missed optimisation

---
 llvm/include/llvm/Analysis/ValueTracking.h    |  4 ++
 .../Transforms/InstCombine/InstCombiner.h     |  7 ++
 llvm/lib/Analysis/ValueTracking.cpp           |  6 +-
 .../InstCombine/InstCombineSelect.cpp         | 71 ++++++++++++++++---
 llvm/test/Transforms/InstCombine/select.ll    | 21 +++---
 5 files changed, 82 insertions(+), 27 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index e577d0cc7ad41..57fa4f29d5eff 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -94,6 +94,10 @@ void computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownBits &Known);
 void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
                                  unsigned Depth, const SimplifyQuery &Q);
 
+void computeKnownBitsFromCond(const Value *V, Value *Cond, KnownBits &Known,
+                              unsigned Depth, const SimplifyQuery &SQ,
+                              bool Invert);
+
 /// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
 KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
                                        const KnownBits &KnownLHS,
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index 855d1aeddfaee..f81150b7b31f6 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -438,6 +438,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
     return llvm::computeKnownBits(V, Depth, SQ.getWithInstruction(CxtI));
   }
 
+  void computeKnownBitsFromCond(const Value *V, ICmpInst *Cmp, KnownBits &Known,
+                                unsigned Depth, const Instruction *CxtI,
+                                bool Invert) const {
+    llvm::computeKnownBitsFromCond(V, Cmp, Known, Depth,
+                                   SQ.getWithInstruction(CxtI), Invert);
+  }
+
   bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
                               unsigned Depth = 0,
                               const Instruction *CxtI = nullptr) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 8126d2a1acc27..085e4a81d17ee 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -747,9 +747,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
   computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
 }
 
-static void computeKnownBitsFromCond(const Value *V, Value *Cond,
-                                     KnownBits &Known, unsigned Depth,
-                                     const SimplifyQuery &SQ, bool Invert) {
+void llvm::computeKnownBitsFromCond(const Value *V, Value *Cond,
+                                    KnownBits &Known, unsigned Depth,
+                                    const SimplifyQuery &SQ, bool Invert) {
   Value *A, *B;
   if (Depth < MaxAnalysisRecursionDepth &&
       match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7d26807544d7e..d6f6cebec50ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1073,6 +1073,50 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
   return nullptr;
 }
 
+// When the lsb of cond is 0:
+// cond ? A & -2 : B --> cond ? A : B
+// cond ? BinOp (A & -2), (A & -2) : B --> cond ? BinOp A, A : B
+static Value *foldSelectWithIcmpEqAndPattern(ICmpInst *Cmp, Value *TVal,
+                                             Value *FVal,
+                                             InstCombiner::BuilderTy &Builder,
+                                             SelectInst &SI,
+                                             InstCombinerImpl &IC) {
+
+  Value *A;
+  ConstantInt *MaskedConstant;
+
+  // Checks if true branche matches the pattern 'A % 2'.
+  if (match(TVal,
+            m_OneUse(m_c_And(m_Value(A), m_ConstantInt(MaskedConstant)))) &&
+      MaskedConstant->getValue().getSExtValue() == -2) {
+    KnownBits Known;
+    Known = IC.computeKnownBits(A, 0, &SI);
+    IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false);
+    if (Known.Zero[0])
+      return Builder.CreateSelect(Cmp, A, FVal);
+    else
+      return nullptr;
+  }
+
+  // Checks if true branch matches nested 'A % 2' within a binary operation.
+  Value *MulVal;
+  if (match(TVal, m_OneUse(m_BinOp(m_Value(MulVal), m_Deferred(MulVal)))))
+    if (match(MulVal, m_c_And(m_Value(A), m_ConstantInt(MaskedConstant))) &&
+        MaskedConstant->getValue().getSExtValue() == -2) {
+      KnownBits Known;
+      Known = IC.computeKnownBits(A, 0, &SI);
+      IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false);
+      if (Known.Zero[0]) {
+        Instruction::BinaryOps OpCode = cast<BinaryOperator>(TVal)->getOpcode();
+        Value *NewTValue = Builder.CreateBinOp(OpCode, A, A);
+        return Builder.CreateSelect(Cmp, NewTValue, FVal);
+      } else
+        return nullptr;
+    }
+
+  return nullptr;
+}
+
 /// Fold the following code sequence:
 /// \code
 ///   int a = ctlz(x & -x);
@@ -1812,6 +1856,7 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
 /// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
                                                       ICmpInst *ICI) {
+
   if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
     return NewSel;
 
@@ -1951,6 +1996,10 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
   if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
+  if (Value *V = foldSelectWithIcmpEqAndPattern(ICI, TrueVal, FalseVal, Builder,
+                                                SI, *this))
+    return replaceInstUsesWith(SI, V);
+
   return Changed ? &SI : nullptr;
 }
 
@@ -2359,20 +2408,20 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel,
 /// operand, the result of the select will always be equal to its false value.
 /// For example:
 ///
-///   %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
-///   %val = extractvalue { i64, i1 } %cmpxchg, 0
-///   %success = extractvalue { i64, i1 } %cmpxchg, 1
-///   %sel = select i1 %success, i64 %compare, i64 %val
-///   ret i64 %sel
+///   %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
+///   %1 = extractvalue { i64, i1 } %0, 1
+///   %2 = extractvalue { i64, i1 } %0, 0
+///   %3 = select i1 %1, i64 %compare, i64 %2
+///   ret i64 %3
 ///
-/// The returned value of the cmpxchg instruction (%val) is the original value
-/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %val
+/// The returned value of the cmpxchg instruction (%2) is the original value
+/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2
 /// must have been equal to %compare. Thus, the result of the select is always
-/// equal to %val, and the code can be simplified to:
+/// equal to %2, and the code can be simplified to:
 ///
-///   %cmpxchg = cmpxchg ptr %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
-///   %val = extractvalue { i64, i1 } %cmpxchg, 0
-///   ret i64 %val
+///   %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst
+///   %1 = extractvalue { i64, i1 } %0, 0
+///   ret i64 %1
 ///
 static Value *foldSelectCmpXchg(SelectInst &SI) {
   // A helper that determines if V is an extractvalue instruction whose
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 5addea127ee51..637e9d25a3d41 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -1460,9 +1460,8 @@ define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b)  {
 ; CHECK-LABEL: @select_icmp_eq_mul_and(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
-; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i8 [[A]], [[A]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
 ; CHECK-NEXT:    ret i8 [[RETVAL_0]]
 ;
   %1 = and i8 %a, 1
@@ -1477,9 +1476,8 @@ define i8 @select_icmp_eq_shl_and(i8 noundef %a, i8 %b)  {
 ; CHECK-LABEL: @select_icmp_eq_shl_and(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[SHL:%.*]] = shl i8 [[DIV7]], [[DIV7]]
-; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[SHL]], i8 [[B:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i8 [[A]], [[A]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
 ; CHECK-NEXT:    ret i8 [[RETVAL_0]]
 ;
   %1 = and i8 %a, 1
@@ -1494,8 +1492,7 @@ define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b)  {
 ; CHECK-LABEL: @select_icmp_eq_and(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B:%.*]]
 ; CHECK-NEXT:    ret i8 [[RETVAL_0]]
 ;
   %1 = and i8 %a, 1
@@ -1510,9 +1507,8 @@ define i8 @select_icmp_eq_mul_and_undef(i8 %a, i8 %b)  {
 ; CHECK-LABEL: @select_icmp_eq_mul_and_undef(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV7]], [[DIV7]]
-; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i8 [[A]], [[A]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[TMP2]], i8 [[B:%.*]]
 ; CHECK-NEXT:    ret i8 [[RETVAL_0]]
 ;
   %1 = and i8 %a, 1
@@ -1528,8 +1524,7 @@ define i8 @select_icmp_eq_and_undef(i8 %a, i8 %b)  {
 ; CHECK-LABEL: @select_icmp_eq_and_undef(
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV7:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[DIV7]], i8 [[B:%.*]]
+; CHECK-NEXT:    [[RETVAL_0:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B:%.*]]
 ; CHECK-NEXT:    ret i8 [[RETVAL_0]]
 ;
   %1 = and i8 %a, 1



More information about the llvm-commits mailing list