[llvm] [InstCombine] Missed optimization for select a%2==0, (a/2*2)*(a/2*2), 0 (PR #92658)

Jorge Botto via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 3 10:57:11 PDT 2024


https://github.com/jf-botto updated https://github.com/llvm/llvm-project/pull/92658

>From a860a3dfaabb42b73aabd8286edf56bb0bb6822a Mon Sep 17 00:00:00 2001
From: Jorge Botto <jorge.botto.16 at ucl.ac.uk>
Date: Sat, 27 Jul 2024 22:56:40 +0100
Subject: [PATCH 1/2] precommit test

---
 .../InstCombine/select-known-bits.ll          | 125 ++++++++++++++++++
 1 file changed, 125 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/select-known-bits.ll

diff --git a/llvm/test/Transforms/InstCombine/select-known-bits.ll b/llvm/test/Transforms/InstCombine/select-known-bits.ll
new file mode 100644
index 0000000000000..a4332f2f1d56e
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/select-known-bits.ll
@@ -0,0 +1,125 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instcombine -S | FileCheck %s
+
+define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_mul_and(
+; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]]
+; CHECK-NEXT:    ret i8 [[RETVAL]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div = and i8 %a, -2
+  %mul = mul i8 %div, %div
+  %retval = select i1 %cmp, i8 %mul, i8 %b
+  ret i8 %retval
+}
+
+define i8 @select_icmp_eq_mul_and_inv(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_mul_and_inv(
+; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[MUL]], i8 [[B]]
+; CHECK-NEXT:    ret i8 [[RETVAL]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 1
+  %div = and i8 %a, -2
+  %mul = mul i8 %div, %div
+  %retval = select i1 %cmp, i8 %b, i8 %mul
+  ret i8 %retval
+}
+
+define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_and(
+; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]]
+; CHECK-NEXT:    ret i8 [[RETVAL]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div = and i8 %a, -2
+  %retval = select i1 %cmp, i8 %div, i8 %b
+  ret i8 %retval
+}
+
+define i8 @select_icmp_eq_and_inv(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_and_inv(
+; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[DIV]], i8 [[B]]
+; CHECK-NEXT:    ret i8 [[RETVAL]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 1
+  %div = and i8 %a, -2
+  %retval = select i1 %cmp, i8 %b, i8 %div
+  ret i8 %retval
+}
+
+;negative test
+define i8 @select_icmp_eq_and_undef(i8 %a, i8 %b)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_and_undef(
+; CHECK-SAME: i8 [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]]
+; CHECK-NEXT:    ret i8 [[RETVAL]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div = and i8 %a, -2
+  %retval = select i1 %cmp, i8 %div, i8 %b
+  ret i8 %retval
+}
+
+;negative test
+define i8 @select_icmp_eq_and_diff(i8 noundef %a, i8 %b, i8 %c)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_and_diff(
+; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]], i8 [[C:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[C]], -2
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]]
+; CHECK-NEXT:    ret i8 [[RETVAL]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div = and i8 %c, -2
+  %retval = select i1 %cmp, i8 %div, i8 %b
+  ret i8 %retval
+}
+
+;negative test
+define i8 @select_icmp_eq_mul_and_extra_use(i8 noundef %a, i8 %b)  {
+; CHECK-LABEL: define i8 @select_icmp_eq_mul_and_extra_use(
+; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]]
+; CHECK-NEXT:    [[SUM:%.*]] = add i8 [[MUL]], [[RETVAL]]
+; CHECK-NEXT:    ret i8 [[SUM]]
+;
+  %1 = and i8 %a, 1
+  %cmp = icmp eq i8 %1, 0
+  %div = and i8 %a, -2
+  %mul = mul i8 %div, %div
+  %retval = select i1 %cmp, i8 %mul, i8 %b
+  %sum = add i8 %mul, %retval
+  ret i8 %sum
+}

>From bf113b827fa4c34d30ae8b0075f735e24e4f46e3 Mon Sep 17 00:00:00 2001
From: Jorge Botto <jorge.botto.16 at ucl.ac.uk>
Date: Sat, 3 Aug 2024 14:43:02 +0100
Subject: [PATCH 2/2] Adding missed optimisation

---
 llvm/include/llvm/Analysis/ValueTracking.h    |  4 ++
 .../Transforms/InstCombine/InstCombiner.h     |  7 +++
 llvm/lib/Analysis/ValueTracking.cpp           |  6 +-
 .../InstCombine/InstCombineSelect.cpp         | 63 +++++++++++++++++++
 .../InstCombine/select-known-bits.ll          | 12 ++--
 llvm/test/Transforms/InstCombine/select.ll    | 11 ++--
 6 files changed, 85 insertions(+), 18 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index 96fa16970584d..c1ee4c02e0108 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -94,6 +94,10 @@ void computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownBits &Known);
 void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
                                  unsigned Depth, const SimplifyQuery &Q);
 
