[llvm] [PatternMatching] Add generic API for matching constants using custom conditions (PR #85676)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 25 14:29:34 PDT 2024


https://github.com/goldsteinn updated https://github.com/llvm/llvm-project/pull/85676

>From ffdcb0d1c09140e85070172e02ca438f485adf50 Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 18 Mar 2024 13:00:14 -0500
Subject: [PATCH 1/3] [PatternMatching] Add generic API for matching constants
 using custom conditions

The new API is:
    `m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
        - Matches non-undef constants s.t `Lambda(ele)` is true for all
          elements.
    `m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
        - Matches constants/undef s.t `Lambda(ele)` is true for all
          elements.

The goal with these is to be able to replace the common usage of:
```
    match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
    match(X, m_CheckedInt(C, CustomChecks);
```

The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.

The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
---
 llvm/include/llvm/IR/PatternMatch.h |  33 ++++
 llvm/unittests/IR/PatternMatch.cpp  | 255 ++++++++++++++++++++++++++++
 2 files changed, 288 insertions(+)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 0b13b4aad9c326..6c9fbba4755947 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -460,6 +460,39 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
 //
 ///////////////////////////////////////////////////////////////////////////////
 
+template <typename APTy> struct custom_checkfn {
+  function_ref<bool(const APTy &)> CheckFn;
+  bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, poison elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+  return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+  api_pred_ty<custom_checkfn<APInt>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, poison elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+  return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+  apf_pred_ty<custom_checkfn<APFloat>> P(V);
+  P.CheckFn = CheckFn;
+  return P;
+}
+
 struct is_any_apint {
   bool isValue(const APInt &C) { return true; }
 };
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index a25885faa3a442..6b863f10a64482 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -611,6 +611,212 @@ TEST_F(PatternMatchTest, BitCast) {
   EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
 }
 
+TEST_F(PatternMatchTest, CustomCheckFn) {
+  APInt I0(64, 0);
+  APInt I1(64, 0);
+
+  auto CheckIsZeroI = [](const APInt &C) { return C.isZero(); };
+  auto CheckIsEqI1 = [&I1](const APInt &C) { return C.eq(I1); };
+  auto CheckIsNeI1 = [&I1](const APInt &C) { return !C.eq(I1); };
+
+  custom_checkfn<APInt> CustomCheckZeroI;
+  CustomCheckZeroI.CheckFn = CheckIsZeroI;
+  custom_checkfn<APInt> CustomCheckEqI1;
+  CustomCheckEqI1.CheckFn = CheckIsEqI1;
+  custom_checkfn<APInt> CustomCheckNeI1;
+  CustomCheckNeI1.CheckFn = CheckIsNeI1;
+
+  EXPECT_TRUE(CustomCheckZeroI.isValue(I0));
+  EXPECT_TRUE(CustomCheckEqI1.isValue(I0));
+  EXPECT_FALSE(CustomCheckNeI1.isValue(I0));
+
+  I0.setBit(0);
+
+  EXPECT_FALSE(CustomCheckZeroI.isValue(I0));
+  EXPECT_FALSE(CustomCheckEqI1.isValue(I0));
+  EXPECT_TRUE(CustomCheckNeI1.isValue(I0));
+
+  I1.setBit(0);
+
+  EXPECT_FALSE(CustomCheckZeroI.isValue(I0));
+  EXPECT_TRUE(CustomCheckEqI1.isValue(I0));
+  EXPECT_FALSE(CustomCheckNeI1.isValue(I0));
+
+  APFloat F0(0.0);
+  APFloat F1(0.0);
+
+  auto CheckIsZeroF = [](const APFloat &C) { return C.isZero(); };
+  auto CheckIsEqF1 = [&F1](const APFloat &C) {
+    return C.bitcastToAPInt().eq(F1.bitcastToAPInt());
+  };
+  auto CheckIsNeF1 = [&F1](const APFloat &C) {
+    return !C.bitcastToAPInt().eq(F1.bitcastToAPInt());
+  };
+
+  custom_checkfn<APFloat> CustomCheckZeroF;
+  CustomCheckZeroF.CheckFn = CheckIsZeroF;
+  custom_checkfn<APFloat> CustomCheckEqF1;
+  CustomCheckEqF1.CheckFn = CheckIsEqF1;
+  custom_checkfn<APFloat> CustomCheckNeF1;
+  CustomCheckNeF1.CheckFn = CheckIsNeF1;
+
+  EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
+  EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
+  EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
+
+  F0 = -F0;
+
+  EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
+  EXPECT_FALSE(CustomCheckEqF1.isValue(F0));
+  EXPECT_TRUE(CustomCheckNeF1.isValue(F0));
+
+  F0 = -F0;
+
+  EXPECT_TRUE(CustomCheckZeroF.isValue(F0));
+  EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
+  EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
+
+  F0 = F0 + APFloat(1.0);
+
+  EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
+  EXPECT_FALSE(CustomCheckEqF1.isValue(F0));
+  EXPECT_TRUE(CustomCheckNeF1.isValue(F0));
+
+  F1 = F1 + APFloat(1.0);
+
+  EXPECT_FALSE(CustomCheckZeroF.isValue(F0));
+  EXPECT_TRUE(CustomCheckEqF1.isValue(F0));
+  EXPECT_FALSE(CustomCheckNeF1.isValue(F0));
+}
+
+TEST_F(PatternMatchTest, CheckedInt) {
+  Type *I8Ty = IRB.getInt8Ty();
+  const APInt *Res = nullptr;
+
+  auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+  auto CheckTrue = [](const APInt &) { return true; };
+  auto CheckFalse = [](const APInt &) { return false; };
+  auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+  auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+  auto DoScalarCheck = [&](int8_t Val) {
+    APInt APVal(8, Val);
+    Constant *C = ConstantInt::get(I8Ty, Val);
+
+    Res = nullptr;
+    EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+    EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+    EXPECT_EQ(*Res, APVal);
+
+    Res = nullptr;
+    EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+    EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+    Res = nullptr;
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+    EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+    if (CheckUgt1(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+    EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+    if (CheckNonZero(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+    Res = nullptr;
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+    EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+    if (CheckPow2(APVal)) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, APVal);
+    }
+
+  };
+
+  DoScalarCheck(0);
+  DoScalarCheck(1);
+  DoScalarCheck(2);
+  DoScalarCheck(3);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+  EXPECT_EQ(Res, nullptr);
+
+  auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+                            function_ref<bool(const APInt &)> CheckFn,
+                            bool UndefAsPoison) {
+    SmallVector<Constant *> VecElems;
+    std::optional<bool> Okay;
+    bool AllSame = true;
+    bool HasUndef = false;
+    std::optional<APInt> First;
+    for (const std::optional<int8_t> &Val : Vals) {
+      if (!Val.has_value()) {
+        VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+                                         : UndefValue::get(I8Ty));
+        HasUndef = true;
+      } else {
+        if (!Okay.has_value())
+          Okay = true;
+        APInt APVal(8, *Val);
+        if (!First.has_value())
+          First = APVal;
+        else
+          AllSame &= First->eq(APVal);
+        Okay = *Okay && CheckFn(APVal);
+        VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+      }
+    }
+
+    Constant *C = ConstantVector::get(VecElems);
+    EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false),
+              m_CheckedInt(CheckFn).match(C));
+
+    Res = nullptr;
+    bool Expec =
+        !(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false);
+    EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+    if (Expec) {
+      EXPECT_NE(Res, nullptr);
+      EXPECT_EQ(*Res, *First);
+    }
+  };
+  auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) {
+    DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true);
+    DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true);
+    DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false);
+    DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false);
+  };
+
+  DoVecCheck({0, 1});
+  DoVecCheck({1, 1});
+  DoVecCheck({1, 2});
+  DoVecCheck({1, std::nullopt});
+  DoVecCheck({1, std::nullopt, 1});
+  DoVecCheck({1, std::nullopt, 2});
+  DoVecCheck({std::nullopt, std::nullopt, std::nullopt});
+}
+
 TEST_F(PatternMatchTest, Power2) {
   Value *C128 = IRB.getInt32(128);
   Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128));
