[llvm] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`. NFC. (PR #80764)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 6 18:14:54 PST 2024


https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/80764

>From 0bcdd86bfc216e225f0ef5de130ca7eb19211859 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 6 Feb 2024 07:23:46 +0800
Subject: [PATCH] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`.
 NFC.

---
 llvm/include/llvm/IR/PatternMatch.h           | 28 ++++++++++++++++++
 llvm/lib/Analysis/InstructionSimplify.cpp     |  2 +-
 .../InstCombine/InstCombineAndOrXor.cpp       | 22 ++++++--------
 .../InstCombine/InstCombineCasts.cpp          | 12 ++++++--
 .../InstCombine/InstCombineCompares.cpp       | 17 ++++-------
 .../InstCombine/InstCombineSelect.cpp         | 11 +++----
 .../InstCombine/InstructionCombining.cpp      | 15 ----------
 llvm/unittests/IR/PatternMatch.cpp            | 29 +++++++++++++++++++
 8 files changed, 85 insertions(+), 51 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8e..f34fa7f661262e 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,34 @@ m_BitCast(const OpTy &Op) {
   return CastOperator_match<OpTy, Instruction::BitCast>(Op);
 }
 
+template <typename Op_t> struct ElementWiseBitCast_match {
+  Op_t Op;
+
+  ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+  template <typename OpTy> bool match(OpTy *V) {
+    if (auto *I = dyn_cast<BitCastInst>(V)) {
+      Type *SrcType = I->getSrcTy();
+      Type *DstType = I->getType();
+      // Make sure the bitcast doesn't change between scalar and vector and
+      // doesn't change the number of vector elements.
+      if (SrcType->isVectorTy() != DstType->isVectorTy())
+        return false;
+      if (SrcType->isVectorTy() &&
+          cast<VectorType>(SrcType)->getElementCount() !=
+              cast<VectorType>(DstType)->getElementCount())
+        return false;
+      return Op.match(I->getOperand(0));
+    }
+    return false;
+  }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+  return ElementWiseBitCast_match<OpTy>(Op);
+}
+
 /// Matches PtrToInt.
 template <typename OpTy>
 inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f36..01b017142cfcb0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   // floating-point casts:
   // icmp slt (bitcast (uitofp X)),  0 --> false
   // icmp sgt (bitcast (uitofp X)), -1 --> true
-  if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+  if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
     if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
       return ConstantInt::getFalse(ITy);
     if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6ca4d6d673068e..0409f33445f0af 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2493,14 +2493,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   // Assumes any IEEE-represented type has the sign bit in the high bit.
   // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
       match(Op1, m_MaxSignedValue()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
-        Attribute::NoImplicitFloat)) {
+          Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       return new BitCastInst(FAbs, I.getType());
     }
@@ -3925,13 +3923,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   // This is generous interpretation of noimplicitfloat, this is not a true
   // floating-point operation.
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+      match(Op1, m_SignMask()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       Value *FNegFAbs = Builder.CreateFNeg(FAbs);
       return new BitCastInst(FNegFAbs, I.getType());
@@ -4701,13 +4698,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
     // Assumes any IEEE-represented type has the sign bit in the high bit.
     // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
     Value *CastOp;
-    if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+    if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+        match(Op1, m_SignMask()) &&
         !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
             Attribute::NoImplicitFloat)) {
       Type *EltTy = CastOp->getType()->getScalarType();
-      if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-          EltTy->getPrimitiveSizeInBits() ==
-          I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+      if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
         Value *FNeg = Builder.CreateFNeg(CastOp);
         return new BitCastInst(FNeg, I.getType());
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 58f0763bb0c0cd..ed47de287302ed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -182,9 +182,15 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
     if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() ||
         (CI.getOpcode() == Instruction::Trunc &&
          shouldChangeType(CI.getSrcTy(), CI.getType()))) {
-      if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
-        replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
-        return NV;
+
+      // If it's a bitcast involving vectors, make sure it has the same number
+      // of elements on both sides.
+      if (CI.getOpcode() != Instruction::BitCast ||
+          match(&CI, m_ElementWiseBitCast(m_Value()))) {
+        if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
+          replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
+          return NV;
+        }
       }
     }
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d3..cda1061fb35a8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
   Value *V;
   if (!Cmp.getParent()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat) &&