+void computeKnownBitsFromCond(const Value *V, Value *Cond, KnownBits &Known,
+                              unsigned Depth, const SimplifyQuery &SQ,
+                              bool Invert);
+
 /// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
 KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
                                        const KnownBits &KnownLHS,
diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
index ebcbd5d9e8880..27bcaad49e5b4 100644
--- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
+++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
@@ -438,6 +438,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
     return llvm::computeKnownBits(V, Depth, SQ.getWithInstruction(CxtI));
   }
 
+  void computeKnownBitsFromCond(const Value *V, Value *Cmp, KnownBits &Known,
+                                unsigned Depth, const Instruction *CxtI,
+                                bool Invert) const {
+    llvm::computeKnownBitsFromCond(V, Cmp, Known, Depth,
+                                   SQ.getWithInstruction(CxtI), Invert);
+  }
+
   bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
                               unsigned Depth = 0,
                               const Instruction *CxtI = nullptr) {
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 285284dc27071..0d2d2d3bbbdbf 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -752,9 +752,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
   computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
 }
 
-static void computeKnownBitsFromCond(const Value *V, Value *Cond,
-                                     KnownBits &Known, unsigned Depth,
-                                     const SimplifyQuery &SQ, bool Invert) {
+void llvm::computeKnownBitsFromCond(const Value *V, Value *Cond,
+                                    KnownBits &Known, unsigned Depth,
+                                    const SimplifyQuery &SQ, bool Invert) {
   Value *A, *B;
   if (Depth < MaxAnalysisRecursionDepth &&
       match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index a22ee1de0ac21..eaa8faaa2db0f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1078,6 +1078,62 @@ static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
   return nullptr;
 }
 
+/// Attempts to fold (AND %A constant) --> %A
+/// if all bits that are zero in the negated constant
+/// are also zero in A's known zero bits.
+static Value *foldAndMaskPattern(Value *V, Value *Cmp, SelectInst &SI,
+                                 InstCombinerImpl &IC, unsigned Depth = 0) {
+
+  Value *A;
+  const APInt *MaskedConstant;
+
+  if (match(V, m_And(m_Value(A), m_APInt(MaskedConstant))) &&
+      isGuaranteedNotToBeUndef(A)) {
+    KnownBits Known = IC.computeKnownBits(A, 0, &SI);
+    IC.computeKnownBitsFromCond(A, Cmp, Known, 0, &SI, false);
+    if ((~(*MaskedConstant)).isSubsetOf(Known.Zero))
+      return A;
+  }
+
+  auto *I = dyn_cast<Instruction>(V);
+  if (!I || !isSafeToSpeculativelyExecute(I) || Depth >= 2)
+    return nullptr;
+
+  bool Changed = false;
+  for (unsigned i = 0; i < I->getNumOperands(); ++i) {
+    llvm::Value *Operand = I->getOperand(i);
+
+    if (std::any_of(Operand->user_begin(), Operand->user_end(),
+                    [I](const User *User) { return User != I; }))
+      break;
+
+    Value *NewOp = foldAndMaskPattern(Operand, Cmp, SI, IC, Depth + 1);
+    if (NewOp) {
+      IC.replaceOperand(*I, i, NewOp);
+      Changed = true;
+    }
+  }
+
+  return Changed ? I : nullptr;
+}
+
+/// Attmpts to fold expressions in both branches of a select instruction
+/// based on KnownBits implied by the condition
+// static Instruction *foldSelectWithIcmpEqAndPattern(Value *TVal, Value *FVal,
+//                                                    Value *CondVal,
+//                                                    SelectInst &SI,
+//                                                    InstCombinerImpl &IC) {
+//   if (TVal->hasOneUse())
+//     if (Value *newTrueOp = simplifyAndMaskPattern(TVal, CondVal, SI, IC))
+//       return IC.replaceOperand(SI, 1, newTrueOp);
+
+//   if (FVal->hasOneUse())
+//     if (Value *newFalseOp = simplifyAndMaskPattern(FVal, CondVal, SI, IC))
+//       return IC.replaceOperand(SI, 2, newFalseOp);
+
+//   return nullptr;
+// }
+
 /// Fold the following code sequence:
 /// \code
 ///   int a = ctlz(x & -x);
@@ -4110,5 +4166,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
     }
   }
 
