[llvm] [CmpInstAnalysis] Return decomposed bit test as struct (NFC) (PR #109819)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 08:51:15 PDT 2024


https://github.com/nikic created https://github.com/llvm/llvm-project/pull/109819

decomposeBitTestICmp() currently returns the result via two out parameters plus an in-place modification of Pred. This changes it to return an optional struct instead.

The motivation here is twofold. First, I'd like to extend this code to handle cases where the comparison is against a value other than zero, which would mean yet another out parameter. Second, while doing that I was badly bitten by the in-place modification, so I'd like to get rid of it.

>From 2dc6fdc5f71c25c2895e73fddf3aa73901cc7de9 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 24 Sep 2024 17:44:09 +0200
Subject: [PATCH] [CmpInstAnalysis] Return decomposed bit test as struct (NFC)

decomposeBitTestICmp() currently returns the result via two out
parameters plus an in-place modification of Pred. This changes it
to return an optional struct instead.

The motivation here is twofold. First, I'd like to extend this code
to handle cases where the comparison is against a value other than
zero, which would mean yet another out parameter. Second, while
doing that I was badly bitten by the in-place modification, so I'd
like to get rid of it.
---
 llvm/include/llvm/Analysis/CmpInstAnalysis.h  | 19 ++++--
 llvm/lib/Analysis/CmpInstAnalysis.cpp         | 67 ++++++++++---------
 llvm/lib/Analysis/InstructionSimplify.cpp     | 10 ++-
 .../InstCombine/InstCombineAndOrXor.cpp       | 20 ++++--
 .../InstCombine/InstCombineCompares.cpp       |  9 ++-
 .../InstCombine/InstCombineSelect.cpp         | 20 +++---
 .../Transforms/Scalar/LoopIdiomRecognize.cpp  | 14 ++--
 7 files changed, 91 insertions(+), 68 deletions(-)

diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index 1d07a0c22887bb..406dacd930605e 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -14,6 +14,7 @@
 #ifndef LLVM_ANALYSIS_CMPINSTANALYSIS_H
 #define LLVM_ANALYSIS_CMPINSTANALYSIS_H
 
+#include "llvm/ADT/APInt.h"
 #include "llvm/IR/InstrTypes.h"
 
 namespace llvm {
@@ -91,12 +92,18 @@ namespace llvm {
   Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
                                CmpInst::Predicate &Pred);
 
-  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible. The
-  /// returned predicate is either == or !=. Returns false if decomposition
-  /// fails.
-  bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
-                            Value *&X, APInt &Mask,
-                            bool LookThroughTrunc = true);
+  /// Represents the operation icmp (X & Mask) pred 0, where pred can only be
+  /// eq or ne.
+  struct DecomposedBitTest {
+    Value *X;
+    CmpInst::Predicate Pred;
+    APInt Mask;
+  };
+
+  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
+  std::optional<DecomposedBitTest>
+  decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
+                       bool LookThroughTrunc = true);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index a1fa7857764d98..36d7aa510545af 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -73,81 +73,84 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
   return nullptr;
 }
 
-bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
-                                CmpInst::Predicate &Pred,
-                                Value *&X, APInt &Mask, bool LookThruTrunc) {
+std::optional<DecomposedBitTest>
+llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
+                           bool LookThruTrunc) {
   using namespace PatternMatch;
 
   const APInt *C;
   if (!match(RHS, m_APIntAllowPoison(C)))
-    return false;
+    return std::nullopt;
 
+  DecomposedBitTest Result;
   switch (Pred) {
   default:
-    return false;
+    return std::nullopt;
   case ICmpInst::ICMP_SLT:
     // X < 0 is equivalent to (X & SignMask) != 0.
     if (!C->isZero())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   case ICmpInst::ICMP_SLE:
     // X <= -1 is equivalent to (X & SignMask) != 0.
     if (!C->isAllOnes())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   case ICmpInst::ICMP_SGT:
     // X > -1 is equivalent to (X & SignMask) == 0.
     if (!C->isAllOnes())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_SGE:
     // X >= 0 is equivalent to (X & SignMask) == 0.
     if (!C->isZero())
-      return false;
-    Mask = APInt::getSignMask(C->getBitWidth());
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = APInt::getSignMask(C->getBitWidth());
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_ULT:
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
     if (!C->isPowerOf2())
-      return false;
-    Mask = -*C;
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = -*C;
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_ULE:
     // X <=u 2^n-1 is equivalent to (X & ~(2^n-1)) == 0.
     if (!(*C + 1).isPowerOf2())
-      return false;
-    Mask = ~*C;
-    Pred = ICmpInst::ICMP_EQ;
+      return std::nullopt;
+    Result.Mask = ~*C;
+    Result.Pred = ICmpInst::ICMP_EQ;
     break;
   case ICmpInst::ICMP_UGT:
     // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0.
     if (!(*C + 1).isPowerOf2())
-      return false;
-    Mask = ~*C;
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = ~*C;
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   case ICmpInst::ICMP_UGE:
     // X >=u 2^n is equivalent to (X & ~(2^n-1)) != 0.
     if (!C->isPowerOf2())
-      return false;
-    Mask = -*C;
-    Pred = ICmpInst::ICMP_NE;
+      return std::nullopt;
+    Result.Mask = -*C;
+    Result.Pred = ICmpInst::ICMP_NE;
     break;
   }
 
+  Value *X;
   if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
-    Mask = Mask.zext(X->getType()->getScalarSizeInBits());
+    Result.X = X;
+    Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
   } else {
-    X = LHS;
+    Result.X = LHS;
   }
 
