[llvm] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`. NFC. (PR #80764)

via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 15:42:55 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Yingwei Zheng (dtcxzyw)

<details>
<summary>Changes</summary>

This patch introduces a matching helper `m_ElementWiseBitCast`, which is used for matching element-wise int<-> fp casts.
The motivation of this patch is to avoid duplicating checks in https://github.com/llvm/llvm-project/pull/80740 and https://github.com/llvm/llvm-project/pull/80414.

---
Full diff: https://github.com/llvm/llvm-project/pull/80764.diff


5 Files Affected:

- (modified) llvm/include/llvm/IR/PatternMatch.h (+24) 
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+1-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+9-13) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+5-12) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+4-2) 


``````````diff
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8e..7ce42e30639092 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,30 @@ m_BitCast(const OpTy &Op) {
   return CastOperator_match<OpTy, Instruction::BitCast>(Op);
 }
 
+template <typename Op_t> struct ElementWiseBitCast_match {
+  Op_t Op;
+
+  ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+  template <typename OpTy> bool match(OpTy *V) {
+    if (auto *I = dyn_cast<BitCastInst>(V)) {
+      Type *SrcType = I->getSrcTy();
+      Type *DstType = I->getType();
+      // Make sure the bitcast doesn't change between scalar and vector and
+      // doesn't change the number of vector elements.
+      if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+          SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits())
+        return Op.match(I->getOperand(0));
+    }
+    return false;
+  }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+  return ElementWiseBitCast_match<OpTy>(Op);
+}
+
 /// Matches PtrToInt.
 template <typename OpTy>
 inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f36..01b017142cfcb0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   // floating-point casts:
   // icmp slt (bitcast (uitofp X)),  0 --> false
   // icmp sgt (bitcast (uitofp X)), -1 --> true
-  if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+  if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
     if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
       return ConstantInt::getFalse(ITy);
     if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6ca4d6d673068e..0409f33445f0af 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2493,14 +2493,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   // Assumes any IEEE-represented type has the sign bit in the high bit.
   // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
       match(Op1, m_MaxSignedValue()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
-        Attribute::NoImplicitFloat)) {
+          Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       return new BitCastInst(FAbs, I.getType());
     }
@@ -3925,13 +3923,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   // This is generous interpretation of noimplicitfloat, this is not a true
   // floating-point operation.
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+      match(Op1, m_SignMask()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       Value *FNegFAbs = Builder.CreateFNeg(FAbs);
       return new BitCastInst(FNegFAbs, I.getType());
@@ -4701,13 +4698,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
     // Assumes any IEEE-represented type has the sign bit in the high bit.
     // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
     Value *CastOp;
-    if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+    if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+        match(Op1, m_SignMask()) &&
         !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
             Attribute::NoImplicitFloat)) {
       Type *EltTy = CastOp->getType()->getScalarType();
-      if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-          EltTy->getPrimitiveSizeInBits() ==
-          I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+      if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
         Value *FNeg = Builder.CreateFNeg(CastOp);
         return new BitCastInst(FNeg, I.getType());
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d3..cda1061fb35a8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
   Value *V;
   if (!Cmp.getParent()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat) &&
-      Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
-    Type *SrcType = V->getType();
-    Type *DstType = X->getType();
-    Type *FPType = SrcType->getScalarType();
-    // Make sure the bitcast doesn't change between scalar and vector and
-    // doesn't change the number of vector elements.
-    if (SrcType->isVectorTy() == DstType->isVectorTy() &&
-        SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
-        FPType->isIEEELikeFPTy() && C1 == *C2) {
+      Cmp.isEquality() &&
+      match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+    Type *FPType = V->getType()->getScalarType();
+    if (FPType->isIEEELikeFPTy() && C1 == *C2) {
       APInt ExponentMask =
           APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
       if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   // Ignore signbit of bitcasted int when comparing equality to FP 0.0:
   // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
   if (match(Op1, m_PosZeroFP()) &&
-      match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
-      X->getType()->isVectorTy() == OpType->isVectorTy() &&
-      X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+      match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
     ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
     if (Pred == FCmpInst::FCMP_OEQ)
       IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e620..d6f447125269c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2382,7 +2382,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
   const APInt *C;
   bool IsTrueIfSignSet;
   ICmpInst::Predicate Pred;
-  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+                                   m_APInt(C)))) ||
       !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
       X->getType() != SelType)
     return nullptr;
@@ -2783,7 +2784,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
     CmpInst::Predicate Pred;
     const APInt *C;
     bool TrueIfSigned;
-    if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+    if (!match(CondVal,
+               m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
         !IC.isSignBitCheck(Pred, *C, TrueIfSigned))
       continue;
     if (!match(TrueVal, m_FNeg(m_Specific(X))))

``````````

</details>


https://github.com/llvm/llvm-project/pull/80764


More information about the llvm-commits mailing list