[llvm] [PatternMatching] Add generic API for matching constants using custom conditions (PR #85676)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 19 08:57:57 PDT 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/85676

>From a78b248e339c50a1d441d2cbd2ed37012acd959b Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:14 -0500
Subject: [PATCH 1/2] [PatternMatching] Add generic API for matching constants
 using custom conditions

The new API is:
    `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
        - Matches non-undef constants s.t `Lambda(ele)` is true for all
          elements.
    `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
        - Matches constants/undef s.t `Lambda(ele)` is true for all
          elements.

The goal with these is to be able to replace the common usage of:
```
    match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
    match(X, m_CheckedInt(C, CustomChecks);
```

The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.

The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
---
 llvm/include/llvm/IR/PatternMatch.h |  91 +++++++++--
 llvm/unittests/IR/PatternMatch.cpp  | 240 ++++++++++++++++++++++++++++
 2 files changed, 320 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..82bd8469bee48b 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
 /// This helper class is used to match constant scalars, vector splats,
 /// and fixed width vectors that satisfy a specified predicate.
 /// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
 struct cstval_pred_ty : public Predicate {
   template <typename ITy> bool match(ITy *V) {
     if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
           Constant *Elt = C->getAggregateElement(i);
           if (!Elt)
             return false;
-          if (isa<UndefValue>(Elt))
+          if (isa<UndefValue>(Elt)) {
+            if (!AllowUndefs)
+              return false;
             continue;
+          }
           auto *CV = dyn_cast<ConstantVal>(Elt);
           if (!CV || !this->isValue(CV->getValue()))
             return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
 
 /// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
 
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
   const APInt *&Res;
 
   api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+        if (auto *CI =
+                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APFloat.
 /// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
   const APFloat *&Res;
 
   apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantFP>(
-                C->getSplatValue(/* AllowUndef */ true)))
+        if (auto *CI =
+                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
 //
 ///////////////////////////////////////////////////////////////////////////////
 
+template <typename APTy> struct custom_checkfn {
+  function_ref<bool(const APTy &)> CheckFn;
+  bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>, false> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+                       function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+                      function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
 struct is_any_apint {
   bool isValue(const APInt &C) { return true; }
 };
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
   EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
 }
 
+TEST_F(PatternMatchTest, CheckedInt) {
+  Type *I8Ty = IRB.getInt8Ty();
+  const APInt *Res = nullptr;
+
+  auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+  auto CheckTrue = [](const APInt &) { return true; };
+  auto CheckFalse = [](const APInt &) { return false; };
+  auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+  auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+  auto DoScalarCheck = [&](int8_t Val) {
+    APInt APVal(8, Val);
+    Constant *C = ConstantInt::get(I8Ty, Val);
+
+    Res = nullptr;
+    EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+    EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+    EXPECT_EQ(*Res, APVal);
+
+    Res = nullptr;
+    EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+    EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal),
+              m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal),
+              m_CheckedIntAllowUndef(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal),
+              m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal),
+              m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+  };
+
+  DoScalarCheck(0);
+  DoScalarCheck(1);
+  DoScalarCheck(2);
+  DoScalarCheck(3);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+                            function_ref<bool(const APInt &)> CheckFn,
+                            bool UndefAsPoison) {
+    SmallVector<Constant *> VecElems;
+    std::optional<bool> Okay;
+    bool AllSame = true;
+    bool HasUndef = false;
+    std::optional<APInt> First;
+    for (const std::optional<int8_t> &Val : Vals) {
+      if (!Val.has_value()) {
+        VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+                                         : UndefValue::get(I8Ty));
+        HasUndef = true;
+      } else {
+        if (!Okay.has_value())
+          Okay = true;
+        APInt APVal(8, *Val);
+        if (!First.has_value())
+          First = APVal;
+        else
+          AllSame &= First->eq(APVal);
+        Okay = *Okay && CheckFn(APVal);
+        VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+      }
+    }
+
+    Constant *C = ConstantVector::get(VecElems);
+    EXPECT_EQ(!HasUndef && Okay.value_or(false),
+              m_CheckedInt(CheckFn).match(C));
+    EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+    Res = nullptr;
+    bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+    if (Expec) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, *First);
+    }
+
+    Res = nullptr;
+    Expec = AllSame && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedIntAllowUndef(Res, CheckFn).match(C));
+    if (Expec) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, *First);
+    }
+  };
+  auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
+    DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
+    DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
+    DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
+  };
+
+  DoVecCheck({0, 1});
+  DoVecCheck({1, 1});
+  DoVecCheck({1, 2});
+  DoVecCheck({1, std::nullopt});
+  DoVecCheck({1, std::nullopt, 1});
+  DoVecCheck({1, std::nullopt, 2});
+  DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
+}
+
 TEST_F(PatternMatchTest, Power2) {
   Value *C128 = IRB.getInt32(128);
   Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1276,6 +1439,63 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorInfUndef, m_Finite()));
   EXPECT_FALSE(match(VectorNaNUndef, m_Finite()));
 
