[llvm] [IR][PatternMatch] Make `m_Checked{Int,Fp}` accept `Constant *` output instead of `APInt *` (PR #91377)

via llvm-commits llvm-commits at lists.llvm.org
Tue May 7 11:33:07 PDT 2024


https://github.com/goldsteinn created https://github.com/llvm/llvm-project/pull/91377

The `APInt *` version is pretty useless as any case one needs an
`APInt *` out, they could just replace whatever they have the
`m_Checked...` lambda with direct checks on the `APInt`.

Leaving other helpers such as `m_Negative`, `m_Power2`,
etc... unchanged as the `APInt` out version is used mostly for
convenience and rarely change functionality when converted output a
`Constant *`.


>From 98fe785cbccd0b2feae55bfd3211b3582afe42e8 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Tue, 7 May 2024 11:58:14 -0500
Subject: [PATCH] [IR][PatternMatch] Make `m_Checked{Int,Fp}` accept `Constant
 *` output instead of `APInt *`

The `APInt *` version is pretty useless as any case one needs an
`APInt *` out, they could just replace whatever they have the
`m_Checked...` lambda with direct checks on the `APInt`.

Leaving other helpers such as `m_Negative`, `m_Power2`,
etc... unchanged as the `APInt` out version is used mostly for
convenience and rarely change functionality when converted output a
`Constant *`.
---
 llvm/include/llvm/IR/PatternMatch.h | 32 ++++++-----
 llvm/unittests/IR/PatternMatch.cpp  | 89 ++++++++++++++++++-----------
 2 files changed, 74 insertions(+), 47 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 5d8f5c134bb5b5..171ddab977dea2 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -354,7 +354,8 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
 /// is true.
 template <typename Predicate, typename ConstantVal, bool AllowPoison>
 struct cstval_pred_ty : public Predicate {
-  template <typename ITy> bool match(ITy *V) {
+  const Constant **Res = nullptr;
+  template <typename ITy> bool match_impl(ITy *V) {
     if (const auto *CV = dyn_cast<ConstantVal>(V))
       return this->isValue(CV->getValue());
     if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
@@ -387,6 +388,15 @@ struct cstval_pred_ty : public Predicate {
     }
     return false;
   }
+
+  template <typename ITy> bool match(ITy *V) {
+    if (this->match_impl(V)) {
+      if (Res)
+        *Res = cast<Constant>(V);
+      return true;
+    }
+    return false;
+  }
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
@@ -469,28 +479,24 @@ template <typename APTy> struct custom_checkfn {
 /// For vectors, poison elements are assumed to match.
 inline cst_pred_ty<custom_checkfn<APInt>>
 m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
-  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+  return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}};
 }
 
-inline api_pred_ty<custom_checkfn<APInt>>
-m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
-  api_pred_ty<custom_checkfn<APInt>> P(V);
-  P.CheckFn = CheckFn;
-  return P;
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedInt(const Constant *&V, function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>>{{CheckFn}, &V};
 }
 
 /// Match a float or vector where CheckFn(ele) for each element is true.
 /// For vectors, poison elements are assumed to match.
 inline cstfp_pred_ty<custom_checkfn<APFloat>>
 m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
-  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}};
 }
 