@@ -1397,21 +1603,58 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   EXPECT_FALSE(match(VectorInfPoison, m_Finite()));
   EXPECT_FALSE(match(VectorNaNPoison, m_Finite()));
 
+  auto CheckTrue = [](const APFloat &) { return true; };
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckTrue)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue)));
+  EXPECT_TRUE(match(VectorNaNPoison, m_CheckedFp(CheckTrue)));
+
+  auto CheckFalse = [](const APFloat &) { return false; };
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorZeroPoison, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorInfPoison, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse)));
+  EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckFalse)));
+
+  auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); };
+  EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN)));
+  EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN)));
+
   const APFloat *C;
   // Regardless of whether poison is allowed,
   // a fully undef/poison constant does not match.
   EXPECT_FALSE(match(ScalarUndef, m_APFloat(C)));
   EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C)));
+  EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue)));
   EXPECT_FALSE(match(VectorUndef, m_APFloat(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C)));
+  EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloat(C)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C)));
+  EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue)));
   EXPECT_FALSE(match(VectorPoison, m_APFloat(C)));
   EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C)));
   EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C)));
+  EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue)));
 
   // We can always match simple constants and simple splats.
   C = nullptr;
@@ -1432,6 +1675,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C)));
   EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
 
   // Splats with undef are never allowed.
   // Whether splats with poison can be matched depends on the matcher.