-      Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
-    Type *SrcType = V->getType();
-    Type *DstType = X->getType();
-    Type *FPType = SrcType->getScalarType();
-    // Make sure the bitcast doesn't change between scalar and vector and
-    // doesn't change the number of vector elements.
-    if (SrcType->isVectorTy() == DstType->isVectorTy() &&
-        SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
-        FPType->isIEEELikeFPTy() && C1 == *C2) {
+      Cmp.isEquality() &&
+      match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+    Type *FPType = V->getType()->getScalarType();
+    if (FPType->isIEEELikeFPTy() && C1 == *C2) {
       APInt ExponentMask =
           APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
       if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   // Ignore signbit of bitcasted int when comparing equality to FP 0.0:
   // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
   if (match(Op1, m_PosZeroFP()) &&
-      match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
-      X->getType()->isVectorTy() == OpType->isVectorTy() &&
-      X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+      match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
     ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
     if (Pred == FCmpInst::FCMP_OEQ)
       IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e620..527037881edb19 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2365,9 +2365,6 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
   Value *FVal = Sel.getFalseValue();
   Type *SelType = Sel.getType();
 
-  if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType())
-    return nullptr;
-
   // Match select ?, TC, FC where the constants are equal but negated.
   // TODO: Generalize to handle a negated variable operand?
   const APFloat *TC, *FC;
@@ -2382,7 +2379,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
   const APInt *C;
   bool IsTrueIfSignSet;
   ICmpInst::Predicate Pred;
-  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+                                   m_APInt(C)))) ||
       !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
       X->getType() != SelType)
     return nullptr;
@@ -2770,8 +2768,6 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
 
   // Match select with (icmp slt (bitcast X to int), 0)
   //                or (icmp sgt (bitcast X to int), -1)
-  if (ICmpInst::makeCmpResultType(SI.getType()) != CondVal->getType())
-    return ChangedFMF ? &SI : nullptr;
 
   for (bool Swap : {false, true}) {
     Value *TrueVal = SI.getTrueValue();
@@ -2783,7 +2779,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
     CmpInst::Predicate Pred;
     const APInt *C;
     bool TrueIfSigned;
-    if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+    if (!match(CondVal,
+               m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
         !IC.isSignBitCheck(Pred, *C, TrueIfSigned))
       continue;
     if (!match(TrueVal, m_FNeg(m_Specific(X))))
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index dd168917f4dc9c..6b8bc438e6dc3a 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1474,21 +1474,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
   if (SI->getType()->isIntOrIntVectorTy(1))
     return nullptr;
 
-  // If it's a bitcast involving vectors, make sure it has the same number of
-  // elements on both sides.
-  if (auto *BC = dyn_cast<BitCastInst>(&Op)) {
-    VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy());
-    VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy());
-
-    // Verify that either both or neither are vectors.
-    if ((SrcTy == nullptr) != (DestTy == nullptr))
-      return nullptr;
-
-    // If vectors, verify that they have the same number of elements.
-    if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount())
-      return nullptr;
-  }
-
   // Test if a FCmpInst instruction is used exclusively by a select as
   // part of a minimum or maximum operation. If so, refrain from doing
   // any other folding. This helps out other analyses which understand
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 9e9e41b8fbad0d..8ee6fc0eb2112d 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -530,6 +530,35 @@ TEST_F(PatternMatchTest, ZExtSExtSelf) {
   EXPECT_TRUE(m_ZExtOrSExtOrSelf(m_One()).match(One64S));
 }
 
+TEST_F(PatternMatchTest, BitCast) {
+  Value *OneDouble = ConstantFP::get(IRB.getDoubleTy(), APFloat(1.0));
+  // scalar -> scalar
+  Value *DoubleToI64 = IRB.CreateBitCast(OneDouble, IRB.getInt64Ty());
+  // scalar -> vector
+  Value *DoubleToV2I32 = IRB.CreateBitCast(
+      OneDouble, VectorType::get(IRB.getInt32Ty(), 2, /*Scalable=*/false));
+  // vector -> scalar
+  Value *V2I32ToDouble = IRB.CreateBitCast(DoubleToV2I32, IRB.getDoubleTy());
+  // vector -> vector (same count)
+  Value *V2I32ToV2Float = IRB.CreateBitCast(
+      DoubleToV2I32, VectorType::get(IRB.getFloatTy(), 2, /*Scalable=*/false));
+  // vector -> vector (different count)
+  Value *V2I32TOV4I16 = IRB.CreateBitCast(
+      DoubleToV2I32, VectorType::get(IRB.getInt16Ty(), 4, /*Scalable=*/false));
+
+  EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToI64));
+  EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToV2I32));
+  EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToDouble));
+  EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToV2Float));
+  EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32TOV4I16));
+
+  EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(DoubleToI64));
+  EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(DoubleToV2I32));
+  EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32ToDouble));
+  EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(V2I32ToV2Float));
+  EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32TOV4I16));
+}
+
 TEST_F(PatternMatchTest, Power2) {
   Value *C128 = IRB.getInt32(128);
   Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));



More information about the llvm-commits mailing list