-inline apf_pred_ty<custom_checkfn<APFloat>>
-m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
-  apf_pred_ty<custom_checkfn<APFloat>> P(V);
-  P.CheckFn = CheckFn;
-  return P;
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFp(const Constant *&V, function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{{CheckFn}, &V};
 }
 
 struct is_any_apint {
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index d5a4a6a05687d0..6e79d5cd8ed13b 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -614,7 +614,7 @@ TEST_F(PatternMatchTest, BitCast) {
 TEST_F(PatternMatchTest, CheckedInt) {
   Type *I8Ty = IRB.getInt8Ty();
   const APInt *Res = nullptr;
-
+  const Constant * CRes = nullptr;
   auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
   auto CheckTrue = [](const APInt &) { return true; };
   auto CheckFalse = [](const APInt &) { return false; };
@@ -625,39 +625,49 @@ TEST_F(PatternMatchTest, CheckedInt) {
     APInt APVal(8, Val);
     Constant *C = ConstantInt::get(I8Ty, Val);
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
-    EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+    EXPECT_TRUE(m_CheckedInt(CRes, CheckTrue).match(C));
+    EXPECT_NE(CRes, nullptr);
+    EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
     EXPECT_EQ(*Res, APVal);
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
-    EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+    EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(C));
+    EXPECT_EQ(CRes, nullptr);
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
-    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CRes, CheckUgt1).match(C));
     if (CheckUgt1(APVal)) {
-      EXPECT_NE(Res, nullptr);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
       EXPECT_EQ(*Res, APVal);
     }
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
-    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CRes, CheckNonZero).match(C));
     if (CheckNonZero(APVal)) {
-      EXPECT_NE(Res, nullptr);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
       EXPECT_EQ(*Res, APVal);
     }
 
+    CRes = nullptr;
     Res = nullptr;
     EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
-    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CRes, CheckPow2).match(C));
     if (CheckPow2(APVal)) {
-      EXPECT_NE(Res, nullptr);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_TRUE(match(CRes, m_APIntAllowPoison(Res)));
       EXPECT_EQ(*Res, APVal);
     }
-
   };
 
   DoScalarCheck(0);
@@ -666,20 +676,20 @@ TEST_F(PatternMatchTest, CheckedInt) {
   DoScalarCheck(3);
 
   EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
-  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
-  EXPECT_EQ(Res, nullptr);
+  EXPECT_FALSE(m_CheckedInt(CRes, CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(CRes, nullptr);
 
   auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
                             function_ref<bool(const APInt &)> CheckFn,
@@ -711,13 +721,16 @@ TEST_F(PatternMatchTest, CheckedInt) {
     EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
               m_CheckedInt(CheckFn).match(C));
 
+    CRes = nullptr;
     Res = nullptr;
     bool Expec =
-        !(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
-    EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+        !(HasUndef && !UndefAsPoison) && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedInt(CRes, CheckFn).match(C));
     if (Expec) {
-      EXPECT_NE(Res, nullptr);
-      EXPECT_EQ(*Res, *First);
+      EXPECT_NE(CRes, nullptr);
+      EXPECT_EQ(match(CRes, m_APIntAllowPoison(Res)), AllSame);
+      if (AllSame)
+        EXPECT_EQ(*Res, *First);
     }
   };
   auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
@@ -1559,24 +1572,25 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
 
   const APFloat *C;
+  const Constant *CC;
   // Regardless of whether poison is allowed,
   // a fully undef/poison constant does not match.
   EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
   EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(CC, CheckTrue)));
   EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(CC, CheckTrue)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(CC, CheckTrue)));
   EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
   EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
-  EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_FALSE(match(VectorPoison, m_CheckedFp(CC, CheckTrue)));
 
   // We can always match simple constants and simple splats.
   C = nullptr;
@@ -1597,12 +1611,13 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
-  C = nullptr;
-  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
-  EXPECT_TRUE(C->isZero());
-  C = nullptr;
-  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
-  EXPECT_TRUE(C->isZero());
+
+  CC = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckTrue)));
+  EXPECT_TRUE(CC->isNullValue());
+  CC = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(CC, CheckNonNaN)));
+  EXPECT_TRUE(CC->isNullValue());
 
   // Splats with undef are never allowed.
   // Whether splats with poison can be matched depends on the matcher.
@@ -1627,11 +1642,17 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
   EXPECT_TRUE(C->isZero());
+  CC = nullptr;
   C = nullptr;
-  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckTrue)));
+  EXPECT_NE(CC, nullptr);
+  EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
+  CC = nullptr;
   C = nullptr;
-  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CC, CheckNonNaN)));
+  EXPECT_NE(CC, nullptr);
+  EXPECT_TRUE(match(CC, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
 }
 



More information about the llvm-commits mailing list