[llvm] goldsteinn/pattern match api (PR #85676)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 18 11:14:54 PDT 2024
https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/85676
- **[PatternMatching] Add generic API for matching constants using custom conditions**
- **[InstCombine] Add example usage for new `Checked` matcher API**
>From 71c8840221b2a953b72a8f6680b5939bde22ff75 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:14 -0500
Subject: [PATCH 1/2] [PatternMatching] Add generic API for matching constants
using custom conditions
The new API is:
`m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
- Matches non-undef constants s.t `Lambda(ele)` is true for all
elements.
`m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
- Matches constants/undef s.t `Lambda(ele)` is true for all
elements.
The goal with these is to be able to replace the common usage of:
```
match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
match(X, m_CheckedInt(C, CustomChecks);
```
The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.
The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
---
llvm/include/llvm/IR/PatternMatch.h | 91 +++++++++--
llvm/unittests/IR/PatternMatch.cpp | 240 ++++++++++++++++++++++++++++
2 files changed, 320 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..4333d3e6e8da2a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
- if (isa<UndefValue>(Elt))
+ if (isa<UndefValue>(Elt)) {
+ if (!AllowUndefs)
+ return false;
continue;
+ }
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
};
/// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
/// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
const APInt *&Res;
api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
const APFloat *&Res;
apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantFP>(
- C->getSplatValue(/* AllowUndef */ true)))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
//
///////////////////////////////////////////////////////////////////////////////
+template <typename APTy> struct custom_checkfn {
+ function_ref<bool(const APTy &)> CheckFn;
+ bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+ function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+ function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
struct is_any_apint {
bool isValue(const APInt &C) { return true; }
};
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
}
+TEST_F(PatternMatchTest, CheckedInt) {
+ Type *I8Ty = IRB.getInt8Ty();
+ const APInt *Res = nullptr;
+
+ auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+ auto CheckTrue = [](const APInt &) { return true; };
+ auto CheckFalse = [](const APInt &) { return false; };
+ auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+ auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+ auto DoScalarCheck = [&](int8_t Val) {
+ APInt APVal(8, Val);
+ Constant *C = ConstantInt::get(I8Ty, Val);
+
+ Res = nullptr;
+ EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+ EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+ EXPECT_EQ(*Res, APVal);
+
+ Res = nullptr;
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal),
+ m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal),
+ m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+ };
+
+ DoScalarCheck(0);
+ DoScalarCheck(1);
+ DoScalarCheck(2);
+ DoScalarCheck(3);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+ function_ref<bool(const APInt &)> CheckFn,
+ bool UndefAsPoison) {
+ SmallVector<Constant *> VecElems;
+ std::optional<bool> Okay;
+ bool AllSame = true;
+ bool HasUndef = false;
+ std::optional<APInt> First;
+ for (const std::optional<int8_t> &Val : Vals) {
+ if (!Val.has_value()) {
+ VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+ : UndefValue::get(I8Ty));
+ HasUndef = true;
+ } else {
+ if (!Okay.has_value())
+ Okay = true;
+ APInt APVal(8, *Val);
+ if (!First.has_value())
+ First = APVal;
+ else
+ AllSame &= First->eq(APVal);
+ Okay = *Okay && CheckFn(APVal);
+ VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+ }
+ }
+
+ Constant *C = ConstantVector::get(VecElems);
+ EXPECT_EQ(!HasUndef && Okay.value_or(false),
+ m_CheckedInt(CheckFn).match(C));
+ EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+ Res = nullptr;
+ bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+
+ Res = nullptr;
+ Expec = AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedIntAllowUndef(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+ };
+ auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
+ DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
+ DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
+ DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
+ DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
+ };
+
+ DoVecCheck({0, 1});
+ DoVecCheck({1, 1});
+ DoVecCheck({1, 2});
+ DoVecCheck({1, std::nullopt});
+ DoVecCheck({1, std::nullopt, 1});
+ DoVecCheck({1, std::nullopt, 2});
+ DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
+}
+
TEST_F(PatternMatchTest, Power2) {
Value *C128 = IRB.getInt32(128);
Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1276,6 +1439,63 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorInfUndef, m_Finite()));
EXPECT_FALSE(match(VectorNaNUndef, m_Finite()));
+ auto CheckTrue = [](const APFloat &) { return true; };
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
+
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckTrue)));
+ EXPECT_TRUE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckTrue)));
+
+ auto CheckFalse = [](const APFloat &) { return false; };
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
+
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckFalse)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckFalse)));
+
+ auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
+
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+ EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+
const APFloat *C;
// Regardless of whether undefs are allowed,
// a fully undef constant does not match.
@@ -1285,6 +1505,7 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatForbidUndef(C)));
EXPECT_FALSE(match(VectorUndef, m_APFloatAllowUndef(C)));
+ EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
// We can always match simple constants and simple splats.
C = nullptr;
@@ -1305,14 +1526,33 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
C = nullptr;
EXPECT_TRUE(match(VectorZero, m_APFloatAllowUndef(C)));
EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckTrue)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+ EXPECT_TRUE(C->isZero());
// Whether splats with undef can be matched depends on the matcher.
EXPECT_FALSE(match(VectorZeroUndef, m_APFloat(C)));
EXPECT_FALSE(match(VectorZeroUndef, m_APFloatForbidUndef(C)));
+ EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(C, CheckTrue)));
C = nullptr;
EXPECT_TRUE(match(VectorZeroUndef, m_APFloatAllowUndef(C)));
EXPECT_TRUE(C->isZero());
C = nullptr;
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckTrue)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
+ EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+ EXPECT_TRUE(C->isZero());
+ C = nullptr;
EXPECT_TRUE(match(VectorZeroUndef, m_Finite(C)));
EXPECT_TRUE(C->isZero());
}
>From 6a6c35f20a9c748c58eb8129a2f950155ac33a26 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:18 -0500
Subject: [PATCH 2/2] [InstCombine] Add example usage for new `Checked` matcher
API
There is no real motivation for this change other than to highlight a
case where the new `Checked` matcher API can handle non-splat-vecs
without increasing code complexity.
---
.../InstCombine/InstCombineCompares.cpp | 66 +++++++++----------
.../InstCombine/signed-truncation-check.ll | 30 ++-------
2 files changed, 36 insertions(+), 60 deletions(-)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (*CmpC == Op0Min + 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- // X <u C --> X == 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- // X >u C --> X != 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
- return new ICmpInst(ICmpInst::ICMP_NE, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- }
+ // A <s C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- }
+ // A >s C -> A == C+1 if max(A)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
break;
}
}
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
More information about the llvm-commits
mailing list