[llvm] [PatternMatching] Add generic API for matching constants using custom conditions (PR #85676)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 18 11:15:26 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-ir
Author: None (goldsteinn)
<details>
<summary>Changes</summary>
- **[PatternMatching] Add generic API for matching constants using custom conditions**
- **[InstCombine] Add example usage for new `Checked` matcher API**
---
Patch is 26.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85676.diff
4 Files Affected:
- (modified) llvm/include/llvm/IR/PatternMatch.h (+80-11)
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+30-36)
- (modified) llvm/test/Transforms/InstCombine/signed-truncation-check.ll (+6-24)
- (modified) llvm/unittests/IR/PatternMatch.cpp (+240)
``````````diff
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 382009d9df785d..4333d3e6e8da2a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -346,7 +346,7 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
/// This helper class is used to match constant scalars, vector splats,
/// and fixed width vectors that satisfy a specified predicate.
/// For fixed width vector constants, undefined elements are ignored.
-template <typename Predicate, typename ConstantVal>
+template <typename Predicate, typename ConstantVal, bool AllowUndefs>
struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -369,8 +369,11 @@ struct cstval_pred_ty : public Predicate {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
return false;
- if (isa<UndefValue>(Elt))
+ if (isa<UndefValue>(Elt)) {
+ if (!AllowUndefs)
+ return false;
continue;
+ }
auto *CV = dyn_cast<ConstantVal>(Elt);
if (!CV || !this->isValue(CV->getValue()))
return false;
@@ -384,16 +387,17 @@ struct cstval_pred_ty : public Predicate {
};
/// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowUndefs = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowUndefs>;
/// specialization of cstval_pred_ty for ConstantFP
-template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+template <typename Predicate, bool AllowUndefs = true>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP, AllowUndefs>;
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
-template <typename Predicate> struct api_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct api_pred_ty : public Predicate {
const APInt *&Res;
api_pred_ty(const APInt *&R) : Res(R) {}
@@ -406,7 +410,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -419,7 +424,8 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APFloat.
/// Undefs are allowed in splat vector constants.
-template <typename Predicate> struct apf_pred_ty : public Predicate {
+template <typename Predicate, bool AllowUndefs = true>
+struct apf_pred_ty : public Predicate {
const APFloat *&Res;
apf_pred_ty(const APFloat *&R) : Res(R) {}
@@ -432,8 +438,8 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
}
if (V->getType()->isVectorTy())
if (const auto *C = dyn_cast<Constant>(V))
- if (auto *CI = dyn_cast_or_null<ConstantFP>(
- C->getSplatValue(/* AllowUndef */ true)))
+ if (auto *CI =
+ dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndefs)))
if (this->isValue(CI->getValue())) {
Res = &CI->getValue();
return true;
@@ -452,6 +458,69 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
//
///////////////////////////////////////////////////////////////////////////////
+template <typename APTy> struct custom_checkfn {
+ function_ref<bool(const APTy &)> CheckFn;
+ bool isValue(const APTy &C) { return CheckFn(C); }
+};
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cst_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>, false>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>, false>
+m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and integer or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cst_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(function_ref<bool(const APInt &)> CheckFn) {
+ return cst_pred_ty<custom_checkfn<APInt>>{CheckFn};
+}
+
+inline api_pred_ty<custom_checkfn<APInt>>
+m_CheckedIntAllowUndef(const APInt *&V,
+ function_ref<bool(const APInt &)> CheckFn) {
+ api_pred_ty<custom_checkfn<APInt>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed NOT to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>, false>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>, false>
+m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>, false> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
+// Match and float or vector where CheckFn(ele) for each element is true.
+// For vectors, undefined elements are assumed to match.
+inline cstfp_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(function_ref<bool(const APFloat &)> CheckFn) {
+ return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn};
+}
+
+inline apf_pred_ty<custom_checkfn<APFloat>>
+m_CheckedFpAllowUndef(const APFloat *&V,
+ function_ref<bool(const APFloat &)> CheckFn) {
+ apf_pred_ty<custom_checkfn<APFloat>> P(V);
+ P.CheckFn = CheckFn;
+ return P;
+}
+
struct is_any_apint {
bool isValue(const APInt &C) { return true; }
};
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 0dce0077bf1588..711294e4635579 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6347,57 +6347,51 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
case ICmpInst::ICMP_ULT: {
if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (*CmpC == Op0Min + 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- // X <u C --> X == 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.ceilLogBase2();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_UGT: {
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- // X >u C --> X != 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
- return new ICmpInst(ICmpInst::ICMP_NE, Op0,
- Constant::getNullValue(Op1->getType()));
- }
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (match(Op1, m_CheckedInt([&Op0Known](const APInt &C) {
+ return Op0Known.countMinTrailingZeros() >= C.getActiveBits();
+ })))
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
break;
}
case ICmpInst::ICMP_SLT: {
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- }
+ // A <s C -> A == C-1 if min(A)+1 == C
+ if (match(Op1, m_SpecificInt(Op0Min + 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Min));
break;
}
case ICmpInst::ICMP_SGT: {
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- }
+ // A >s C -> A == C+1 if max(A)-1 == C
+ if (match(Op1, m_SpecificInt(Op0Max - 1)))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), Op0Max));
break;
}
}
diff --git a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
index 208e166b2c8760..465235fb08d383 100644
--- a/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
+++ b/llvm/test/Transforms/InstCombine/signed-truncation-check.ll
@@ -212,10 +212,7 @@ define <3 x i1> @positive_vec_undef0(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef1(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -227,10 +224,7 @@ define <3 x i1> @positive_vec_undef1(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef2(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -242,10 +236,7 @@ define <3 x i1> @positive_vec_undef2(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef3(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 256, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -257,10 +248,7 @@ define <3 x i1> @positive_vec_undef3(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef4(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 128, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
@@ -272,10 +260,7 @@ define <3 x i1> @positive_vec_undef4(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef5(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 -1, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 -1, i32 -1>
@@ -287,10 +272,7 @@ define <3 x i1> @positive_vec_undef5(<3 x i32> %arg) {
define <3 x i1> @positive_vec_undef6(<3 x i32> %arg) {
; CHECK-LABEL: @positive_vec_undef6(
-; CHECK-NEXT: [[T1:%.*]] = icmp sgt <3 x i32> [[ARG:%.*]], <i32 -1, i32 undef, i32 -1>
-; CHECK-NEXT: [[T2:%.*]] = add <3 x i32> [[ARG]], <i32 128, i32 undef, i32 128>
-; CHECK-NEXT: [[T3:%.*]] = icmp ult <3 x i32> [[T2]], <i32 256, i32 undef, i32 256>
-; CHECK-NEXT: [[T4:%.*]] = and <3 x i1> [[T1]], [[T3]]
+; CHECK-NEXT: [[T4:%.*]] = icmp ult <3 x i32> [[ARG:%.*]], <i32 128, i32 128, i32 128>
; CHECK-NEXT: ret <3 x i1> [[T4]]
;
%t1 = icmp sgt <3 x i32> %arg, <i32 -1, i32 undef, i32 -1>
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index 533a30bfba45dd..de361c70804c3e 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -572,6 +572,169 @@ TEST_F(PatternMatchTest, BitCast) {
EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32));
}
+TEST_F(PatternMatchTest, CheckedInt) {
+ Type *I8Ty = IRB.getInt8Ty();
+ const APInt *Res = nullptr;
+
+ auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); };
+ auto CheckTrue = [](const APInt &) { return true; };
+ auto CheckFalse = [](const APInt &) { return false; };
+ auto CheckNonZero = [](const APInt &C) { return !C.isZero(); };
+ auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); };
+
+ auto DoScalarCheck = [&](int8_t Val) {
+ APInt APVal(8, Val);
+ Constant *C = ConstantInt::get(I8Ty, Val);
+
+ Res = nullptr;
+ EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C));
+ EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C));
+ EXPECT_EQ(*Res, APVal);
+
+ Res = nullptr;
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C));
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckUgt1(APVal), m_CheckedIntAllowUndef(CheckUgt1).match(C));
+ EXPECT_EQ(CheckUgt1(APVal),
+ m_CheckedIntAllowUndef(Res, CheckUgt1).match(C));
+ if (CheckUgt1(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(CheckNonZero).match(C));
+ EXPECT_EQ(CheckNonZero(APVal),
+ m_CheckedIntAllowUndef(Res, CheckNonZero).match(C));
+ if (CheckNonZero(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+
+ Res = nullptr;
+ EXPECT_EQ(CheckPow2(APVal), m_CheckedIntAllowUndef(CheckPow2).match(C));
+ EXPECT_EQ(CheckPow2(APVal),
+ m_CheckedIntAllowUndef(Res, CheckPow2).match(C));
+ if (CheckPow2(APVal)) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, APVal);
+ }
+ };
+
+ DoScalarCheck(0);
+ DoScalarCheck(1);
+ DoScalarCheck(2);
+ DoScalarCheck(3);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty)));
+ EXPECT_EQ(Res, nullptr);
+
+ auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals,
+ function_ref<bool(const APInt &)> CheckFn,
+ bool UndefAsPoison) {
+ SmallVector<Constant *> VecElems;
+ std::optional<bool> Okay;
+ bool AllSame = true;
+ bool HasUndef = false;
+ std::optional<APInt> First;
+ for (const std::optional<int8_t> &Val : Vals) {
+ if (!Val.has_value()) {
+ VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty)
+ : UndefValue::get(I8Ty));
+ HasUndef = true;
+ } else {
+ if (!Okay.has_value())
+ Okay = true;
+ APInt APVal(8, *Val);
+ if (!First.has_value())
+ First = APVal;
+ else
+ AllSame &= First->eq(APVal);
+ Okay = *Okay && CheckFn(APVal);
+ VecElems.push_back(ConstantInt::get(I8Ty, *Val));
+ }
+ }
+
+ Constant *C = ConstantVector::get(VecElems);
+ EXPECT_EQ(!HasUndef && Okay.value_or(false),
+ m_CheckedInt(CheckFn).match(C));
+ EXPECT_EQ(Okay.value_or(false), m_CheckedIntAllowUndef(CheckFn).match(C));
+
+ Res = nullptr;
+ bool Expec = !HasUndef && AllSame && Okay.value_or(false);
+ EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C));
+ if (Expec) {
+ EXPECT_NE(Res, nullptr);
+ EXPECT_EQ(*Res, *First);
+ }
+
+ Res = nullptr;
+ Expec = AllSame && Okay.value_or(f...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/85676
More information about the llvm-commits
mailing list