[llvm] [PatternMatching] Add generic API for matching constants using custom conditions (PR #85676)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 29 18:52:47 PDT 2024
https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/85676
>From d52b966acb772fbd02dafe19735d64a1e1b86c00 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:14 -0500
Subject: [PATCH 1/2] [PatternMatching] Add generic API for matching constants
using custom conditions
The new API is:
`m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
- Matches non-undef constants s.t `Lambda(ele)` is true for all
elements.
`m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
- Matches constants/undef s.t `Lambda(ele)` is true for all
elements.
The goal with these is to be able to replace the common usage of:
```
match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
match(X, m_CheckedInt(C, CustomChecks);
```
The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.
The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
---
llvm/include/llvm/IR/PatternMatch.h | 91 +++++++-
llvm/unittests/IR/PatternMatch.cpp | 318 ++++++++++++++++++++++++++++
2 files changed, 398 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 46372c78263a1d..9bc4cc74edf361 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
- if (isa<UndefValue>(Elt))
+ if (isa<UndefValue>(Elt)) {
+ if (!AllowUndefs)
+ return false;
continue;
+ }
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
};
/// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
/// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
const APInt *&Res;
api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
const APFloat *&Res;
apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantFP>(
- C->getSplatValue(/* AllowUndef */ true)))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
//
///////////////////////////////////////////////////////////////////////////////
+template <typename APTy> struct custom_checkfn {
+ function_ref<bool(const APTy &)> CheckFn;
+ bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+ function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+ function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
struct is_any_apint {
bool isValue(const APInt &C) { return true; }
};
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index a0b873de2d5860..fdb4684f82548f 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -611,6 +611,247 @@ TEST_F(PatternMatchTest, BitCast) {
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
}
+TEST_F(PatternMatchTest, CustomCheckFn) {
+ APInt I0(64, 0);
+ APInt I1(64, 0);
+
+ auto CheckIsZeroI = [](const APInt &C) { return C.isZero(); };
+ auto CheckIsEqI1 = [&I1](const APInt &C) { return C.eq(I1); };
+ auto CheckIsNeI1 = [&I1](const APInt &C) { return !C.eq(I1); };
+
+ custom_checkfn<APInt> CustomCheckZeroI;
+ CustomCheckZeroI.CheckFn = CheckIsZeroI;
+ custom_checkfn<APInt> CustomCheckEqI1;
+ CustomCheckEqI1.CheckFn = CheckIsEqI1;
+ custom_checkfn<APInt> CustomCheckNeI1;
+ CustomCheckNeI1.CheckFn = CheckIsNeI1;
+
+ EXPECT_TRUE(CustomCheckZeroI.isValue(I0));
+ EXPECT_TRUE(CustomCheckEqI1.isValue(I0));
+ EXPECT_FALSE(CustomCheckNeI1.isValue(I0));
+
+ I0.setBit(0);
+
+ EXPECT_FALSE(CustomCheckZeroI.isValue(I0));
+ EXPECT_FALSE(CustomCheckEqI1.isValue(I0));
+ EXPECT_TRUE(CustomCheckNeI1.isValue(I0));
+
+ I1.setBit(0);
+
+ EXPECT_FALSE(CustomCheckZeroI.isValue(I0));
+ EXPECT_TRUE(CustomCheckEqI1.isValue(I0));
+ EXPECT_FALSE(CustomCheckNeI1.isValue(I0));
+
+ APFloat F0(0.0);
+ APFloat F1(0.0);
+
+ auto CheckIsZeroF = [](const APFloat &C) { return C.isZero(); };
+ auto CheckIsEqF1 = [&F1](const APFloat &C) {
+ return C.bitcastToAPInt().eq(F1.bitcastToAPInt());
+ };
+ auto CheckIsNeF1 = [&F1](const APFloat &C) {
+ return !C.bitcastToAPInt().eq(F1.bitcastToAPInt());
+ };
+
+ custom_checkfn<APFloat> CustomCheckZeroF;
+ CustomCheckZeroF.CheckFn = CheckIsZeroF;
+ custom_checkfn<APFloat> CustomCheckEqF1;
+ CustomCheckEqF1.CheckFn = CheckIsEqF1;
+ custom_checkfn<APFloat> CustomCheckNeF1;
+ CustomCheckNeF1.CheckFn = CheckIsNeF1;
+
+ EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
+ EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
+ EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
+
+ F0 = -F0;
+
+ EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
+ EXPECT_FALSE(CustomCheckEqF1.isValue(F0));
+ EXPECT_TRUE(CustomCheckNeF1.isValue(F0));
+
+ F0 = -F0;
+
+ EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
+ EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
+ EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
+
+ F0 = F0 + APFloat(1.0);
+
+ EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
+ EXPECT_FALSE(CustomCheckEqF1.isValue(F0));
+ EXPECT_TRUE(CustomCheckNeF1.isValue(F0));
+
+ F1 = F1 + APFloat(1.0);
+
+ EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
+ EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
+ EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
+}
+
+TEST_F(PatternMatchTest, CheckedInt) {
+ Type *I8Ty = IRB.getInt8Ty();
+ const APInt *Res = nullptr;
+
+ auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+ auto CheckTrue = [](const APInt &) { return true; };
+ auto CheckFalse = [](const APInt &) { return false; };
+ auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+ auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+ auto DoScalarCheck = [&](int8_t Val) {
+ APInt APVal(8, Val);
+ Constant *C = ConstantInt::get(I8Ty, Val);
+
+ Res = nullptr;
+ EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+ EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+ EXPECT_EQ(*Res, APVal);
+
+ Res = nullptr;
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal),
+ m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal),
+ m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+ };
+
+ DoScalarCheck(0);
+ DoScalarCheck(1);
+ DoScalarCheck(2);
+ DoScalarCheck(3);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+ function_ref<bool(const APInt &)> CheckFn,
+ bool UndefAsPoison) {
+ SmallVector<Constant *> VecElems;
+ std::optional<bool> Okay;
+ bool AllSame = true;
+ bool HasUndef = false;
+ std::optional<APInt> First;
+ for (const std::optional<int8_t> &Val : Vals) {
+ if (!Val.has_value()) {
+ VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+ : UndefValue::get(I8Ty));
+ HasUndef = true;
+ } else {
+ if (!Okay.has_value())
+ Okay = true;
+ APInt APVal(8, *Val);
+ if (!First.has_value())
+ First = APVal;
+ else
+ AllSame &= First->eq(APVal);
+ Okay = *Okay && CheckFn(APVal);
+ VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+ }
+ }
+
+ Constant *C = ConstantVector::get(VecElems);
+ EXPECT_EQ(!HasUndef && Okay.value_or(false),
+ m_CheckedInt(CheckFn).match(C));
+ EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+ Res = nullptr;
+ bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+
+ Res = nullptr;
+ Expec = AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedIntAllowUndef(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+ };
+ auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
+ DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
+ DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
+ DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
+ };
+
+ DoVecCheck({0, 1});
+ DoVecCheck({1, 1});
+ DoVecCheck({1, 2});
+ DoVecCheck({1, std::nullopt});
+ DoVecCheck({1, std::nullopt, 1});
+ DoVecCheck({1, std::nullopt, 2});
+ DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
+}
+
TEST_F(PatternMatchTest, Power2) {
Value *C128 = IRB.getInt32(128);
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1315,6 +1556,63 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorInfUndef, m_Finite()));
EXPECT_FALSE(match(VectorNaNUndef, m_Finite()));
+ auto CheckTrue = [](const APFloat &) { return true; };
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
+
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckTrue)));
+
+ auto CheckFalse = [](const APFloat &) { return false; };
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
+
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckFalse)));
+
+ auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
+
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+
const APFloat *C;
// Regardless of whether undefs are allowed,
// a fully undef constant does not match.
@@ -1324,6 +1622,7 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidUndef(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowUndef(C)));
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
// We can always match simple constants and simple splats.
C = nullptr;
@@ -1344,14 +1643,33 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_APFloatAllowUndef(C)));
EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckTrue)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+ EXPECT_TRUE(C->isZero());
// Whether splats with undef can be matched depends on the matcher.
EXPECT_FALSE(match(VectorZeroUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorZeroUndef, m_APFloatForbidUndef(C)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(C, CheckTrue)));
C = nullptr;
EXPECT_TRUE(match(VectorZeroUndef, m_APFloatAllowUndef(C)));
EXPECT_TRUE(C->isZero());
C = nullptr;
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckTrue)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
EXPECT_TRUE(match(VectorZeroUndef, m_Finite(C)));
EXPECT_TRUE(C->isZero());
}
>From 0f5619afede72c42b81424488fa7a6dd23f49f4c Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:18 -0500
Subject: [PATCH 2/2] [InstCombine] Add example usage for new `Checked` matcher
API
There is no real motivation for this change other than to highlight a
case where the new `Checked` matcher API can handle non-splat-vecs
without increasing code complexity.
---
llvm/include/llvm/IR/PatternMatch.h | 207 ++++++++++--------
.../InstCombine/InstCombineCompares.cpp | 66 +++---
.../InstCombine/signed-truncation-check.ll | 30 +--
llvm/unittests/IR/PatternMatch.cpp | 169 ++++++++------
4 files changed, 260 insertions(+), 212 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 9bc4cc74edf361..ca0bbf4d1279a0 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -387,16 +387,22 @@ struct cstval_pred_ty : public Predicate {
};
/// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
+template <typename Predicate>
+using cst_or_undef_pred_ty = cst_pred_ty<Predicate, true>;
+
/// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
+template <typename Predicate>
+using cstfp_or_undef_pred_ty = cstfp_pred_ty<Predicate, true>;
+
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
struct api_pred_ty : public Predicate {
const APInt *&Res;
@@ -421,10 +427,13 @@ struct api_pred_ty : public Predicate {
}
};
+template <typename Predicate>
+using api_or_undef_pred_ty = api_pred_ty<Predicate, true>;
+
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
struct apf_pred_ty : public Predicate {
const APFloat *&Res;
@@ -449,6 +458,9 @@ struct apf_pred_ty : public Predicate {
}
};
+template <typename Predicate>
+using apf_or_undef_pred_ty = apf_pred_ty<Predicate, true>;
+
///////////////////////////////////////////////////////////////////////////////
//
// Encapsulate constant value queries for use in templated predicate matchers.
@@ -479,15 +491,15 @@ m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
// Match and integer or vector where CheckFn(ele) for each element is true.
// For vectors, undefined elements are assumed to match.
-inline cst_pred_ty<custom_checkfn<APInt>>
+inline cst_or_undef_pred_ty<custom_checkfn<APInt>>
m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
- return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+ return cst_or_undef_pred_ty<custom_checkfn<APInt>>{CheckFn};
}
-inline api_pred_ty<custom_checkfn<APInt>>
+inline api_or_undef_pred_ty<custom_checkfn<APInt>>
m_CheckedIntAllowUndef(const APInt *&V,
function_ref<bool(const APInt &)> CheckFn) {
- api_pred_ty<custom_checkfn<APInt>> P(V);
+ api_or_undef_pred_ty<custom_checkfn<APInt>> P(V);
P.CheckFn = CheckFn;
return P;
}
@@ -508,15 +520,15 @@ m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
// Match and float or vector where CheckFn(ele) for each element is true.
// For vectors, undefined elements are assumed to match.
-inline cstfp_pred_ty<custom_checkfn<APFloat>>
+inline cstfp_or_undef_pred_ty<custom_checkfn<APFloat>>
m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
- return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+ return cstfp_or_undef_pred_ty<custom_checkfn<APFloat>>{CheckFn};
}
-inline apf_pred_ty<custom_checkfn<APFloat>>
+inline apf_or_undef_pred_ty<custom_checkfn<APFloat>>
m_CheckedFpAllowUndef(const APFloat *&V,
function_ref<bool(const APFloat &)> CheckFn) {
- apf_pred_ty<custom_checkfn<APFloat>> P(V);
+ apf_or_undef_pred_ty<custom_checkfn<APFloat>> P(V);
P.CheckFn = CheckFn;
return P;
}
@@ -526,16 +538,16 @@ struct is_any_apint {
};
/// Match an integer or vector with any integral constant.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_any_apint> m_AnyIntegralConstant() {
- return cst_pred_ty<is_any_apint>();
+inline cst_or_undef_pred_ty<is_any_apint> m_AnyIntegralConstant() {
+ return cst_or_undef_pred_ty<is_any_apint>();
}
struct is_shifted_mask {
bool isValue(const APInt &C) { return C.isShiftedMask(); }
};
-inline cst_pred_ty<is_shifted_mask> m_ShiftedMask() {
- return cst_pred_ty<is_shifted_mask>();
+inline cst_or_undef_pred_ty<is_shifted_mask> m_ShiftedMask() {
+ return cst_or_undef_pred_ty<is_shifted_mask>();
}
struct is_all_ones {
@@ -543,8 +555,8 @@ struct is_all_ones {
};
/// Match an integer or vector with all bits set.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_all_ones> m_AllOnes() {
- return cst_pred_ty<is_all_ones>();
+inline cst_or_undef_pred_ty<is_all_ones> m_AllOnes() {
+ return cst_or_undef_pred_ty<is_all_ones>();
}
struct is_maxsignedvalue {
@@ -553,10 +565,11 @@ struct is_maxsignedvalue {
/// Match an integer or vector with values having all bits except for the high
/// bit set (0x7f...).
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_maxsignedvalue> m_MaxSignedValue() {
- return cst_pred_ty<is_maxsignedvalue>();
+inline cst_or_undef_pred_ty<is_maxsignedvalue> m_MaxSignedValue() {
+ return cst_or_undef_pred_ty<is_maxsignedvalue>();
}
-inline api_pred_ty<is_maxsignedvalue> m_MaxSignedValue(const APInt *&V) {
+inline api_or_undef_pred_ty<is_maxsignedvalue>
+m_MaxSignedValue(const APInt *&V) {
return V;
}
@@ -565,30 +578,35 @@ struct is_negative {
};
/// Match an integer or vector of negative values.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_negative> m_Negative() {
- return cst_pred_ty<is_negative>();
+inline cst_or_undef_pred_ty<is_negative> m_Negative() {
+ return cst_or_undef_pred_ty<is_negative>();
+}
+inline api_or_undef_pred_ty<is_negative> m_Negative(const APInt *&V) {
+ return V;
}
-inline api_pred_ty<is_negative> m_Negative(const APInt *&V) { return V; }
struct is_nonnegative {
bool isValue(const APInt &C) { return C.isNonNegative(); }
};
/// Match an integer or vector of non-negative values.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_nonnegative> m_NonNegative() {
- return cst_pred_ty<is_nonnegative>();
+inline cst_or_undef_pred_ty<is_nonnegative> m_NonNegative() {
+ return cst_or_undef_pred_ty<is_nonnegative>();
+}
+inline api_or_undef_pred_ty<is_nonnegative> m_NonNegative(const APInt *&V) {
+ return V;
}
-inline api_pred_ty<is_nonnegative> m_NonNegative(const APInt *&V) { return V; }
struct is_strictlypositive {
bool isValue(const APInt &C) { return C.isStrictlyPositive(); }
};
/// Match an integer or vector of strictly positive values.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_strictlypositive> m_StrictlyPositive() {
- return cst_pred_ty<is_strictlypositive>();
+inline cst_or_undef_pred_ty<is_strictlypositive> m_StrictlyPositive() {
+ return cst_or_undef_pred_ty<is_strictlypositive>();
}
-inline api_pred_ty<is_strictlypositive> m_StrictlyPositive(const APInt *&V) {
+inline api_or_undef_pred_ty<is_strictlypositive>
+m_StrictlyPositive(const APInt *&V) {
return V;
}
@@ -597,32 +615,36 @@ struct is_nonpositive {
};
/// Match an integer or vector of non-positive values.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_nonpositive> m_NonPositive() {
- return cst_pred_ty<is_nonpositive>();
+inline cst_or_undef_pred_ty<is_nonpositive> m_NonPositive() {
+ return cst_or_undef_pred_ty<is_nonpositive>();
+}
+inline api_or_undef_pred_ty<is_nonpositive> m_NonPositive(const APInt *&V) {
+ return V;
}
-inline api_pred_ty<is_nonpositive> m_NonPositive(const APInt *&V) { return V; }
struct is_one {
bool isValue(const APInt &C) { return C.isOne(); }
};
/// Match an integer 1 or a vector with all elements equal to 1.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_one> m_One() { return cst_pred_ty<is_one>(); }
+inline cst_or_undef_pred_ty<is_one> m_One() {
+ return cst_or_undef_pred_ty<is_one>();
+}
struct is_zero_int {
bool isValue(const APInt &C) { return C.isZero(); }
};
/// Match an integer 0 or a vector with all elements equal to 0.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_zero_int> m_ZeroInt() {
- return cst_pred_ty<is_zero_int>();
+inline cst_or_undef_pred_ty<is_zero_int> m_ZeroInt() {
+ return cst_or_undef_pred_ty<is_zero_int>();
}
struct is_zero {
template <typename ITy> bool match(ITy *V) {
auto *C = dyn_cast<Constant>(V);
// FIXME: this should be able to do something for scalable vectors
- return C && (C->isNullValue() || cst_pred_ty<is_zero_int>().match(C));
+ return C && (C->isNullValue() || m_ZeroInt().match(C));
}
};
/// Match any null constant or a vector with all elements equal to 0.
@@ -634,18 +656,21 @@ struct is_power2 {
};
/// Match an integer or vector power-of-2.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_power2> m_Power2() { return cst_pred_ty<is_power2>(); }
-inline api_pred_ty<is_power2> m_Power2(const APInt *&V) { return V; }
+inline cst_or_undef_pred_ty<is_power2> m_Power2() {
+ return cst_or_undef_pred_ty<is_power2>();
+}
+inline api_or_undef_pred_ty<is_power2> m_Power2(const APInt *&V) { return V; }
struct is_negated_power2 {
bool isValue(const APInt &C) { return C.isNegatedPowerOf2(); }
};
/// Match a integer or vector negated power-of-2.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_negated_power2> m_NegatedPower2() {
- return cst_pred_ty<is_negated_power2>();
+inline cst_or_undef_pred_ty<is_negated_power2> m_NegatedPower2() {
+ return cst_or_undef_pred_ty<is_negated_power2>();
}
-inline api_pred_ty<is_negated_power2> m_NegatedPower2(const APInt *&V) {
+inline api_or_undef_pred_ty<is_negated_power2>
+m_NegatedPower2(const APInt *&V) {
return V;
}
@@ -654,10 +679,10 @@ struct is_negated_power2_or_zero {
};
/// Match a integer or vector negated power-of-2.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
- return cst_pred_ty<is_negated_power2_or_zero>();
+inline cst_or_undef_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
+ return cst_or_undef_pred_ty<is_negated_power2_or_zero>();
}
-inline api_pred_ty<is_negated_power2_or_zero>
+inline api_or_undef_pred_ty<is_negated_power2_or_zero>
m_NegatedPower2OrZero(const APInt *&V) {
return V;
}
@@ -667,10 +692,10 @@ struct is_power2_or_zero {
};
/// Match an integer or vector of 0 or power-of-2 values.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_power2_or_zero> m_Power2OrZero() {
- return cst_pred_ty<is_power2_or_zero>();
+inline cst_or_undef_pred_ty<is_power2_or_zero> m_Power2OrZero() {
+ return cst_or_undef_pred_ty<is_power2_or_zero>();
}
-inline api_pred_ty<is_power2_or_zero> m_Power2OrZero(const APInt *&V) {
+inline api_or_undef_pred_ty<is_power2_or_zero> m_Power2OrZero(const APInt *&V) {
return V;
}
@@ -679,8 +704,8 @@ struct is_sign_mask {
};
/// Match an integer or vector with only the sign bit(s) set.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_sign_mask> m_SignMask() {
- return cst_pred_ty<is_sign_mask>();
+inline cst_or_undef_pred_ty<is_sign_mask> m_SignMask() {
+ return cst_or_undef_pred_ty<is_sign_mask>();
}
struct is_lowbit_mask {
@@ -688,20 +713,23 @@ struct is_lowbit_mask {
};
/// Match an integer or vector with only the low bit(s) set.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
- return cst_pred_ty<is_lowbit_mask>();
+inline cst_or_undef_pred_ty<is_lowbit_mask> m_LowBitMask() {
+ return cst_or_undef_pred_ty<is_lowbit_mask>();
+}
+inline api_or_undef_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) {
+ return V;
}
-inline api_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) { return V; }
struct is_lowbit_mask_or_zero {
bool isValue(const APInt &C) { return !C || C.isMask(); }
};
/// Match an integer or vector with only the low bit(s) set.
/// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
- return cst_pred_ty<is_lowbit_mask_or_zero>();
+inline cst_or_undef_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
+ return cst_or_undef_pred_ty<is_lowbit_mask_or_zero>();
}
-inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
+inline api_or_undef_pred_ty<is_lowbit_mask_or_zero>
+m_LowBitMaskOrZero(const APInt *&V) {
return V;
}
@@ -712,9 +740,9 @@ struct icmp_pred_with_threshold {
};
/// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
/// to Threshold. For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<icmp_pred_with_threshold>
+inline cst_or_undef_pred_ty<icmp_pred_with_threshold>
m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) {
- cst_pred_ty<icmp_pred_with_threshold> P;
+ cst_or_undef_pred_ty<icmp_pred_with_threshold> P;
P.Pred = Predicate;
P.Thr = &Threshold;
return P;
@@ -725,15 +753,17 @@ struct is_nan {
};
/// Match an arbitrary NaN constant. This includes quiet and signalling nans.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_nan> m_NaN() { return cstfp_pred_ty<is_nan>(); }
+inline cstfp_or_undef_pred_ty<is_nan> m_NaN() {
+ return cstfp_or_undef_pred_ty<is_nan>();
+}
struct is_nonnan {
bool isValue(const APFloat &C) { return !C.isNaN(); }
};
/// Match a non-NaN FP constant.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_nonnan> m_NonNaN() {
- return cstfp_pred_ty<is_nonnan>();
+inline cstfp_or_undef_pred_ty<is_nonnan> m_NonNaN() {
+ return cstfp_or_undef_pred_ty<is_nonnan>();
}
struct is_inf {
@@ -741,15 +771,17 @@ struct is_inf {
};
/// Match a positive or negative infinity FP constant.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_inf> m_Inf() { return cstfp_pred_ty<is_inf>(); }
+inline cstfp_or_undef_pred_ty<is_inf> m_Inf() {
+ return cstfp_or_undef_pred_ty<is_inf>();
+}
struct is_noninf {
bool isValue(const APFloat &C) { return !C.isInfinity(); }
};
/// Match a non-infinity FP constant, i.e. finite or NaN.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_noninf> m_NonInf() {
- return cstfp_pred_ty<is_noninf>();
+inline cstfp_or_undef_pred_ty<is_noninf> m_NonInf() {
+ return cstfp_or_undef_pred_ty<is_noninf>();
}
struct is_finite {
@@ -757,20 +789,21 @@ struct is_finite {
};
/// Match a finite FP constant, i.e. not infinity or NaN.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_finite> m_Finite() {
- return cstfp_pred_ty<is_finite>();
+inline cstfp_or_undef_pred_ty<is_finite> m_Finite() {
+ return cstfp_or_undef_pred_ty<is_finite>();
}
-inline apf_pred_ty<is_finite> m_Finite(const APFloat *&V) { return V; }
+inline apf_or_undef_pred_ty<is_finite> m_Finite(const APFloat *&V) { return V; }
struct is_finitenonzero {
bool isValue(const APFloat &C) { return C.isFiniteNonZero(); }
};
/// Match a finite non-zero FP constant.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_finitenonzero> m_FiniteNonZero() {
- return cstfp_pred_ty<is_finitenonzero>();
+inline cstfp_or_undef_pred_ty<is_finitenonzero> m_FiniteNonZero() {
+ return cstfp_or_undef_pred_ty<is_finitenonzero>();
}
-inline apf_pred_ty<is_finitenonzero> m_FiniteNonZero(const APFloat *&V) {
+inline apf_or_undef_pred_ty<is_finitenonzero>
+m_FiniteNonZero(const APFloat *&V) {
return V;
}
@@ -779,8 +812,8 @@ struct is_any_zero_fp {
};
/// Match a floating-point negative zero or positive zero.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_any_zero_fp> m_AnyZeroFP() {
- return cstfp_pred_ty<is_any_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_any_zero_fp> m_AnyZeroFP() {
+ return cstfp_or_undef_pred_ty<is_any_zero_fp>();
}
struct is_pos_zero_fp {
@@ -788,8 +821,8 @@ struct is_pos_zero_fp {
};
/// Match a floating-point positive zero.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_pos_zero_fp> m_PosZeroFP() {
- return cstfp_pred_ty<is_pos_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_pos_zero_fp> m_PosZeroFP() {
+ return cstfp_or_undef_pred_ty<is_pos_zero_fp>();
}
struct is_neg_zero_fp {
@@ -797,8 +830,8 @@ struct is_neg_zero_fp {
};
/// Match a floating-point negative zero.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_neg_zero_fp> m_NegZeroFP() {
- return cstfp_pred_ty<is_neg_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_neg_zero_fp> m_NegZeroFP() {
+ return cstfp_or_undef_pred_ty<is_neg_zero_fp>();
}
struct is_non_zero_fp {
@@ -806,8 +839,8 @@ struct is_non_zero_fp {
};
/// Match a floating-point non-zero.
/// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_non_zero_fp> m_NonZeroFP() {
- return cstfp_pred_ty<is_non_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_non_zero_fp> m_NonZeroFP() {
+ return cstfp_or_undef_pred_ty<is_non_zero_fp>();
}
///////////////////////////////////////////////////////////////////////////////
@@ -871,8 +904,7 @@ m_ImmConstant() {
}
/// Match an immediate Constant, capturing the value if we match.
-inline match_combine_and<bind_ty<Constant>,
- match_unless<constantexpr_match>>
+inline match_combine_and<bind_ty<Constant>, match_unless<constantexpr_match>>
m_ImmConstant(Constant *&C) {
return m_CombineAnd(m_Constant(C), m_Unless(m_ConstantExpr()));
}
@@ -1142,11 +1174,11 @@ template <typename Op_t> struct FNeg_match {
if (FPMO->getOpcode() == Instruction::FSub) {
if (FPMO->hasNoSignedZeros()) {
// With 'nsz', any zero goes.
- if (!cstfp_pred_ty<is_any_zero_fp>().match(FPMO->getOperand(0)))
+ if (!m_AnyZeroFP().match(FPMO->getOperand(0)))
return false;
} else {
// Without 'nsz', we need fsub -0.0, X exactly.
- if (!cstfp_pred_ty<is_neg_zero_fp>().match(FPMO->getOperand(0)))
+ if (!m_NegZeroFP().match(FPMO->getOperand(0)))
return false;
}
@@ -1164,7 +1196,8 @@ template <typename OpTy> inline FNeg_match<OpTy> m_FNeg(const OpTy &X) {
/// Match 'fneg X' as 'fsub +-0.0, X'.
template <typename RHS>
-inline BinaryOp_match<cstfp_pred_ty<is_any_zero_fp>, RHS, Instruction::FSub>
+inline BinaryOp_match<cstfp_or_undef_pred_ty<is_any_zero_fp>, RHS,
+ Instruction::FSub>
m_FNegNSZ(const RHS &X) {
return m_FSub(m_AnyZeroFP(), X);
}
@@ -2621,14 +2654,15 @@ inline BinaryOp_match<LHS, RHS, Instruction::Xor, true> m_c_Xor(const LHS &L,
/// Matches a 'Neg' as 'sub 0, V'.
template <typename ValTy>
-inline BinaryOp_match<cst_pred_ty<is_zero_int>, ValTy, Instruction::Sub>
+inline BinaryOp_match<cst_or_undef_pred_ty<is_zero_int>, ValTy,
+ Instruction::Sub>
m_Neg(const ValTy &V) {
return m_Sub(m_ZeroInt(), V);
}
/// Matches a 'Neg' as 'sub nsw 0, V'.
template <typename ValTy>
-inline OverflowingBinaryOp_match<cst_pred_ty<is_zero_int>, ValTy,
+inline OverflowingBinaryOp_match<cst_or_undef_pred_ty<is_zero_int>, ValTy,
Instruction::Sub,
OverflowingBinaryOperator::NoSignedWrap>
m_NSWNeg(const ValTy &V) {
@@ -2639,7 +2673,8 @@ m_NSWNeg(const ValTy &V) {
/// NOTE: we first match the 'Not' (by matching '-1'),
/// and only then match the inner matcher!
template <typename ValTy>
-inline BinaryOp_match<cst_pred_ty<is_all_ones>, ValTy, Instruction::Xor, true>
+inline BinaryOp_match<cst_or_undef_pred_ty<is_all_ones>, ValTy,
+ Instruction::Xor, true>
m_Not(const ValTy &V) {
return m_c_Xor(m_AllOnes(), V);
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index db302d7e526844..3754672f828a1b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6391,57 +6391,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (*CmpC == Op0Min + 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- // X <u C --> X == 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- // X >u C --> X != 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
- return new ICmpInst(ICmpInst::ICMP_NE, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- }
+ // A <s C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- }
+ // A >s C -> A == C+1 if max(A)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
break;
}
}
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index fdb4684f82548f..8b25cfd9cd4de0 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1913,20 +1913,26 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
Constant *CU32Zero = Constant::getIntegerValue(U32Ty, U32Zero);
Constant *CU32DeadBeef = Constant::getIntegerValue(U32Ty, U32DeadBeef);
- EXPECT_TRUE(match(CU32Max, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_FALSE(match(CU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CU32Max, cst_pred_ty<always_true_pred<APInt>>()));
- EXPECT_FALSE(match(CU32Max, cst_pred_ty<always_false_pred<APInt>>()));
+ EXPECT_TRUE(match(CU32Max, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CU32Max, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CU32Max, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CU32Max, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
- EXPECT_FALSE(match(CU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_TRUE(match(CU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
- EXPECT_FALSE(match(CU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+ EXPECT_FALSE(match(CU32Zero, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_TRUE(match(CU32Zero, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CU32Zero, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CU32Zero, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
- EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
- EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CU32DeadBeef, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(
+ match(CU32DeadBeef, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(
+ match(CU32DeadBeef, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CU32DeadBeef, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
// Scalar float
APFloat F32NaN = APFloat::getNaN(APFloat::IEEEsingle());
@@ -1939,20 +1945,26 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
Constant *CF32Zero = ConstantFP::get(F32Ty, F32Zero);
Constant *CF32Pi = ConstantFP::get(F32Ty, F32Pi);
- EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
- EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
- EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ EXPECT_TRUE(match(CF32NaN, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CF32NaN, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CF32NaN, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CF32NaN, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
- EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
- EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
- EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ EXPECT_FALSE(match(CF32Zero, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+ EXPECT_TRUE(match(CF32Zero, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CF32Zero, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CF32Zero, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
- EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
- EXPECT_TRUE(match(CF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
- EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ EXPECT_FALSE(match(CF32Pi, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CF32Pi, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CF32Pi, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CF32Pi, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
auto FixedEC = ElementCount::getFixed(4);
auto ScalableEC = ElementCount::getScalable(4);
@@ -1966,23 +1978,32 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
Constant *CSplatU32Zero = ConstantVector::getSplat(EC, CU32Zero);
Constant *CSplatU32DeadBeef = ConstantVector::getSplat(EC, CU32DeadBeef);
- EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<always_true_pred<APInt>>()));
- EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<always_false_pred<APInt>>()));
-
- EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
- EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+ EXPECT_TRUE(
+ match(CSplatU32Max, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(
+ match(CSplatU32Max, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(
+ match(CSplatU32Max, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CSplatU32Max, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
- EXPECT_FALSE(match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
EXPECT_FALSE(
- match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
+ match(CSplatU32Zero, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
EXPECT_TRUE(
- match(CSplatU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
+ match(CSplatU32Zero, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(
+ match(CSplatU32Zero, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
EXPECT_FALSE(
- match(CSplatU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+ match(CSplatU32Zero, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
+
+ EXPECT_FALSE(
+ match(CSplatU32DeadBeef, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CSplatU32DeadBeef,
+ cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CSplatU32DeadBeef,
+ cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CSplatU32DeadBeef,
+ cst_or_undef_pred_ty<always_false_pred<APInt>>()));
// float
@@ -1990,25 +2011,32 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
Constant *CSplatF32Zero = ConstantVector::getSplat(EC, CF32Zero);
Constant *CSplatF32Pi = ConstantVector::getSplat(EC, CF32Pi);
- EXPECT_TRUE(match(CSplatF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_FALSE(match(CSplatF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
EXPECT_TRUE(
- match(CSplatF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ match(CSplatF32NaN, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
EXPECT_FALSE(
- match(CSplatF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ match(CSplatF32NaN, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CSplatF32NaN,
+ cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CSplatF32NaN,
+ cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
- EXPECT_FALSE(match(CSplatF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_TRUE(match(CSplatF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
- EXPECT_TRUE(
- match(CSplatF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
EXPECT_FALSE(
- match(CSplatF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ match(CSplatF32Zero, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+ EXPECT_TRUE(
+ match(CSplatF32Zero, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CSplatF32Zero,
+ cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CSplatF32Zero,
+ cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
- EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
- EXPECT_TRUE(match(CSplatF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
EXPECT_FALSE(
- match(CSplatF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ match(CSplatF32Pi, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(
+ match(CSplatF32Pi, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CSplatF32Pi,
+ cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CSplatF32Pi,
+ cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
}
// Int arbitrary vector
@@ -2018,16 +2046,21 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
Constant *CU32MaxWithUndef =
ConstantVector::get({CU32Undef, CU32Max, CU32Undef});
- EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CMixedU32, cst_pred_ty<always_true_pred<APInt>>()));
- EXPECT_FALSE(match(CMixedU32, cst_pred_ty<always_false_pred<APInt>>()));
+ EXPECT_FALSE(match(CMixedU32, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CMixedU32, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(
+ match(CMixedU32, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CMixedU32, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
- EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_max_pred>()));
- EXPECT_FALSE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_zero_pred>()));
- EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_TRUE(
+ match(CU32MaxWithUndef, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
EXPECT_FALSE(
- match(CU32MaxWithUndef, cst_pred_ty<always_false_pred<APInt>>()));
+ match(CU32MaxWithUndef, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(
+ match(CU32MaxWithUndef, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CU32MaxWithUndef,
+ cst_or_undef_pred_ty<always_false_pred<APInt>>()));
// Float arbitrary vector
@@ -2036,17 +2069,21 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
Constant *CF32NaNWithUndef =
ConstantVector::get({CF32Undef, CF32NaN, CF32Undef});
- EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_zero_pred>()));
- EXPECT_TRUE(match(CMixedF32, cstfp_pred_ty<always_true_pred<APFloat>>()));
- EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ EXPECT_FALSE(match(CMixedF32, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CMixedF32, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CMixedF32, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CMixedF32, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
- EXPECT_TRUE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_nan_pred>()));
- EXPECT_FALSE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_zero_pred>()));
EXPECT_TRUE(
- match(CF32NaNWithUndef, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ match(CF32NaNWithUndef, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
EXPECT_FALSE(
- match(CF32NaNWithUndef, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ match(CF32NaNWithUndef, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CF32NaNWithUndef,
+ cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CF32NaNWithUndef,
+ cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
}
TEST_F(PatternMatchTest, InsertValue) {
More information about the llvm-commits
mailing list