@@ -1456,6 +1705,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) {
   C = nullptr;
   EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C)));
   EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue)));
+  EXPECT_TRUE(C->isZero());
+  C = nullptr;
+  EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN)));
+  EXPECT_TRUE(C->isZero());
 }
 
 TEST_F(PatternMatchTest, FloatingPointFNeg) {

>From c0c1049df4868cc0c289e6aeb926dc6c34c4ad5d Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 22 Apr 2024 16:01:03 -0500
Subject: [PATCH 2/3] [InstCombine] Add non-splat test for `(icmp (lshr x, y),
 x)`; NFC

---
 .../InstCombine/icmp-div-constant.ll          | 27 +++++++++++++------
 1 file changed, 19 insertions(+), 8 deletions(-)

diff --git a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
index b047715432d779..cf1746d85be8a1 100644
--- a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
@@ -118,8 +118,8 @@ define i32 @icmp_div(i16 %a, i16 %c) {
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
 ; CHECK-NEXT:    br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]
@@ -173,8 +173,8 @@ define i32 @icmp_div3(i16 %a, i16 %c) {
 ; CHECK-NEXT:    [[TOBOOL:%.*]] = icmp eq i16 [[A:%.*]], 0
 ; CHECK-NEXT:    br i1 [[TOBOOL]], label [[THEN:%.*]], label [[EXIT:%.*]]
 ; CHECK:       then:
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i16 [[C:%.*]], 0
-; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP]] to i32
+; CHECK-NEXT:    [[CMP_NOT:%.*]] = icmp eq i16 [[C:%.*]], 0
+; CHECK-NEXT:    [[TMP0:%.*]] = sext i1 [[CMP_NOT]] to i32
 ; CHECK-NEXT:    br label [[EXIT]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    [[PHI:%.*]] = phi i32 [ -1, [[ENTRY:%.*]] ], [ [[TMP0]], [[THEN]] ]
@@ -381,8 +381,8 @@ define i1 @sdiv_eq_smin_use(i32 %x, i32 %y) {
 
 define i1 @sdiv_x_by_const_cmp_x(i32 %x) {
 ; CHECK-LABEL: @sdiv_x_by_const_cmp_x(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i32 [[X:%.*]], 0
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i32 [[X:%.*]], 0
+; CHECK-NEXT:    ret i1 [[R]]
 ;
   %v = sdiv i32 %x, 13
   %r = icmp eq i32 %v, %x
@@ -403,8 +403,8 @@ define i1 @udiv_x_by_const_cmp_x(i32 %x) {
 
 define i1 @lshr_x_by_const_cmp_x(i32 %x) {
 ; CHECK-LABEL: @lshr_x_by_const_cmp_x(
-; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i32 [[X:%.*]], 0
-; CHECK-NEXT:    ret i1 [[TMP1]]
+; CHECK-NEXT:    [[R:%.*]] = icmp eq i32 [[X:%.*]], 0
+; CHECK-NEXT:    ret i1 [[R]]
 ;
   %v = lshr i32 %x, 1
   %r = icmp eq i32 %v, %x
@@ -421,6 +421,17 @@ define <4 x i1> @lshr_by_const_cmp_sle_value(<4 x i32> %x) {
   ret <4 x i1> %r
 }
 
+define <4 x i1> @lshr_by_const_cmp_sle_value_non_splat(<4 x i32> %x) {
+; CHECK-LABEL: @lshr_by_const_cmp_sle_value_non_splat(
+; CHECK-NEXT:    [[V:%.*]] = lshr <4 x i32> [[X:%.*]], <i32 3, i32 3, i32 3, i32 5>
+; CHECK-NEXT:    [[R:%.*]] = icmp sle <4 x i32> [[V]], [[X]]
+; CHECK-NEXT:    ret <4 x i1> [[R]]
+;
+  %v = lshr <4 x i32> %x, <i32 3, i32 3, i32 3, i32 5>
+  %r = icmp sle <4 x i32> %v, %x
+  ret <4 x i1> %r
+}
+
 define i1 @lshr_by_const_cmp_sge_value(i32 %x) {
 ; CHECK-LABEL: @lshr_by_const_cmp_sge_value(
 ; CHECK-NEXT:    [[R:%.*]] = icmp slt i32 [[X:%.*]], 1

>From e4b9b1ed82fefb80d9dba7e1998a4c031449599f Mon Sep 17 00:00:00 2001
From: Noah Goldstein <goldstein.w.n at gmail.com>
Date: Mon, 22 Apr 2024 16:01:11 -0500
Subject: [PATCH 3/3] [InstCombine] Add example usage for new Checked matcher
 API

There is no real motivation for this change other than to highlight a
case where the new `Checked` matcher API can handle non-splat-vecs
without increasing code complexity.
---
 .../InstCombine/InstCombineCompares.cpp          | 16 ++++++----------
 .../Transforms/InstCombine/icmp-div-constant.ll  |  3 +--
 2 files changed, 7 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 1064340cb53661..794d29a992a76a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7123,34 +7123,30 @@ Instruction *InstCombinerImpl::foldICmpCommutative(ICmpInst::Predicate Pred,
     return replaceInstUsesWith(CxtI, V);
 
   // Folding (X / Y) pred X => X swap(pred) 0 for constant Y other than 0 or 1
+  auto CheckUGT1 = [](const APInt &Divisor) { return Divisor.ugt(1); };
   {
-    const APInt *Divisor;
-    if (match(Op0, m_UDiv(m_Specific(Op1), m_APInt(Divisor))) &&
-        Divisor->ugt(1)) {
+    if (match(Op0, m_UDiv(m_Specific(Op1), m_CheckedInt(CheckUGT1)))) {
       return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1,
                           Constant::getNullValue(Op1->getType()));
     }
 
     if (!ICmpInst::isUnsigned(Pred) &&
-        match(Op0, m_SDiv(m_Specific(Op1), m_APInt(Divisor))) &&
-        Divisor->ugt(1)) {
+        match(Op0, m_SDiv(m_Specific(Op1), m_CheckedInt(CheckUGT1)))) {
       return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1,
                           Constant::getNullValue(Op1->getType()));
     }
   }
 
   // Another case of this fold is (X >> Y) pred X => X swap(pred) 0 if Y != 0
+  auto CheckNE0 = [](const APInt &Shift) { return !Shift.isZero(); };
   {
-    const APInt *Shift;
-    if (match(Op0, m_LShr(m_Specific(Op1), m_APInt(Shift))) &&
-        !Shift->isZero()) {
+    if (match(Op0, m_LShr(m_Specific(Op1), m_CheckedInt(CheckNE0)))) {
       return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1,
                           Constant::getNullValue(Op1->getType()));
     }
 
     if ((Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_SGE) &&
-        match(Op0, m_AShr(m_Specific(Op1), m_APInt(Shift))) &&
-        !Shift->isZero()) {
+        match(Op0, m_AShr(m_Specific(Op1), m_CheckedInt(CheckNE0)))) {
       return new ICmpInst(ICmpInst::getSwappedPredicate(Pred), Op1,
                           Constant::getNullValue(Op1->getType()));
     }
diff --git a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
index cf1746d85be8a1..9d6cf164cbf78c 100644
--- a/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
+++ b/llvm/test/Transforms/InstCombine/icmp-div-constant.ll
@@ -423,8 +423,7 @@ define <4 x i1> @lshr_by_const_cmp_sle_value(<4 x i32> %x) {
 
 define <4 x i1> @lshr_by_const_cmp_sle_value_non_splat(<4 x i32> %x) {
 ; CHECK-LABEL: @lshr_by_const_cmp_sle_value_non_splat(
-; CHECK-NEXT:    [[V:%.*]] = lshr <4 x i32> [[X:%.*]], <i32 3, i32 3, i32 3, i32 5>
-; CHECK-NEXT:    [[R:%.*]] = icmp sle <4 x i32> [[V]], [[X]]
+; CHECK-NEXT:    [[R:%.*]] = icmp sgt <4 x i32> [[X:%.*]], <i32 -1, i32 -1, i32 -1, i32 -1>
 ; CHECK-NEXT:    ret <4 x i1> [[R]]
 ;
   %v = lshr <4 x i32> %x, <i32 3, i32 3, i32 3, i32 5>



More information about the llvm-commits mailing list