+  // Attempts to recursively identify and fold (AND A constant) --> A
+  // in the true branch of the select if all bits
+  // that are zero in the negated constant are also zero in A's known zero bits.
+  if (TrueVal->hasOneUse())
+    if (Value *newTrueOp = foldAndMaskPattern(TrueVal, CondVal, SI, *this))
+      return replaceOperand(SI, 1, newTrueOp);
+
   return nullptr;
 }
diff --git a/llvm/test/Transforms/InstCombine/select-known-bits.ll b/llvm/test/Transforms/InstCombine/select-known-bits.ll
index a4332f2f1d56e..52c56bed429a0 100644
--- a/llvm/test/Transforms/InstCombine/select-known-bits.ll
+++ b/llvm/test/Transforms/InstCombine/select-known-bits.ll
@@ -6,8 +6,7 @@ define i8 @select_icmp_eq_mul_and(i8 noundef %a, i8 %b)  {
 ; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[A]], [[A]]
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 [[B]]
 ; CHECK-NEXT:    ret i8 [[RETVAL]]
 ;
@@ -24,8 +23,7 @@ define i8 @select_icmp_eq_mul_and_inv(i8 noundef %a, i8 %b)  {
 ; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
 ; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[DIV]], [[DIV]]
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[A]], [[A]]
 ; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[MUL]], i8 [[B]]
 ; CHECK-NEXT:    ret i8 [[RETVAL]]
 ;
@@ -42,8 +40,7 @@ define i8 @select_icmp_eq_and(i8 noundef %a, i8 %b)  {
 ; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[DIV]], i8 [[B]]
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP]], i8 [[A]], i8 [[B]]
 ; CHECK-NEXT:    ret i8 [[RETVAL]]
 ;
   %1 = and i8 %a, 1
@@ -58,8 +55,7 @@ define i8 @select_icmp_eq_and_inv(i8 noundef %a, i8 %b)  {
 ; CHECK-SAME: i8 noundef [[A:%.*]], i8 [[B:%.*]]) {
 ; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[A]], 1
 ; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i8 [[TMP1]], 0
-; CHECK-NEXT:    [[DIV:%.*]] = and i8 [[A]], -2
-; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[DIV]], i8 [[B]]
+; CHECK-NEXT:    [[RETVAL:%.*]] = select i1 [[CMP_NOT]], i8 [[A]], i8 [[B]]
 ; CHECK-NEXT:    ret i8 [[RETVAL]]
 ;
   %1 = and i8 %a, 1
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 1369be305ec13..1c7247ec6a8b3 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -2989,9 +2989,8 @@ define i8 @select_replacement_loop3(i32 noundef %x) {
 
 define i16 @select_replacement_loop4(i16 noundef %p_12) {
 ; CHECK-LABEL: @select_replacement_loop4(
-; CHECK-NEXT:    [[AND1:%.*]] = and i16 [[P_12:%.*]], 1
-; CHECK-NEXT:    [[CMP21:%.*]] = icmp ult i16 [[P_12]], 2
-; CHECK-NEXT:    [[AND3:%.*]] = select i1 [[CMP21]], i16 [[AND1]], i16 0
+; CHECK-NEXT:    [[CMP21:%.*]] = icmp ult i16 [[P_12:%.*]], 2
+; CHECK-NEXT:    [[AND3:%.*]] = select i1 [[CMP21]], i16 [[P_12]], i16 0
 ; CHECK-NEXT:    ret i16 [[AND3]]
 ;
   %cmp1 = icmp ult i16 %p_12, 2
@@ -4671,8 +4670,7 @@ define i8 @select_knownbits_simplify(i8 noundef %x)  {
 ; CHECK-LABEL: @select_knownbits_simplify(
 ; CHECK-NEXT:    [[X_LO:%.*]] = and i8 [[X:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[X_LO]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], -2
-; CHECK-NEXT:    [[RES:%.*]] = select i1 [[CMP]], i8 [[AND]], i8 0
+; CHECK-NEXT:    [[RES:%.*]] = select i1 [[CMP]], i8 [[X]], i8 0
 ; CHECK-NEXT:    ret i8 [[RES]]
 ;
   %x.lo = and i8 %x, 1
@@ -4686,8 +4684,7 @@ define i8 @select_knownbits_simplify_nested(i8 noundef %x)  {
 ; CHECK-LABEL: @select_knownbits_simplify_nested(
 ; CHECK-NEXT:    [[X_LO:%.*]] = and i8 [[X:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[X_LO]], 0
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], -2
-; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[AND]], [[AND]]
+; CHECK-NEXT:    [[MUL:%.*]] = mul i8 [[X]], [[X]]
 ; CHECK-NEXT:    [[RES:%.*]] = select i1 [[CMP]], i8 [[MUL]], i8 0
 ; CHECK-NEXT:    ret i8 [[RES]]
 ;



More information about the llvm-commits mailing list