[llvm] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`. NFC. (PR #80764)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 5 15:42:27 PST 2024


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/80764

This patch introduces a matching helper `m_ElementWiseBitCast`, which is used for matching element-wise int<-> fp casts.
The motivation of this patch is to avoid duplicating checks in https://github.com/llvm/llvm-project/pull/80740 and https://github.com/llvm/llvm-project/pull/80414.

>From 47f574404987726094b2258ecb5d235dd9c4a73f Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 6 Feb 2024 07:23:46 +0800
Subject: [PATCH] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`.
 NFC.

---
 llvm/include/llvm/IR/PatternMatch.h           | 24 +++++++++++++++++++
 llvm/lib/Analysis/InstructionSimplify.cpp     |  2 +-
 .../InstCombine/InstCombineAndOrXor.cpp       | 22 +++++++----------
 .../InstCombine/InstCombineCompares.cpp       | 17 ++++---------
 .../InstCombine/InstCombineSelect.cpp         |  6 +++--
 5 files changed, 43 insertions(+), 28 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8..7ce42e3063909 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,30 @@ m_BitCast(const OpTy &Op) {
   return CastOperator_match<OpTy, Instruction::BitCast>(Op);
 }
 
+template <typename Op_t> struct ElementWiseBitCast_match {
+  Op_t Op;
+
+  ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+  template <typename OpTy> bool match(OpTy *V) {
+    if (auto *I = dyn_cast<BitCastInst>(V)) {
+      Type *SrcType = I->getSrcTy();
+      Type *DstType = I->getType();
+      // Make sure the bitcast doesn't change between scalar and vector and
+      // doesn't change the number of vector elements.
+      if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+          SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits())
+        return Op.match(I->getOperand(0));
+    }
+    return false;
+  }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+  return ElementWiseBitCast_match<OpTy>(Op);
+}
+
 /// Matches PtrToInt.
 template <typename OpTy>
 inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f3..01b017142cfcb 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   // floating-point casts:
   // icmp slt (bitcast (uitofp X)),  0 --> false
   // icmp sgt (bitcast (uitofp X)), -1 --> true
-  if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+  if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
     if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
       return ConstantInt::getFalse(ITy);
     if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6ca4d6d673068..0409f33445f0a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2493,14 +2493,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
   // Assumes any IEEE-represented type has the sign bit in the high bit.
   // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
       match(Op1, m_MaxSignedValue()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
-        Attribute::NoImplicitFloat)) {
+          Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       return new BitCastInst(FAbs, I.getType());
     }
@@ -3925,13 +3923,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
   // This is generous interpretation of noimplicitfloat, this is not a true
   // floating-point operation.
   Value *CastOp;
-  if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+  if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+      match(Op1, m_SignMask()) &&
       !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat)) {
     Type *EltTy = CastOp->getType()->getScalarType();
-    if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-        EltTy->getPrimitiveSizeInBits() ==
-        I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+    if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
       Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
       Value *FNegFAbs = Builder.CreateFNeg(FAbs);
       return new BitCastInst(FNegFAbs, I.getType());
@@ -4701,13 +4698,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
     // Assumes any IEEE-represented type has the sign bit in the high bit.
     // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
     Value *CastOp;
-    if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+    if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+        match(Op1, m_SignMask()) &&
         !Builder.GetInsertBlock()->getParent()->hasFnAttribute(
             Attribute::NoImplicitFloat)) {
       Type *EltTy = CastOp->getType()->getScalarType();
-      if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
-          EltTy->getPrimitiveSizeInBits() ==
-          I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+      if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
         Value *FNeg = Builder.CreateFNeg(CastOp);
         return new BitCastInst(FNeg, I.getType());
       }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d..cda1061fb35a8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
   Value *V;
   if (!Cmp.getParent()->getParent()->hasFnAttribute(
           Attribute::NoImplicitFloat) &&
-      Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
-    Type *SrcType = V->getType();
-    Type *DstType = X->getType();
-    Type *FPType = SrcType->getScalarType();
-    // Make sure the bitcast doesn't change between scalar and vector and
-    // doesn't change the number of vector elements.
-    if (SrcType->isVectorTy() == DstType->isVectorTy() &&
-        SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
-        FPType->isIEEELikeFPTy() && C1 == *C2) {
+      Cmp.isEquality() &&
+      match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+    Type *FPType = V->getType()->getScalarType();
+    if (FPType->isIEEELikeFPTy() && C1 == *C2) {
       APInt ExponentMask =
           APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
       if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
   // Ignore signbit of bitcasted int when comparing equality to FP 0.0:
   // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
   if (match(Op1, m_PosZeroFP()) &&
-      match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
-      X->getType()->isVectorTy() == OpType->isVectorTy() &&
-      X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+      match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
     ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
     if (Pred == FCmpInst::FCMP_OEQ)
       IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e62..d6f447125269c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2382,7 +2382,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
   const APInt *C;
   bool IsTrueIfSignSet;
   ICmpInst::Predicate Pred;
-  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+  if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+                                   m_APInt(C)))) ||
       !InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
       X->getType() != SelType)
     return nullptr;
@@ -2783,7 +2784,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
     CmpInst::Predicate Pred;
     const APInt *C;
     bool TrueIfSigned;
-    if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+    if (!match(CondVal,
+               m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
         !IC.isSignBitCheck(Pred, *C, TrueIfSigned))
       continue;
     if (!match(TrueVal, m_FNeg(m_Specific(X))))



More information about the llvm-commits mailing list