[llvm] goldsteinn/pattern match api (PR #85676)

via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 18 11:14:54 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/85676

- **[PatternMatching] Add generic API for matching constants using custom conditions**
- **[InstCombine] Add example usage for new `Checked` matcher API**


>From 71c8840221b2a953b72a8f6680b5939bde22ff75 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:14 -0500
Subject: [PATCH 1/2] [PatternMatching] Add generic API for matching constants
 using custom conditions

The new API is:
    `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
        - Matches non-undef constants s.t `Lambda(ele)` is true for all
          elements.
    `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
        - Matches constants/undef s.t `Lambda(ele)` is true for all
          elements.

The goal with these is to be able to replace the common usage of:
```
    match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
    match(X, m_CheckedInt(C, CustomChecks);
```

The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.

The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
---
 llvm/include/llvm/IR/PatternMatch.h |  91 +++++++++--
 llvm/unittests/IR/PatternMatch.cpp  | 240 ++++++++++++++++++++++++++++
 2 files changed, 320 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..4333d3e6e8da2a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
 /// This helper class is used to match constant scalars, vector splats,
 /// and fixed width vectors that satisfy a specified predicate.
 /// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
 struct cstval_pred_ty : public Predicate {
   template <typename ITy> bool match(ITy *V) {
     if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
           Constant *Elt = C->getAggregateElement(i);
           if (!Elt)
             return false;
-          if (isa<UndefValue>(Elt))
+          if (isa<UndefValue>(Elt)) {
+            if (!AllowUndefs)
+              return false;
             continue;
+          }
           auto *CV = dyn_cast<ConstantVal>(Elt);
           if (!CV || !this->isValue(CV->getValue()))
             return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
 
 /// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
 
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
   const APInt *&Res;
 
   api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+        if (auto *CI =
+                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APFloat.
 /// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
   const APFloat *&Res;
 
   apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
       }
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        if (auto *CI = dyn_cast_or_null<ConstantFP>(
-                C->getSplatValue(/* AllowUndef */ true)))
+        if (auto *CI =
+                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
 //
 ///////////////////////////////////////////////////////////////////////////////
 
+template <typename APTy> struct custom_checkfn {
+  function_ref<bool(const APTy &)> CheckFn;
+  bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>, false> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+                       function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+                    function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
 struct is_any_apint {
   bool isValue(const APInt &C) { return true; }
 };
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
   EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
 }
 
+TEST_F(PatternMatchTest, CheckedInt) {
+  Type *I8Ty = IRB.getInt8Ty();
+  const APInt *Res = nullptr;
+
+  auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+  auto CheckTrue = [](const APInt &) { return true; };
+  auto CheckFalse = [](const APInt &) { return false; };
+  auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+  auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+  auto DoScalarCheck = [&](int8_t Val) {
+    APInt APVal(8, Val);
+    Constant *C = ConstantInt::get(I8Ty, Val);
+
+    Res = nullptr;
+    EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+    EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+    EXPECT_EQ(*Res, APVal);
+
+    Res = nullptr;
+    EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+    EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal),
+              m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal),
+              m_CheckedIntAllowUndef(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal),
+              m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal),
+              m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+  };
+
+  DoScalarCheck(0);
+  DoScalarCheck(1);
+  DoScalarCheck(2);
+  DoScalarCheck(3);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+                            function_ref<bool(const APInt &)> CheckFn,
+                            bool UndefAsPoison) {
+    SmallVector<Constant *> VecElems;
+    std::optional<bool> Okay;
+    bool AllSame = true;
+    bool HasUndef = false;
+    std::optional<APInt> First;
+    for (const std::optional<int8_t> &Val : Vals) {
+      if (!Val.has_value()) {
+        VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+                                         : UndefValue::get(I8Ty));
+        HasUndef = true;
+      } else {
+        if (!Okay.has_value())
+          Okay = true;
+        APInt APVal(8, *Val);
+        if (!First.has_value())
+          First = APVal;
+        else
+          AllSame &= First->eq(APVal);
+        Okay = *Okay && CheckFn(APVal);
+        VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+      }
+    }
+
+    Constant *C = ConstantVector::get(VecElems);
+    EXPECT_EQ(!HasUndef && Okay.value_or(false),
+              m_CheckedInt(CheckFn).match(C));
+    EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+    Res = nullptr;
+    bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+    if (Expec) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, *First);
+    }
+
+    Res = nullptr;
+    Expec = AllSame && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedIntAllowUndef(Res, CheckFn).match(C));
+    if (Expec) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, *First);
+    }
+  };
+  auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
+    DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
+    DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
+    DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
+  };
+
+  DoVecCheck({0, 1});
+  DoVecCheck({1, 1});
+  DoVecCheck({1, 2});
+  DoVecCheck({1, std::nullopt});
+  DoVecCheck({1, std::nullopt, 1});
+  DoVecCheck({1, std::nullopt, 2});
+  DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
+}
+
 TEST_F(PatternMatchTest, Power2) {
   Value *C128 = IRB.getInt32(128);
   Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1276,6 +1439,63 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorInfUndef, m_Finite()));
   EXPECT_FALSE(match(VectorNaNUndef, m_Finite()));
 
