[llvm] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`. NFC. (PR #80764)
Yingwei Zheng via llvm-commits
llvm-commits at lists.llvm.org
Tue Feb 6 23:52:22 PST 2024
https://github.com/dtcxzyw updated https://github.com/llvm/llvm-project/pull/80764
>From 62894a2287d4e73fc73f54976665f0346a6df260 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Tue, 6 Feb 2024 07:23:46 +0800
Subject: [PATCH] [PatternMatch] Add a matching helper `m_ElementWiseBitCast`.
NFC.
---
llvm/include/llvm/IR/PatternMatch.h | 28 +++++++++++++
llvm/lib/Analysis/InstructionSimplify.cpp | 2 +-
.../InstCombine/InstCombineAndOrXor.cpp | 22 ++++------
.../InstCombine/InstCombineCasts.cpp | 12 ++++--
.../InstCombine/InstCombineCompares.cpp | 17 +++-----
.../InstCombine/InstCombineSelect.cpp | 11 ++---
.../InstCombine/InstructionCombining.cpp | 15 -------
.../InstSimplify/cast-unsigned-icmp-cmp-0.ll | 13 ++++++
llvm/unittests/IR/PatternMatch.cpp | 41 +++++++++++++++++++
9 files changed, 110 insertions(+), 51 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 878079c4fe4e8e..3155e7dc38b64a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -1711,6 +1711,34 @@ m_BitCast(const OpTy &Op) {
return CastOperator_match<OpTy, Instruction::BitCast>(Op);
}
+template <typename Op_t> struct ElementWiseBitCast_match {
+ Op_t Op;
+
+ ElementWiseBitCast_match(const Op_t &OpMatch) : Op(OpMatch) {}
+
+ template <typename OpTy> bool match(OpTy *V) {
+ BitCastInst *I = dyn_cast<BitCastInst>(V);
+ if (!I)
+ return false;
+ Type *SrcType = I->getSrcTy();
+ Type *DstType = I->getType();
+ // Make sure the bitcast doesn't change between scalar and vector and
+ // doesn't change the number of vector elements.
+ if (SrcType->isVectorTy() != DstType->isVectorTy())
+ return false;
+ if (VectorType *SrcVecTy = dyn_cast<VectorType>(SrcType);
+ SrcVecTy && SrcVecTy->getElementCount() !=
+ cast<VectorType>(DstType)->getElementCount())
+ return false;
+ return Op.match(I->getOperand(0));
+ }
+};
+
+template <typename OpTy>
+inline ElementWiseBitCast_match<OpTy> m_ElementWiseBitCast(const OpTy &Op) {
+ return ElementWiseBitCast_match<OpTy>(Op);
+}
+
/// Matches PtrToInt.
template <typename OpTy>
inline CastOperator_match<OpTy, Instruction::PtrToInt>
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 2793b798f35f36..01b017142cfcb0 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3034,7 +3034,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
// floating-point casts:
// icmp slt (bitcast (uitofp X)), 0 --> false
// icmp sgt (bitcast (uitofp X)), -1 --> true
- if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) {
+ if (match(LHS, m_ElementWiseBitCast(m_UIToFP(m_Value(X))))) {
if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero()))
return ConstantInt::getFalse(ITy);
if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes()))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 6a827e2f3a9637..7b93848eab3517 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -2531,14 +2531,12 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
// Assumes any IEEE-represented type has the sign bit in the high bit.
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
match(Op1, m_MaxSignedValue()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
- Attribute::NoImplicitFloat)) {
+ Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
return new BitCastInst(FAbs, I.getType());
}
@@ -3963,13 +3961,12 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
// This is generous interpretation of noimplicitfloat, this is not a true
// floating-point operation.
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+ match(Op1, m_SignMask()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp);
Value *FNegFAbs = Builder.CreateFNeg(FAbs);
return new BitCastInst(FNegFAbs, I.getType());
@@ -4739,13 +4736,12 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
// Assumes any IEEE-represented type has the sign bit in the high bit.
// TODO: Unify with APInt matcher. This version allows undef unlike m_APInt
Value *CastOp;
- if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) &&
+ if (match(Op0, m_ElementWiseBitCast(m_Value(CastOp))) &&
+ match(Op1, m_SignMask()) &&
!Builder.GetInsertBlock()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat)) {
Type *EltTy = CastOp->getType()->getScalarType();
- if (EltTy->isFloatingPointTy() && EltTy->isIEEE() &&
- EltTy->getPrimitiveSizeInBits() ==
- I.getType()->getScalarType()->getPrimitiveSizeInBits()) {
+ if (EltTy->isFloatingPointTy() && EltTy->isIEEE()) {
Value *FNeg = Builder.CreateFNeg(CastOp);
return new BitCastInst(FNeg, I.getType());
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 58f0763bb0c0cd..ed47de287302ed 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -182,9 +182,15 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) {
if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() ||
(CI.getOpcode() == Instruction::Trunc &&
shouldChangeType(CI.getSrcTy(), CI.getType()))) {
- if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
- replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
- return NV;
+
+ // If it's a bitcast involving vectors, make sure it has the same number
+ // of elements on both sides.
+ if (CI.getOpcode() != Instruction::BitCast ||
+ match(&CI, m_ElementWiseBitCast(m_Value()))) {
+ if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) {
+ replaceAllDbgUsesWith(*Sel, *NV, CI, DT);
+ return NV;
+ }
}
}
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 380cb3504209d3..cda1061fb35a8d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1834,15 +1834,10 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
Value *V;
if (!Cmp.getParent()->getParent()->hasFnAttribute(
Attribute::NoImplicitFloat) &&
- Cmp.isEquality() && match(X, m_OneUse(m_BitCast(m_Value(V))))) {
- Type *SrcType = V->getType();
- Type *DstType = X->getType();
- Type *FPType = SrcType->getScalarType();
- // Make sure the bitcast doesn't change between scalar and vector and
- // doesn't change the number of vector elements.
- if (SrcType->isVectorTy() == DstType->isVectorTy() &&
- SrcType->getScalarSizeInBits() == DstType->getScalarSizeInBits() &&
- FPType->isIEEELikeFPTy() && C1 == *C2) {
+ Cmp.isEquality() &&
+ match(X, m_OneUse(m_ElementWiseBitCast(m_Value(V))))) {
+ Type *FPType = V->getType()->getScalarType();
+ if (FPType->isIEEELikeFPTy() && C1 == *C2) {
APInt ExponentMask =
APFloat::getInf(FPType->getFltSemantics()).bitcastToAPInt();
if (C1 == ExponentMask) {
@@ -7754,9 +7749,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
// Ignore signbit of bitcasted int when comparing equality to FP 0.0:
// fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
if (match(Op1, m_PosZeroFP()) &&
- match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
- X->getType()->isVectorTy() == OpType->isVectorTy() &&
- X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+ match(Op0, m_OneUse(m_ElementWiseBitCast(m_Value(X))))) {
ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
if (Pred == FCmpInst::FCMP_OEQ)
IntPred = ICmpInst::ICMP_EQ;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 2756f81ed9e620..527037881edb19 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2365,9 +2365,6 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
Value *FVal = Sel.getFalseValue();
Type *SelType = Sel.getType();
- if (ICmpInst::makeCmpResultType(TVal->getType()) != Cond->getType())
- return nullptr;
-
// Match select ?, TC, FC where the constants are equal but negated.
// TODO: Generalize to handle a negated variable operand?
const APFloat *TC, *FC;
@@ -2382,7 +2379,8 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
const APInt *C;
bool IsTrueIfSignSet;
ICmpInst::Predicate Pred;
- if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) ||
+ if (!match(Cond, m_OneUse(m_ICmp(Pred, m_ElementWiseBitCast(m_Value(X)),
+ m_APInt(C)))) ||
!InstCombiner::isSignBitCheck(Pred, *C, IsTrueIfSignSet) ||
X->getType() != SelType)
return nullptr;
@@ -2770,8 +2768,6 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
// Match select with (icmp slt (bitcast X to int), 0)
// or (icmp sgt (bitcast X to int), -1)
- if (ICmpInst::makeCmpResultType(SI.getType()) != CondVal->getType())
- return ChangedFMF ? &SI : nullptr;
for (bool Swap : {false, true}) {
Value *TrueVal = SI.getTrueValue();
@@ -2783,7 +2779,8 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
CmpInst::Predicate Pred;
const APInt *C;
bool TrueIfSigned;
- if (!match(CondVal, m_ICmp(Pred, m_BitCast(m_Specific(X)), m_APInt(C))) ||
+ if (!match(CondVal,
+ m_ICmp(Pred, m_ElementWiseBitCast(m_Specific(X)), m_APInt(C))) ||
!IC.isSignBitCheck(Pred, *C, TrueIfSigned))
continue;
if (!match(TrueVal, m_FNeg(m_Specific(X))))
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 4e88a5cc535b11..651e852bf6ed02 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -1474,21 +1474,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
if (SI->getType()->isIntOrIntVectorTy(1))
return nullptr;
- // If it's a bitcast involving vectors, make sure it has the same number of
- // elements on both sides.
- if (auto *BC = dyn_cast<BitCastInst>(&Op)) {
- VectorType *DestTy = dyn_cast<VectorType>(BC->getDestTy());
- VectorType *SrcTy = dyn_cast<VectorType>(BC->getSrcTy());
-
- // Verify that either both or neither are vectors.
- if ((SrcTy == nullptr) != (DestTy == nullptr))
- return nullptr;
-
- // If vectors, verify that they have the same number of elements.
- if (SrcTy && SrcTy->getElementCount() != DestTy->getElementCount())
- return nullptr;
- }
-
// Test if a FCmpInst instruction is used exclusively by a select as
// part of a minimum or maximum operation. If so, refrain from doing
// any other folding. This helps out other analyses which understand
diff --git a/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll b/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll
index 8014133c5d3739..5a61a060785ff4 100644
--- a/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll
+++ b/llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll
@@ -57,6 +57,19 @@ define <2 x i1> @i32_cast_cmp_sgt_int_m1_uitofp_float_vec(<2 x i32> %i) {
ret <2 x i1> %cmp
}
+define i1 @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_mismatch(<2 x i32> %i) {
+; CHECK-LABEL: @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_mismatch(
+; CHECK-NEXT: [[F:%.*]] = uitofp <2 x i32> [[I:%.*]] to <2 x float>
+; CHECK-NEXT: [[B:%.*]] = bitcast <2 x float> [[F]] to i64
+; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i64 [[B]], -1
+; CHECK-NEXT: ret i1 [[CMP]]
+;
+ %f = uitofp <2 x i32> %i to <2 x float>
+ %b = bitcast <2 x float> %f to i64
+ %cmp = icmp sgt i64 %b, -1
+ ret i1 %cmp
+}
+
define <3 x i1> @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_undef(<3 x i32> %i) {
; CHECK-LABEL: @i32_cast_cmp_sgt_int_m1_uitofp_float_vec_undef(
; CHECK-NEXT: ret <3 x i1> <i1 true, i1 true, i1 true>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 9e9e41b8fbad0d..2aa948eb3a0fe7 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -530,6 +530,47 @@ TEST_F(PatternMatchTest, ZExtSExtSelf) {
EXPECT_TRUE(m_ZExtOrSExtOrSelf(m_One()).match(One64S));
}
+TEST_F(PatternMatchTest, BitCast) {
+ Value *OneDouble = ConstantFP::get(IRB.getDoubleTy(), APFloat(1.0));
+ Value *ScalableDouble = ConstantFP::get(
+ VectorType::get(IRB.getDoubleTy(), 2, /*Scalable=*/true), APFloat(1.0));
+ // scalar -> scalar
+ Value *DoubleToI64 = IRB.CreateBitCast(OneDouble, IRB.getInt64Ty());
+ // scalar -> vector
+ Value *DoubleToV2I32 = IRB.CreateBitCast(
+ OneDouble, VectorType::get(IRB.getInt32Ty(), 2, /*Scalable=*/false));
+ // vector -> scalar
+ Value *V2I32ToDouble = IRB.CreateBitCast(DoubleToV2I32, IRB.getDoubleTy());
+ // vector -> vector (same count)
+ Value *V2I32ToV2Float = IRB.CreateBitCast(
+ DoubleToV2I32, VectorType::get(IRB.getFloatTy(), 2, /*Scalable=*/false));
+ // vector -> vector (different count)
+ Value *V2I32TOV4I16 = IRB.CreateBitCast(
+ DoubleToV2I32, VectorType::get(IRB.getInt16Ty(), 4, /*Scalable=*/false));
+ // scalable vector -> scalable vector (same count)
+ Value *NXV2DoubleToNXV2I64 = IRB.CreateBitCast(
+ ScalableDouble, VectorType::get(IRB.getInt64Ty(), 2, /*Scalable=*/true));
+ // scalable vector -> scalable vector (different count)
+ Value *NXV2I64ToNXV4I32 = IRB.CreateBitCast(
+ NXV2DoubleToNXV2I64, VectorType::get(IRB.getInt32Ty(), 4, /*Scalable=*/true));
+
+ EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToI64));
+ EXPECT_TRUE(m_BitCast(m_Value()).match(DoubleToV2I32));
+ EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToDouble));
+ EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32ToV2Float));
+ EXPECT_TRUE(m_BitCast(m_Value()).match(V2I32TOV4I16));
+ EXPECT_TRUE(m_BitCast(m_Value()).match(NXV2DoubleToNXV2I64));
+ EXPECT_TRUE(m_BitCast(m_Value()).match(NXV2I64ToNXV4I32));
+
+ EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(DoubleToI64));
+ EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(DoubleToV2I32));
+ EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32ToDouble));
+ EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(V2I32ToV2Float));
+ EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(V2I32TOV4I16));
+ EXPECT_TRUE(m_ElementWiseBitCast(m_Value()).match(NXV2DoubleToNXV2I64));
+ EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
+}
+
TEST_F(PatternMatchTest, Power2) {
Value *C128 = IRB.getInt32(128);
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
More information about the llvm-commits
mailing list