-  return true;
+  return Result;
 }
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 32a9f1ab34fb3f..90f05d43a2b147 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4624,13 +4624,11 @@ static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
 static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
                                            ICmpInst::Predicate Pred,
                                            Value *TrueVal, Value *FalseVal) {
-  Value *X;
-  APInt Mask;
-  if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask))
-    return nullptr;
+  if (auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred))
+    return simplifySelectBitTest(TrueVal, FalseVal, Res->X, &Res->Mask,
+                                 Res->Pred == ICmpInst::ICMP_EQ);
 
-  return simplifySelectBitTest(TrueVal, FalseVal, X, &Mask,
-                               Pred == ICmpInst::ICMP_EQ);
+  return nullptr;
 }
 
 /// Try to simplify a select instruction when its condition operand is an
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 80d3adedfc89f3..2c2d24d392a938 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -181,11 +181,13 @@ static unsigned conjugateICmpMask(unsigned Mask) {
 // Adapts the external decomposeBitTestICmp for local use.
 static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
                                  Value *&X, Value *&Y, Value *&Z) {
-  APInt Mask;
-  if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask))
+  auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
+  if (!Res)
     return false;
 
-  Y = ConstantInt::get(X->getType(), Mask);
+  Pred = Res->Pred;
+  X = Res->X;
+  Y = ConstantInt::get(X->getType(), Res->Mask);
   Z = ConstantInt::get(X->getType(), 0);
   return true;
 }
@@ -870,11 +872,15 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1,
                            APInt &UnsetBitsMask) -> bool {
     CmpInst::Predicate Pred = ICmp->getPredicate();
     // Can it be decomposed into  icmp eq (X & Mask), 0  ?
-    if (llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
-                                   Pred, X, UnsetBitsMask,
-                                   /*LookThroughTrunc=*/false) &&
-        Pred == ICmpInst::ICMP_EQ)
+    if (auto Res =
+            llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
+                                       Pred, /*LookThroughTrunc=*/false);
+        Res && Res->Pred == ICmpInst::ICMP_EQ) {
+      X = Res->X;
+      UnsetBitsMask = Res->Mask;
       return true;
+    }
+
     // Is it  icmp eq (X & Mask), 0  already?
     const APInt *Mask;
     if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) &&
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 698abbb34c18c3..b1215bb4d83b0f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5905,11 +5905,10 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
   // This matches patterns corresponding to tests of the signbit as well as:
   // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
   // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
-  APInt Mask;
-  if (decomposeBitTestICmp(Op0, Op1, Pred, X, Mask, true /* WithTrunc */)) {
-    Value *And = Builder.CreateAnd(X, Mask);
-    Constant *Zero = ConstantInt::getNullValue(X->getType());
-    return new ICmpInst(Pred, And, Zero);
+  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
+    Value *And = Builder.CreateAnd(Res->X, Res->Mask);
+    Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
+    return new ICmpInst(Res->Pred, And, Zero);
   }
 
   unsigned SrcBits = X->getType()->getScalarSizeInBits();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 7476db9ee38f45..3dbe95897d6356 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -145,12 +145,15 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp,
       return nullptr;
 
     AndMask = *AndRHS;
-  } else if (decomposeBitTestICmp(Cmp->getOperand(0), Cmp->getOperand(1),
-                                  Pred, V, AndMask)) {
-    assert(ICmpInst::isEquality(Pred) && "Not equality test?");
-    if (!AndMask.isPowerOf2())
+  } else if (auto Res = decomposeBitTestICmp(Cmp->getOperand(0),
+                                             Cmp->getOperand(1), Pred)) {
+    assert(ICmpInst::isEquality(Res->Pred) && "Not equality test?");
+    if (!Res->Mask.isPowerOf2())
       return nullptr;
 
+    V = Res->X;
+    AndMask = Res->Mask;
+    Pred = Res->Pred;
     CreateAnd = true;
   } else {
     return nullptr;
@@ -747,12 +750,13 @@ static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal,
 
     C1Log = C1->logBase2();
   } else {
-    APInt C1;
-    if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) ||
-        !C1.isPowerOf2())
+    auto Res = decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    if (!Res || !Res->Mask.isPowerOf2())
       return nullptr;
 
-    C1Log = C1.logBase2();
+    CmpLHS = Res->X;
+    Pred = Res->Pred;
+    C1Log = Res->Mask.logBase2();
     NeedAnd = true;
   }
 
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 578d087e470e1e..e3c3984ccb5156 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -2465,10 +2465,16 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX,
   };
   auto MatchDecomposableConstantBitMask = [&]() {
     APInt Mask;
-    return llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CurrX, Mask) &&
-           ICmpInst::isEquality(Pred) && Mask.isPowerOf2() &&
-           (BitMask = ConstantInt::get(CurrX->getType(), Mask)) &&
-           (BitPos = ConstantInt::get(CurrX->getType(), Mask.logBase2()));
+    auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred);
+    if (Res && Res->Mask.isPowerOf2()) {
+      assert(ICmpInst::isEquality(Res->Pred));
+      Pred = Res->Pred;
+      CurrX = Res->X;
+      BitMask = ConstantInt::get(CurrX->getType(), Res->Mask);
+      BitPos = ConstantInt::get(CurrX->getType(), Res->Mask.logBase2());
+      return true;
+    }
+    return false;
   };
 
   if (!MatchVariableBitMask() && !MatchConstantBitMask() &&



More information about the llvm-commits mailing list