+  auto CheckTrue = [](const APFloat &) { return true; };
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
+
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckTrue)));
+  EXPECT_TRUE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckTrue)));
+
+  auto CheckFalse = [](const APFloat &) { return false; };
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
+
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckFalse)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckFalse)));
+
+  auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
+
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_TRUE(match(VectorInfUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFpAllowUndef(CheckNonNaN)));
+
   const APFloat *C;
   // Regardless of whether undefs are allowed,
   // a fully undef constant does not match.
@@ -1285,6 +1505,7 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatForbidUndef(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatAllowUndef(C)));
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
 
   // We can always match simple constants and simple splats.
   C = nullptr;
@@ -1305,14 +1526,33 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZero, m_APFloatAllowUndef(C)));
   EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
 
   // Whether splats with undef can be matched depends on the matcher.
   EXPECT_FALSE(match(VectorZeroUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorZeroUndef, m_APFloatForbidUndef(C)));
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(C, CheckTrue)));
   C = nullptr;
   EXPECT_TRUE(match(VectorZeroUndef, m_APFloatAllowUndef(C)));
   EXPECT_TRUE(C->isZero());
   C = nullptr;
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZeroUndef, m_CheckedFpAllowUndef(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
   EXPECT_TRUE(match(VectorZeroUndef, m_Finite(C)));
   EXPECT_TRUE(C->isZero());
 }

>From 6a6c35f20a9c748c58eb8129a2f950155ac33a26 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:18 -0500
Subject: [PATCH 2/2] [InstCombine] Add example usage for new `Checked` matcher
 API

There is no real motivation for this change other than to highlight a
case where the new `Checked` matcher API can handle non-splat-vecs
without increasing code complexity.
---
 .../InstCombine/InstCombineCompares.cpp       | 66 +++++++++----------
 .../InstCombine/signed-truncation-check.ll    | 30 ++-------
 2 files changed, 36 insertions(+), 60 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
     case ICmpInst::ICMP_ULT: {
       if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A <u C -> A == C-1 if min(A)+1 == C
-        if (*CmpC == Op0Min + 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC - 1));
-        // X <u C --> X == 0, if the number of zero bits in the bottom of X
-        // exceeds the log2 of C.
-        if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              Constant::getNullValue(Op1->getType()));
-      }
+      // A <u C -> A == C-1 if min(A)+1 == C
+      if (match(Op1, m_SpecificInt(Op0Min + 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Min));
+      // X <u C --> X == 0, if the number of zero bits in the bottom of X
+      // exceeds the log2 of C.
+      if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+                  return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+                })))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            Constant::getNullValue(Op1->getType()));
       break;
     }
     case ICmpInst::ICMP_UGT: {
       if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        // A >u C -> A == C+1 if max(a)-1 == C
-        if (*CmpC == Op0Max - 1)
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-        // X >u C --> X != 0, if the number of zero bits in the bottom of X
-        // exceeds the log2 of C.
-        if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
-          return new ICmpInst(ICmpInst::ICMP_NE, Op0,
-                              Constant::getNullValue(Op1->getType()));
-      }
+      // A >u C -> A == C+1 if max(a)-1 == C
+      if (match(Op1, m_SpecificInt(Op0Max - 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Max));
+      // X >u C --> X != 0, if the number of zero bits in the bottom of X
+      // exceeds the log2 of C.
+      if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+                  return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+                })))
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+                            Constant::getNullValue(Op1->getType()));
       break;
     }
     case ICmpInst::ICMP_SLT: {
       if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC - 1));
-      }
+      // A <s C -> A == C-1 if min(A)+1 == C
+      if (match(Op1, m_SpecificInt(Op0Min + 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Min));
       break;
     }
     case ICmpInst::ICMP_SGT: {
       if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
         return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-      const APInt *CmpC;
-      if (match(Op1, m_APInt(CmpC))) {
-        if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
-          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                              ConstantInt::get(Op1->getType(), *CmpC + 1));
-      }
+      // A >s C -> A == C+1 if max(A)-1 == C
+      if (match(Op1, m_SpecificInt(Op0Max - 1)))
+        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                            ConstantInt::get(Op1->getType(), Op0Max));
       break;
     }
     }
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
 
 define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
 ; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT:    [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT:    [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT:    [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT:    [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT:    [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
 ; CHECK-NEXT:    ret <3 x i1> [[T4]]
 ;
   %t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>



More information about the llvm-commits mailing list