+  auto CheckTrue = [](const APFloat &) { return true; };
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
+
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckTrue)));
+
+  auto CheckFalse = [](const APFloat &) { return false; };
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
+
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckFalse)));
+
+  auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
+
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+
   const APFloat *C;
   // Regardless of whether undefs are allowed,
   // a fully undef constant does not match.
@@ -1285,6 +1505,7 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatForbidUndef(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatAllowUndef(C)));
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
 
   // We can always match simple constants and simple splats.
   C = nullptr;
@@ -1305,14 +1526,33 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZero, m_APFloatAllowUndef(C)));
   EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
 
   // Whether splats with undef can be matched depends on the matcher.
   EXPECT_FALSE(match(VectorZeroUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorZeroUndef, m_APFloatForbidUndef(C)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(C, CheckTrue)));
   C = nullptr;
   EXPECT_TRUE(match(VectorZeroUndef, m_APFloatAllowUndef(C)));
   EXPECT_TRUE(C->isZero());
   C = nullptr;
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
   EXPECT_TRUE(match(VectorZeroUndef, m_Finite(C)));
   EXPECT_TRUE(C->isZero());
 }

>From 4d7774dfdae01dbc8ebb9b30748cfb8638240ca3 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:18 -0500
Subject: [PATCH 2/2] [InstCombine] Add example usage for new `Checked` matcher
 API

There is no real motivation for this change other than to highlight a
case where the new `Checked` matcher API can handle non-splat-vecs
without increasing code complexity.
---
 llvm/include/llvm/IR/PatternMatch.h           | 207 ++++++++++--------
 .../InstCombine/InstCombineCompares.cpp       |  66 +++---
 .../InstCombine/signed-truncation-check.ll    |  30 +--
 llvm/unittests/IR/PatternMatch.cpp            | 169 ++++++++------
 4 files changed, 260 insertions(+), 212 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 82bd8469bee48b..ea6cb1fb72bed6 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -387,16 +387,22 @@ struct cstval_pred_ty : public Predicate {
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
 using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
 
+template <typename Predicate>
+using cst_or_undef_pred_ty = cst_pred_ty<Predicate, true>;
+
 /// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
 using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
 
+template <typename Predicate>
+using cstfp_or_undef_pred_ty = cstfp_pred_ty<Predicate, true>;
+
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
 struct api_pred_ty : public Predicate {
   const APInt *&Res;
 
@@ -421,10 +427,13 @@ struct api_pred_ty : public Predicate {
   }
 };
 
+template <typename Predicate>
+using api_or_undef_pred_ty = api_pred_ty<Predicate, true>;
+
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APFloat.
 /// Undefs are allowed in splat vector constants.
-template <typename Predicate, bool AllowUndefs = true>
+template <typename Predicate, bool AllowUndefs>
 struct apf_pred_ty : public Predicate {
   const APFloat *&Res;
 
@@ -449,6 +458,9 @@ struct apf_pred_ty : public Predicate {
   }
 };
 
+template <typename Predicate>
+using apf_or_undef_pred_ty = apf_pred_ty<Predicate, true>;
+
 ///////////////////////////////////////////////////////////////////////////////
 //
 // Encapsulate constant value queries for use in templated predicate matchers.
@@ -479,15 +491,15 @@ m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
 
 // Match and integer or vector where CheckFn(ele) for each element is true.
 // For vectors, undefined elements are assumed to match.
-inline cst_pred_ty<custom_checkfn<APInt>>
+inline cst_or_undef_pred_ty<custom_checkfn<APInt>>
 m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
-  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+  return cst_or_undef_pred_ty<custom_checkfn<APInt>>{CheckFn};
 }
 
-inline api_pred_ty<custom_checkfn<APInt>>
+inline api_or_undef_pred_ty<custom_checkfn<APInt>>
 m_CheckedIntAllowUndef(const APInt *&V,
                        function_ref<bool(const APInt &)> CheckFn) {
-  api_pred_ty<custom_checkfn<APInt>> P(V);
+  api_or_undef_pred_ty<custom_checkfn<APInt>> P(V);
   P.CheckFn = CheckFn;
   return P;
 }
@@ -508,15 +520,15 @@ m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
 
 // Match and float or vector where CheckFn(ele) for each element is true.
 // For vectors, undefined elements are assumed to match.
-inline cstfp_pred_ty<custom_checkfn<APFloat>>
+inline cstfp_or_undef_pred_ty<custom_checkfn<APFloat>>
 m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
-  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+  return cstfp_or_undef_pred_ty<custom_checkfn<APFloat>>{CheckFn};
 }
 
-inline apf_pred_ty<custom_checkfn<APFloat>>
+inline apf_or_undef_pred_ty<custom_checkfn<APFloat>>
 m_CheckedFpAllowUndef(const APFloat *&V,
                       function_ref<bool(const APFloat &)> CheckFn) {
-  apf_pred_ty<custom_checkfn<APFloat>> P(V);
+  apf_or_undef_pred_ty<custom_checkfn<APFloat>> P(V);
   P.CheckFn = CheckFn;
   return P;
 }
