[llvm] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`. NFC. (PR #80764)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 5 15:42:27 PST 2024
https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/80764
This patch introduces a matching helper `m_ElementWiseBitCast`, which is used for matching element-wise int<-> fp casts.
The motivation of this patch is to avoid duplicating checks in https://github.com/llvm/llvm-project/pull/80740 and https://github.com/llvm/llvm-project/pull/80414.
>From 47f574404987726094b2258ecb5d235dd9c4a73f Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 6 Feb 2024 07:23:46 +0800
Subject: [PATCH] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`.
NFC.
---
llvm/include/llvm/IR/PatternMatch.h | 24 +++++++++++++++++++
llvm/lib/Analysis/InstructionSimplify.cpp | 2 +-
.../InstCombine/InstCombineAndOrXor.cpp | 22 +++++++----------
.../InstCombine/InstCombineCompares.cpp | 17 ++++---------
.../InstCombine/InstCombineSelect.cpp | 6 +++--
5 files changed, 43 insertions(+), 28 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8..7ce42e3063909 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,30 @@ m_BitCast(const OpTy &Op) {
return CastOperator_match<OpTy, Instruction::BitCast>(Op);
}
+template <typename Op_t> struct ElementWiseBitCast_match {
+ Op_t Op;
+
+ ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+ template <typename OpTy> bool match(OpTy *V) {
+ if (auto *I = dyn_cast<BitCastInst>(V)) {
+ Type *SrcType = I->getSrcTy();
+ Type *DstType = I->getType();
+ // Make sure the bitcast doesn't change between scalar and vector and
+ // doesn't change the number of vector elements.
+ if (SrcType->isVectorTy() == DstType->isVectorTy() &&
+ SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits())
+ return Op.match(I->getOperand(0));
+ }
+ return false;
+ }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+ return ElementWiseBitCast_match<OpTy>(Op);
+}
+
/// Matches PtrToInt.
template <typename OpTy>
inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f3..01b017142cfcb 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
// floating-point casts:
// icmp slt (bitcast (uitofp X)), 0 --> false
// icmp sgt (bitcast (uitofp X)), -1 --> true
- if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+ if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
return ConstantInt::getFalse(ITy);
if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6ca4d6d673068..0409f33445f0a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2493,14 +2493,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
// Assumes any IEEE-represented type has the sign bit in the high bit.
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
match(Op1, m_MaxSignedValue()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
- Attribute::NoImplicitFloat)) {
+ Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
return new BitCastInst(FAbs, I.getType());
}
@@ -3925,13 +3923,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// This is generous interpretation of noimplicitfloat, this is not a true
// floating-point operation.
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+ match(Op1, m_SignMask()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
Value *FNegFAbs = Builder.CreateFNeg(FAbs);
return new BitCastInst(FNegFAbs, I.getType());
@@ -4701,13 +4698,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
// Assumes any IEEE-represented type has the sign bit in the high bit.
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+ match(Op1, m_SignMask()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FNeg = Builder.CreateFNeg(CastOp);
return new BitCastInst(FNeg, I.getType());
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d..cda1061fb35a8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
Value *V;
if (!Cmp.getParent()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat) &&
- Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
- Type *SrcType = V->getType();
- Type *DstType = X->getType();
- Type *FPType = SrcType->getScalarType();
- // Make sure the bitcast doesn't change between scalar and vector and
- // doesn't change the number of vector elements.
- if (SrcType->isVectorTy() == DstType->isVectorTy() &&
- SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
- FPType->isIEEELikeFPTy() && C1 == *C2) {
+ Cmp.isEquality() &&
+ match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+ Type *FPType = V->getType()->getScalarType();
+ if (FPType->isIEEELikeFPTy() && C1 == *C2) {
APInt ExponentMask =
APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
// Ignore signbit of bitcasted int when comparing equality to FP 0.0:
// fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
if (match(Op1, m_PosZeroFP()) &&
- match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
- X->getType()->isVectorTy() == OpType->isVectorTy() &&
- X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+ match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
if (Pred == FCmpInst::FCMP_OEQ)
IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e62..d6f447125269c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2382,7 +2382,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
const APInt *C;
bool IsTrueIfSignSet;
ICmpInst::Predicate Pred;
- if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+ if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+ m_APInt(C)))) ||
!InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
X->getType() != SelType)
return nullptr;
@@ -2783,7 +2784,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
CmpInst::Predicate Pred;
const APInt *C;
bool TrueIfSigned;
- if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+ if (!match(CondVal,
+ m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
!IC.isSignBitCheck(Pred, *C, TrueIfSigned))
continue;
if (!match(TrueVal, m_FNeg(m_Specific(X))))
More information about the llvm-commits
mailing list