@@ -526,16 +538,16 @@ struct is_any_apint {
 };
 /// Match an integer or vector with any integral constant.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_any_apint> m_AnyIntegralConstant() {
-  return cst_pred_ty<is_any_apint>();
+inline cst_or_undef_pred_ty<is_any_apint> m_AnyIntegralConstant() {
+  return cst_or_undef_pred_ty<is_any_apint>();
 }
 
 struct is_shifted_mask {
   bool isValue(const APInt &C) { return C.isShiftedMask(); }
 };
 
-inline cst_pred_ty<is_shifted_mask> m_ShiftedMask() {
-  return cst_pred_ty<is_shifted_mask>();
+inline cst_or_undef_pred_ty<is_shifted_mask> m_ShiftedMask() {
+  return cst_or_undef_pred_ty<is_shifted_mask>();
 }
 
 struct is_all_ones {
@@ -543,8 +555,8 @@ struct is_all_ones {
 };
 /// Match an integer or vector with all bits set.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_all_ones> m_AllOnes() {
-  return cst_pred_ty<is_all_ones>();
+inline cst_or_undef_pred_ty<is_all_ones> m_AllOnes() {
+  return cst_or_undef_pred_ty<is_all_ones>();
 }
 
 struct is_maxsignedvalue {
@@ -553,10 +565,11 @@ struct is_maxsignedvalue {
 /// Match an integer or vector with values having all bits except for the high
 /// bit set (0x7f...).
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_maxsignedvalue> m_MaxSignedValue() {
-  return cst_pred_ty<is_maxsignedvalue>();
+inline cst_or_undef_pred_ty<is_maxsignedvalue> m_MaxSignedValue() {
+  return cst_or_undef_pred_ty<is_maxsignedvalue>();
 }
-inline api_pred_ty<is_maxsignedvalue> m_MaxSignedValue(const APInt *&V) {
+inline api_or_undef_pred_ty<is_maxsignedvalue>
+m_MaxSignedValue(const APInt *&V) {
   return V;
 }
 
@@ -565,30 +578,35 @@ struct is_negative {
 };
 /// Match an integer or vector of negative values.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_negative> m_Negative() {
-  return cst_pred_ty<is_negative>();
+inline cst_or_undef_pred_ty<is_negative> m_Negative() {
+  return cst_or_undef_pred_ty<is_negative>();
+}
+inline api_or_undef_pred_ty<is_negative> m_Negative(const APInt *&V) {
+  return V;
 }
-inline api_pred_ty<is_negative> m_Negative(const APInt *&V) { return V; }
 
 struct is_nonnegative {
   bool isValue(const APInt &C) { return C.isNonNegative(); }
 };
 /// Match an integer or vector of non-negative values.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_nonnegative> m_NonNegative() {
-  return cst_pred_ty<is_nonnegative>();
+inline cst_or_undef_pred_ty<is_nonnegative> m_NonNegative() {
+  return cst_or_undef_pred_ty<is_nonnegative>();
+}
+inline api_or_undef_pred_ty<is_nonnegative> m_NonNegative(const APInt *&V) {
+  return V;
 }
-inline api_pred_ty<is_nonnegative> m_NonNegative(const APInt *&V) { return V; }
 
 struct is_strictlypositive {
   bool isValue(const APInt &C) { return C.isStrictlyPositive(); }
 };
 /// Match an integer or vector of strictly positive values.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_strictlypositive> m_StrictlyPositive() {
-  return cst_pred_ty<is_strictlypositive>();
+inline cst_or_undef_pred_ty<is_strictlypositive> m_StrictlyPositive() {
+  return cst_or_undef_pred_ty<is_strictlypositive>();
 }
-inline api_pred_ty<is_strictlypositive> m_StrictlyPositive(const APInt *&V) {
+inline api_or_undef_pred_ty<is_strictlypositive>
+m_StrictlyPositive(const APInt *&V) {
   return V;
 }
 
@@ -597,32 +615,36 @@ struct is_nonpositive {
 };
 /// Match an integer or vector of non-positive values.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_nonpositive> m_NonPositive() {
-  return cst_pred_ty<is_nonpositive>();
+inline cst_or_undef_pred_ty<is_nonpositive> m_NonPositive() {
+  return cst_or_undef_pred_ty<is_nonpositive>();
+}
+inline api_or_undef_pred_ty<is_nonpositive> m_NonPositive(const APInt *&V) {
+  return V;
 }
-inline api_pred_ty<is_nonpositive> m_NonPositive(const APInt *&V) { return V; }
 
 struct is_one {
   bool isValue(const APInt &C) { return C.isOne(); }
 };
 /// Match an integer 1 or a vector with all elements equal to 1.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_one> m_One() { return cst_pred_ty<is_one>(); }
+inline cst_or_undef_pred_ty<is_one> m_One() {
+  return cst_or_undef_pred_ty<is_one>();
+}
 
 struct is_zero_int {
   bool isValue(const APInt &C) { return C.isZero(); }
 };
 /// Match an integer 0 or a vector with all elements equal to 0.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_zero_int> m_ZeroInt() {
-  return cst_pred_ty<is_zero_int>();
+inline cst_or_undef_pred_ty<is_zero_int> m_ZeroInt() {
+  return cst_or_undef_pred_ty<is_zero_int>();
 }
 
 struct is_zero {
   template <typename ITy> bool match(ITy *V) {
     auto *C = dyn_cast<Constant>(V);
     // FIXME: this should be able to do something for scalable vectors
-    return C && (C->isNullValue() || cst_pred_ty<is_zero_int>().match(C));
+    return C && (C->isNullValue() || m_ZeroInt().match(C));
   }
 };
 /// Match any null constant or a vector with all elements equal to 0.
@@ -634,18 +656,21 @@ struct is_power2 {
 };
 /// Match an integer or vector power-of-2.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_power2> m_Power2() { return cst_pred_ty<is_power2>(); }
-inline api_pred_ty<is_power2> m_Power2(const APInt *&V) { return V; }
+inline cst_or_undef_pred_ty<is_power2> m_Power2() {
+  return cst_or_undef_pred_ty<is_power2>();
+}
+inline api_or_undef_pred_ty<is_power2> m_Power2(const APInt *&V) { return V; }
 
 struct is_negated_power2 {
   bool isValue(const APInt &C) { return C.isNegatedPowerOf2(); }
 };
 /// Match a integer or vector negated power-of-2.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_negated_power2> m_NegatedPower2() {
-  return cst_pred_ty<is_negated_power2>();
+inline cst_or_undef_pred_ty<is_negated_power2> m_NegatedPower2() {
+  return cst_or_undef_pred_ty<is_negated_power2>();
 }
-inline api_pred_ty<is_negated_power2> m_NegatedPower2(const APInt *&V) {
+inline api_or_undef_pred_ty<is_negated_power2>
+m_NegatedPower2(const APInt *&V) {
   return V;
 }
 
@@ -654,10 +679,10 @@ struct is_negated_power2_or_zero {
 };
 /// Match a integer or vector negated power-of-2.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
-  return cst_pred_ty<is_negated_power2_or_zero>();
+inline cst_or_undef_pred_ty<is_negated_power2_or_zero> m_NegatedPower2OrZero() {
+  return cst_or_undef_pred_ty<is_negated_power2_or_zero>();
 }
-inline api_pred_ty<is_negated_power2_or_zero>
+inline api_or_undef_pred_ty<is_negated_power2_or_zero>
 m_NegatedPower2OrZero(const APInt *&V) {
   return V;
 }
@@ -667,10 +692,10 @@ struct is_power2_or_zero {
 };
 /// Match an integer or vector of 0 or power-of-2 values.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_power2_or_zero> m_Power2OrZero() {
-  return cst_pred_ty<is_power2_or_zero>();
+inline cst_or_undef_pred_ty<is_power2_or_zero> m_Power2OrZero() {
+  return cst_or_undef_pred_ty<is_power2_or_zero>();
 }
-inline api_pred_ty<is_power2_or_zero> m_Power2OrZero(const APInt *&V) {
+inline api_or_undef_pred_ty<is_power2_or_zero> m_Power2OrZero(const APInt *&V) {
   return V;
 }
 
@@ -679,8 +704,8 @@ struct is_sign_mask {
 };
 /// Match an integer or vector with only the sign bit(s) set.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_sign_mask> m_SignMask() {
-  return cst_pred_ty<is_sign_mask>();
+inline cst_or_undef_pred_ty<is_sign_mask> m_SignMask() {
+  return cst_or_undef_pred_ty<is_sign_mask>();
 }
 
 struct is_lowbit_mask {
@@ -688,20 +713,23 @@ struct is_lowbit_mask {
 };
 /// Match an integer or vector with only the low bit(s) set.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_lowbit_mask> m_LowBitMask() {
-  return cst_pred_ty<is_lowbit_mask>();
+inline cst_or_undef_pred_ty<is_lowbit_mask> m_LowBitMask() {
+  return cst_or_undef_pred_ty<is_lowbit_mask>();
+}
+inline api_or_undef_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) {
+  return V;
 }
-inline api_pred_ty<is_lowbit_mask> m_LowBitMask(const APInt *&V) { return V; }
 
 struct is_lowbit_mask_or_zero {
   bool isValue(const APInt &C) { return !C || C.isMask(); }
 };
 /// Match an integer or vector with only the low bit(s) set.
 /// For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
-  return cst_pred_ty<is_lowbit_mask_or_zero>();
+inline cst_or_undef_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero() {
+  return cst_or_undef_pred_ty<is_lowbit_mask_or_zero>();
 }
-inline api_pred_ty<is_lowbit_mask_or_zero> m_LowBitMaskOrZero(const APInt *&V) {
+inline api_or_undef_pred_ty<is_lowbit_mask_or_zero>
+m_LowBitMaskOrZero(const APInt *&V) {
   return V;
 }
 
@@ -712,9 +740,9 @@ struct icmp_pred_with_threshold {
 };
 /// Match an integer or vector with every element comparing 'pred' (eg/ne/...)
 /// to Threshold. For vectors, this includes constants with undefined elements.
-inline cst_pred_ty<icmp_pred_with_threshold>
+inline cst_or_undef_pred_ty<icmp_pred_with_threshold>
 m_SpecificInt_ICMP(ICmpInst::Predicate Predicate, const APInt &Threshold) {
-  cst_pred_ty<icmp_pred_with_threshold> P;
+  cst_or_undef_pred_ty<icmp_pred_with_threshold> P;
   P.Pred = Predicate;
   P.Thr = &Threshold;
   return P;
@@ -725,15 +753,17 @@ struct is_nan {
 };
 /// Match an arbitrary NaN constant. This includes quiet and signalling nans.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_nan> m_NaN() { return cstfp_pred_ty<is_nan>(); }
+inline cstfp_or_undef_pred_ty<is_nan> m_NaN() {
+  return cstfp_or_undef_pred_ty<is_nan>();
+}
 
 struct is_nonnan {
   bool isValue(const APFloat &C) { return !C.isNaN(); }
 };
 /// Match a non-NaN FP constant.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_nonnan> m_NonNaN() {
-  return cstfp_pred_ty<is_nonnan>();
+inline cstfp_or_undef_pred_ty<is_nonnan> m_NonNaN() {
+  return cstfp_or_undef_pred_ty<is_nonnan>();
 }
 
 struct is_inf {
@@ -741,15 +771,17 @@ struct is_inf {
 };
 /// Match a positive or negative infinity FP constant.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_inf> m_Inf() { return cstfp_pred_ty<is_inf>(); }
+inline cstfp_or_undef_pred_ty<is_inf> m_Inf() {
+  return cstfp_or_undef_pred_ty<is_inf>();
+}
 
 struct is_noninf {
   bool isValue(const APFloat &C) { return !C.isInfinity(); }
 };
 /// Match a non-infinity FP constant, i.e. finite or NaN.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_noninf> m_NonInf() {
-  return cstfp_pred_ty<is_noninf>();
+inline cstfp_or_undef_pred_ty<is_noninf> m_NonInf() {
+  return cstfp_or_undef_pred_ty<is_noninf>();
 }
 
 struct is_finite {
@@ -757,20 +789,21 @@ struct is_finite {
 };
 /// Match a finite FP constant, i.e. not infinity or NaN.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_finite> m_Finite() {
-  return cstfp_pred_ty<is_finite>();
+inline cstfp_or_undef_pred_ty<is_finite> m_Finite() {
+  return cstfp_or_undef_pred_ty<is_finite>();
 }
-inline apf_pred_ty<is_finite> m_Finite(const APFloat *&V) { return V; }
+inline apf_or_undef_pred_ty<is_finite> m_Finite(const APFloat *&V) { return V; }
 
 struct is_finitenonzero {
   bool isValue(const APFloat &C) { return C.isFiniteNonZero(); }
 };
 /// Match a finite non-zero FP constant.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_finitenonzero> m_FiniteNonZero() {
-  return cstfp_pred_ty<is_finitenonzero>();
+inline cstfp_or_undef_pred_ty<is_finitenonzero> m_FiniteNonZero() {
+  return cstfp_or_undef_pred_ty<is_finitenonzero>();
 }
-inline apf_pred_ty<is_finitenonzero> m_FiniteNonZero(const APFloat *&V) {
+inline apf_or_undef_pred_ty<is_finitenonzero>
+m_FiniteNonZero(const APFloat *&V) {
   return V;
 }
 
@@ -779,8 +812,8 @@ struct is_any_zero_fp {
 };
 /// Match a floating-point negative zero or positive zero.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_any_zero_fp> m_AnyZeroFP() {
-  return cstfp_pred_ty<is_any_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_any_zero_fp> m_AnyZeroFP() {
+  return cstfp_or_undef_pred_ty<is_any_zero_fp>();
 }
 
 struct is_pos_zero_fp {
@@ -788,8 +821,8 @@ struct is_pos_zero_fp {
 };
 /// Match a floating-point positive zero.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_pos_zero_fp> m_PosZeroFP() {
-  return cstfp_pred_ty<is_pos_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_pos_zero_fp> m_PosZeroFP() {
+  return cstfp_or_undef_pred_ty<is_pos_zero_fp>();
 }
 
 struct is_neg_zero_fp {
@@ -797,8 +830,8 @@ struct is_neg_zero_fp {
 };
 /// Match a floating-point negative zero.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_neg_zero_fp> m_NegZeroFP() {
-  return cstfp_pred_ty<is_neg_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_neg_zero_fp> m_NegZeroFP() {
+  return cstfp_or_undef_pred_ty<is_neg_zero_fp>();
 }
 
 struct is_non_zero_fp {
@@ -806,8 +839,8 @@ struct is_non_zero_fp {
 };
 /// Match a floating-point non-zero.
 /// For vectors, this includes constants with undefined elements.
-inline cstfp_pred_ty<is_non_zero_fp> m_NonZeroFP() {
-  return cstfp_pred_ty<is_non_zero_fp>();
+inline cstfp_or_undef_pred_ty<is_non_zero_fp> m_NonZeroFP() {
+  return cstfp_or_undef_pred_ty<is_non_zero_fp>();
 }
 
 ///////////////////////////////////////////////////////////////////////////////
@@ -871,8 +904,7 @@ m_ImmConstant() {
 }
 
 /// Match an immediate Constant, capturing the value if we match.
-inline match_combine_and<bind_ty<Constant>,
-                         match_unless<constantexpr_match>>
+inline match_combine_and<bind_ty<Constant>, match_unless<constantexpr_match>>
 m_ImmConstant(Constant *&C) {
   return m_CombineAnd(m_Constant(C), m_Unless(m_ConstantExpr()));
 }
@@ -1127,11 +1159,11 @@ template <typename Op_t> struct FNeg_match {
     if (FPMO->getOpcode() == Instruction::FSub) {
       if (FPMO->hasNoSignedZeros()) {
         // With 'nsz', any zero goes.
-        if (!cstfp_pred_ty<is_any_zero_fp>().match(FPMO->getOperand(0)))
+        if (!m_AnyZeroFP().match(FPMO->getOperand(0)))
           return false;
       } else {
         // Without 'nsz', we need fsub -0.0, X exactly.
-        if (!cstfp_pred_ty<is_neg_zero_fp>().match(FPMO->getOperand(0)))
+        if (!m_NegZeroFP().match(FPMO->getOperand(0)))
           return false;
       }
 
@@ -1149,7 +1181,8 @@ template <typename OpTy> inline FNeg_match<OpTy> m_FNeg(const OpTy &X) {
 
 /// Match 'fneg X' as 'fsub +-0.0, X'.
 template <typename RHS>
-inline BinaryOp_match<cstfp_pred_ty<is_any_zero_fp>, RHS, Instruction::FSub>
+inline BinaryOp_match<cstfp_or_undef_pred_ty<is_any_zero_fp>, RHS,
+                      Instruction::FSub>
 m_FNegNSZ(const RHS &X) {
   return m_FSub(m_AnyZeroFP(), X);
 }
@@ -2576,14 +2609,15 @@ inline BinaryOp_match<LHS, RHS, Instruction::Xor, true> m_c_Xor(const LHS &L,
 
 /// Matches a 'Neg' as 'sub 0, V'.
 template <typename ValTy>
-inline BinaryOp_match<cst_pred_ty<is_zero_int>, ValTy, Instruction::Sub>
+inline BinaryOp_match<cst_or_undef_pred_ty<is_zero_int>, ValTy,
+                      Instruction::Sub>
 m_Neg(const ValTy &V) {
   return m_Sub(m_ZeroInt(), V);
 }
 
 /// Matches a 'Neg' as 'sub nsw 0, V'.
 template <typename ValTy>
-inline OverflowingBinaryOp_match<cst_pred_ty<is_zero_int>, ValTy,
+inline OverflowingBinaryOp_match<cst_or_undef_pred_ty<is_zero_int>, ValTy,
                                  Instruction::Sub,
                                  OverflowingBinaryOperator::NoSignedWrap>
 m_NSWNeg(const ValTy &V) {
@@ -2594,7 +2628,8 @@ m_NSWNeg(const ValTy &V) {
 /// NOTE: we first match the 'Not' (by matching '-1'),
 /// and only then match the inner matcher!
 template <typename ValTy>
-inline BinaryOp_match<cst_pred_ty<is_all_ones>, ValTy, Instruction::Xor, true>
+inline BinaryOp_match<cst_or_undef_pred_ty<is_all_ones>, ValTy,
+                      Instruction::Xor, true>
 m_Not(const ValTy &V) {
   return m_c_Xor(m_AllOnes(), V);
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
     case ICmpInst::ICMP_ULT: {
       if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A <u C -> A == C-1 if min(A)+1 == C
-        if (*CmpC == Op0Min + 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC - 1));
-        // X <u C --> X == 0, if the number of zero bits in the bottom of X
-        // exceeds the log2 of C.
-        if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              Constant::getNullValue(Op1->getType()));
-      }
+      // A <u C -> A == C-1 if min(A)+1 == C
+      if (match(Op1, m_SpecificInt(Op0Min + 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Min));
+      // X <u C --> X == 0, if the number of zero bits in the bottom of X
+      // exceeds the log2 of C.
+      if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+                  return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+                })))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            Constant::getNullValue(Op1->getType()));
       break;
     }
     case ICmpInst::ICMP_UGT: {
       if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A >u C -> A == C+1 if max(a)-1 == C
-        if (*CmpC == Op0Max - 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-        // X >u C --> X != 0, if the number of zero bits in the bottom of X
-        // exceeds the log2 of C.
-        if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
-          return new ICmpInst(ICmpInst::ICMP_NE, Op0,
-                              Constant::getNullValue(Op1->getType()));
-      }
+      // A >u C -> A == C+1 if max(a)-1 == C
+      if (match(Op1, m_SpecificInt(Op0Max - 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Max));
+      // X >u C --> X != 0, if the number of zero bits in the bottom of X
+      // exceeds the log2 of C.
+      if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+                  return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+                })))
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+                            Constant::getNullValue(Op1->getType()));
       break;
     }
     case ICmpInst::ICMP_SLT: {
       if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC - 1));
-      }
+      // A <s C -> A == C-1 if min(A)+1 == C
+      if (match(Op1, m_SpecificInt(Op0Min + 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Min));
       break;
     }
     case ICmpInst::ICMP_SGT: {
       if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-      }
+      // A >s C -> A == C+1 if max(A)-1 == C
+      if (match(Op1, m_SpecificInt(Op0Max - 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Max));
       break;
     }
     }
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index de361c70804c3e..c3614912c7e23c 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1796,20 +1796,26 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
   Constant *CU32Zero = Constant::getIntegerValue(U32Ty, U32Zero);
   Constant *CU32DeadBeef = Constant::getIntegerValue(U32Ty, U32DeadBeef);
 
-  EXPECT_TRUE(match(CU32Max, cst_pred_ty<is_unsigned_max_pred>()));
-  EXPECT_FALSE(match(CU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
-  EXPECT_TRUE(match(CU32Max, cst_pred_ty<always_true_pred<APInt>>()));
-  EXPECT_FALSE(match(CU32Max, cst_pred_ty<always_false_pred<APInt>>()));
+  EXPECT_TRUE(match(CU32Max, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(match(CU32Max, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CU32Max, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(
+      match(CU32Max, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
-  EXPECT_FALSE(match(CU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
-  EXPECT_TRUE(match(CU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
-  EXPECT_TRUE(match(CU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
-  EXPECT_FALSE(match(CU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+  EXPECT_FALSE(match(CU32Zero, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_TRUE(match(CU32Zero, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CU32Zero, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(
+      match(CU32Zero, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
-  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
-  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
-  EXPECT_TRUE(match(CU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
-  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+  EXPECT_FALSE(
+      match(CU32DeadBeef, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(
+      match(CU32DeadBeef, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(
+      match(CU32DeadBeef, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(
+      match(CU32DeadBeef, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
   // Scalar float
   APFloat F32NaN = APFloat::getNaN(APFloat::IEEEsingle());
@@ -1822,20 +1828,26 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
   Constant *CF32Zero = ConstantFP::get(F32Ty, F32Zero);
   Constant *CF32Pi = ConstantFP::get(F32Ty, F32Pi);
 
-  EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
-  EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
-  EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
-  EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+  EXPECT_TRUE(match(CF32NaN, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CF32NaN, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(
+      match(CF32NaN, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(
+      match(CF32NaN, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 
-  EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
-  EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
-  EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
-  EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+  EXPECT_FALSE(match(CF32Zero, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+  EXPECT_TRUE(match(CF32Zero, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(
+      match(CF32Zero, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(
+      match(CF32Zero, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 
-  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
-  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
-  EXPECT_TRUE(match(CF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
-  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+  EXPECT_FALSE(match(CF32Pi, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CF32Pi, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(
+      match(CF32Pi, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(
+      match(CF32Pi, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 
   auto FixedEC = ElementCount::getFixed(4);
   auto ScalableEC = ElementCount::getScalable(4);
@@ -1849,23 +1861,32 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
     Constant *CSplatU32Zero = ConstantVector::getSplat(EC, CU32Zero);
     Constant *CSplatU32DeadBeef = ConstantVector::getSplat(EC, CU32DeadBeef);
 
-    EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<is_unsigned_max_pred>()));
-    EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
-    EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<always_true_pred<APInt>>()));
-    EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<always_false_pred<APInt>>()));
-
-    EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
-    EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
-    EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
-    EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+    EXPECT_TRUE(
+        match(CSplatU32Max, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+    EXPECT_FALSE(
+        match(CSplatU32Max, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+    EXPECT_TRUE(
+        match(CSplatU32Max, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+    EXPECT_FALSE(
+        match(CSplatU32Max, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
-    EXPECT_FALSE(match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
     EXPECT_FALSE(
-        match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
+        match(CSplatU32Zero, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
     EXPECT_TRUE(
-        match(CSplatU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
+        match(CSplatU32Zero, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+    EXPECT_TRUE(
+        match(CSplatU32Zero, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
     EXPECT_FALSE(
-        match(CSplatU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+        match(CSplatU32Zero, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
+
+    EXPECT_FALSE(
+        match(CSplatU32DeadBeef, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+    EXPECT_FALSE(match(CSplatU32DeadBeef,
+                       cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+    EXPECT_TRUE(match(CSplatU32DeadBeef,
+                      cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+    EXPECT_FALSE(match(CSplatU32DeadBeef,
+                       cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
     // float
 
@@ -1873,25 +1894,32 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
     Constant *CSplatF32Zero = ConstantVector::getSplat(EC, CF32Zero);
     Constant *CSplatF32Pi = ConstantVector::getSplat(EC, CF32Pi);
 
-    EXPECT_TRUE(match(CSplatF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
-    EXPECT_FALSE(match(CSplatF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
     EXPECT_TRUE(
-        match(CSplatF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
+        match(CSplatF32NaN, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
     EXPECT_FALSE(
-        match(CSplatF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+        match(CSplatF32NaN, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+    EXPECT_TRUE(match(CSplatF32NaN,
+                      cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+    EXPECT_FALSE(match(CSplatF32NaN,
+                       cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 
-    EXPECT_FALSE(match(CSplatF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
-    EXPECT_TRUE(match(CSplatF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
-    EXPECT_TRUE(
-        match(CSplatF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
     EXPECT_FALSE(
-        match(CSplatF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+        match(CSplatF32Zero, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+    EXPECT_TRUE(
+        match(CSplatF32Zero, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+    EXPECT_TRUE(match(CSplatF32Zero,
+                      cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+    EXPECT_FALSE(match(CSplatF32Zero,
+                       cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 
-    EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
-    EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
-    EXPECT_TRUE(match(CSplatF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
     EXPECT_FALSE(
-        match(CSplatF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+        match(CSplatF32Pi, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+    EXPECT_FALSE(
+        match(CSplatF32Pi, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+    EXPECT_TRUE(match(CSplatF32Pi,
+                      cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+    EXPECT_FALSE(match(CSplatF32Pi,
+                       cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
   }
 
   // Int arbitrary vector
@@ -1901,16 +1929,21 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
   Constant *CU32MaxWithUndef =
       ConstantVector::get({CU32Undef, CU32Max, CU32Undef});
 
-  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_max_pred>()));
-  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_zero_pred>()));
-  EXPECT_TRUE(match(CMixedU32, cst_pred_ty<always_true_pred<APInt>>()));
-  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<always_false_pred<APInt>>()));
+  EXPECT_FALSE(match(CMixedU32, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(match(CMixedU32, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(
+      match(CMixedU32, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(
+      match(CMixedU32, cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
-  EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_max_pred>()));
-  EXPECT_FALSE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_zero_pred>()));
-  EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_TRUE(
+      match(CU32MaxWithUndef, cst_or_undef_pred_ty<is_unsigned_max_pred>()));
   EXPECT_FALSE(
-      match(CU32MaxWithUndef, cst_pred_ty<always_false_pred<APInt>>()));
+      match(CU32MaxWithUndef, cst_or_undef_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(
+      match(CU32MaxWithUndef, cst_or_undef_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(match(CU32MaxWithUndef,
+                     cst_or_undef_pred_ty<always_false_pred<APInt>>()));
 
   // Float arbitrary vector
 
@@ -1919,17 +1952,21 @@ TEST_F(PatternMatchTest, ConstantPredicateType) {
   Constant *CF32NaNWithUndef =
       ConstantVector::get({CF32Undef, CF32NaN, CF32Undef});
 
-  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_nan_pred>()));
-  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_zero_pred>()));
-  EXPECT_TRUE(match(CMixedF32, cstfp_pred_ty<always_true_pred<APFloat>>()));
-  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<always_false_pred<APFloat>>()));
+  EXPECT_FALSE(match(CMixedF32, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CMixedF32, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(
+      match(CMixedF32, cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(
+      match(CMixedF32, cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 
-  EXPECT_TRUE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_nan_pred>()));
-  EXPECT_FALSE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_zero_pred>()));
   EXPECT_TRUE(
-      match(CF32NaNWithUndef, cstfp_pred_ty<always_true_pred<APFloat>>()));
+      match(CF32NaNWithUndef, cstfp_or_undef_pred_ty<is_float_nan_pred>()));
   EXPECT_FALSE(
-      match(CF32NaNWithUndef, cstfp_pred_ty<always_false_pred<APFloat>>()));
+      match(CF32NaNWithUndef, cstfp_or_undef_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(match(CF32NaNWithUndef,
+                    cstfp_or_undef_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(match(CF32NaNWithUndef,
+                     cstfp_or_undef_pred_ty<always_false_pred<APFloat>>()));
 }
 
 TEST_F(PatternMatchTest, InsertValue) {



More information about the llvm-commits mailing list