[llvm] 9c87c55 - [SVE] Make cstfp_pred_ty and cst_pred_ty work with scalable splats
Christopher Tetreault via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 14 14:20:54 PDT 2020
Author: Christopher Tetreault
Date: 2020-07-14T14:20:39-07:00
New Revision: 9c87c5580575cefdebb02cc6685fb6b66fb375c9
URL: https://github.com/llvm/llvm-project/commit/9c87c5580575cefdebb02cc6685fb6b66fb375c9
DIFF: https://github.com/llvm/llvm-project/commit/9c87c5580575cefdebb02cc6685fb6b66fb375c9.diff
LOG: [SVE] Make cstfp_pred_ty and cst_pred_ty work with scalable splats
Reviewers: efriedma, lebedev.ri, fhahn, c-rhodes, david-arm
Reviewed By: efriedma, david-arm
Subscribers: tschuett, rkruppe, psnobl, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D83001
Added:
Modified:
llvm/include/llvm/IR/Constants.h
llvm/include/llvm/IR/PatternMatch.h
llvm/test/Transforms/InstCombine/fmul.ll
llvm/test/Transforms/InstCombine/mul.ll
llvm/unittests/IR/PatternMatch.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 3579c9f1ee33..8e2dba9b2417 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -308,6 +308,7 @@ class ConstantFP final : public ConstantData {
/// Return true if Ty is big enough to represent V.
static bool isValueValidForType(Type *Ty, const APFloat &V);
inline const APFloat &getValueAPF() const { return Val; }
+ inline const APFloat &getValue() const { return Val; }
/// Return true if the value is positive or negative zero.
bool isZero() const { return Val.isZero(); }
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 98182bc3d85d..4c11bc82510b 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -262,17 +262,23 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
return constantint_match<Val>();
}
-/// This helper class is used to match scalar and fixed width vector integer
-/// constants that satisfy a specified predicate.
-/// For vector constants, undefined elements are ignored.
-template <typename Predicate> struct cst_pred_ty : public Predicate {
+/// This helper class is used to match constant scalars, vector splats,
+/// and fixed width vectors that satisfy a specified predicate.
+/// For fixed width vector constants, undefined elements are ignored.
+template <typename Predicate, typename ConstantVal>
+struct cstval_pred_ty : public Predicate {
template <typename ITy> bool match(ITy *V) {
- if (const auto *CI = dyn_cast<ConstantInt>(V))
- return this->isValue(CI->getValue());
- if (const auto *FVTy = dyn_cast<FixedVectorType>(V->getType())) {
+ if (const auto *CV = dyn_cast<ConstantVal>(V))
+ return this->isValue(CV->getValue());
+ if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
if (const auto *C = dyn_cast<Constant>(V)) {
- if (const auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
- return this->isValue(CI->getValue());
+ if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue()))
+ return this->isValue(CV->getValue());
+
+ // Number of elements of a scalable vector unknown at compile time
+ auto *FVTy = dyn_cast<FixedVectorType>(VTy);
+ if (!FVTy)
+ return false;
// Non-splat vector constant: check each element for a match.
unsigned NumElts = FVTy->getNumElements();
@@ -284,8 +290,8 @@ template <typename Predicate> struct cst_pred_ty : public Predicate {
return false;
if (isa<UndefValue>(Elt))
continue;
- auto *CI = dyn_cast<ConstantInt>(Elt);
- if (!CI || !this->isValue(CI->getValue()))
+ auto *CV = dyn_cast<ConstantVal>(Elt);
+ if (!CV || !this->isValue(CV->getValue()))
return false;
HasNonUndefElements = true;
}
@@ -296,6 +302,14 @@ template <typename Predicate> struct cst_pred_ty : public Predicate {
}
};
+/// specialization of cstval_pred_ty for ConstantInt
+template <typename Predicate>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+
+/// specialization of cstval_pred_ty for ConstantFP
+template <typename Predicate>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+
/// This helper class is used to match scalar and vector constants that
/// satisfy a specified predicate, and bind them to an APInt.
template <typename Predicate> struct api_pred_ty : public Predicate {
@@ -321,44 +335,6 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
}
};
-/// This helper class is used to match scalar and vector floating-point
-/// constants that satisfy a specified predicate.
-/// For vector constants, undefined elements are ignored.
-template <typename Predicate> struct cstfp_pred_ty : public Predicate {
- template <typename ITy> bool match(ITy *V) {
- if (const auto *CF = dyn_cast<ConstantFP>(V))
- return this->isValue(CF->getValueAPF());
- if (V->getType()->isVectorTy()) {
- if (const auto *C = dyn_cast<Constant>(V)) {
- if (const auto *CF = dyn_cast_or_null<ConstantFP>(C->getSplatValue()))
- return this->isValue(CF->getValueAPF());
-
- // Number of elements of a scalable vector unknown at compile time
- if (isa<ScalableVectorType>(V->getType()))
- return false;
-
- // Non-splat vector constant: check each element for a match.
- unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
- assert(NumElts != 0 && "Constant vector with no elements?");
- bool HasNonUndefElements = false;
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *Elt = C->getAggregateElement(i);
- if (!Elt)
- return false;
- if (isa<UndefValue>(Elt))
- continue;
- auto *CF = dyn_cast<ConstantFP>(Elt);
- if (!CF || !this->isValue(CF->getValueAPF()))
- return false;
- HasNonUndefElements = true;
- }
- return HasNonUndefElements;
- }
- }
- return false;
- }
-};
-
///////////////////////////////////////////////////////////////////////////////
//
// Encapsulate constant value queries for use in templated predicate matchers.
diff --git a/llvm/test/Transforms/InstCombine/fmul.ll b/llvm/test/Transforms/InstCombine/fmul.ll
index 8e168f252978..4162973f0bed 100644
--- a/llvm/test/Transforms/InstCombine/fmul.ll
+++ b/llvm/test/Transforms/InstCombine/fmul.ll
@@ -1164,3 +1164,12 @@ define double @fmul_sqrt_select(double %x, i1 %c) {
%mul = fmul fast double %sqr, %sel
ret double %mul
}
+
+; fastmath => z * splat(0) = splat(0), even for scalable vectors
+define <vscale x 2 x float> @mul_scalable_splat_zero(<vscale x 2 x float> %z) {
+; CHECK-LABEL: @mul_scalable_splat_zero(
+; CHECK-NEXT: ret <vscale x 2 x float> zeroinitializer
+ %shuf = shufflevector <vscale x 2 x float> insertelement (<vscale x 2 x float> undef, float 0.0, i32 0), <vscale x 2 x float> undef, <vscale x 2 x i32> zeroinitializer
+ %t3 = fmul fast <vscale x 2 x float> %shuf, %z
+ ret <vscale x 2 x float> %t3
+}
diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index 059b18d30b90..d2844561ca7a 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -857,3 +857,12 @@ define <4 x i32> @combine_mul_nabs_v4i32(<4 x i32> %0) {
%m = mul <4 x i32> %r, %r
ret <4 x i32> %m
}
+
+; z * splat(0) = splat(0), even for scalable vectors
+define <vscale x 2 x i64> @mul_scalable_splat_zero(<vscale x 2 x i64> %z) {
+; CHECK-LABEL: @mul_scalable_splat_zero(
+; CHECK-NEXT: ret <vscale x 2 x i64> zeroinitializer
+ %shuf = shufflevector <vscale x 2 x i64> insertelement (<vscale x 2 x i64> undef, i64 0, i32 0), <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
+ %t3 = mul <vscale x 2 x i64> %shuf, %z
+ ret <vscale x 2 x i64> %t3
+}
diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index bbcbd91c8f1f..34dc2767021a 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1325,6 +1325,183 @@ TEST_F(PatternMatchTest, IntrinsicMatcher) {
m_SpecificInt(10))));
}
+namespace {
+
+struct is_unsigned_zero_pred {
+ bool isValue(const APInt &C) { return C.isNullValue(); }
+};
+
+struct is_float_zero_pred {
+ bool isValue(const APFloat &C) { return C.isZero(); }
+};
+
+template <typename T> struct always_true_pred {
+ bool isValue(const T &) { return true; }
+};
+
+template <typename T> struct always_false_pred {
+ bool isValue(const T &) { return false; }
+};
+
+struct is_unsigned_max_pred {
+ bool isValue(const APInt &C) { return C.isMaxValue(); }
+};
+
+struct is_float_nan_pred {
+ bool isValue(const APFloat &C) { return C.isNaN(); }
+};
+
+} // namespace
+
+TEST_F(PatternMatchTest, ConstantPredicateType) {
+
+ // Scalar integer
+ APInt U32Max = APInt::getAllOnesValue(32);
+ APInt U32Zero = APInt::getNullValue(32);
+ APInt U32DeadBeef(32, 0xDEADBEEF);
+
+ Type *U32Ty = Type::getInt32Ty(Ctx);
+
+ Constant *CU32Max = Constant::getIntegerValue(U32Ty, U32Max);
+ Constant *CU32Zero = Constant::getIntegerValue(U32Ty, U32Zero);
+ Constant *CU32DeadBeef = Constant::getIntegerValue(U32Ty, U32DeadBeef);
+
+ EXPECT_TRUE(match(CU32Max, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CU32Max, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CU32Max, cst_pred_ty<always_false_pred<APInt>>()));
+
+ EXPECT_FALSE(match(CU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_TRUE(match(CU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+
+ EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+
+ // Scalar float
+ APFloat F32NaN = APFloat::getNaN(APFloat::IEEEsingle());
+ APFloat F32Zero = APFloat::getZero(APFloat::IEEEsingle());
+ APFloat F32Pi(3.14f);
+
+ Type *F32Ty = Type::getFloatTy(Ctx);
+
+ Constant *CF32NaN = ConstantFP::get(F32Ty, F32NaN);
+ Constant *CF32Zero = ConstantFP::get(F32Ty, F32Zero);
+ Constant *CF32Pi = ConstantFP::get(F32Ty, F32Pi);
+
+ EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+ EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+ EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+ ElementCount FixedEC(4, false);
+ ElementCount ScalableEC(4, true);
+
+ // Vector splat
+
+ for (auto EC : {FixedEC, ScalableEC}) {
+ // integer
+
+ Constant *CSplatU32Max = ConstantVector::getSplat(EC, CU32Max);
+ Constant *CSplatU32Zero = ConstantVector::getSplat(EC, CU32Zero);
+ Constant *CSplatU32DeadBeef = ConstantVector::getSplat(EC, CU32DeadBeef);
+
+ EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<always_false_pred<APInt>>()));
+
+ EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+
+ EXPECT_FALSE(match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(
+ match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(
+ match(CSplatU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CSplatU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+
+ // float
+
+ Constant *CSplatF32NaN = ConstantVector::getSplat(EC, CF32NaN);
+ Constant *CSplatF32Zero = ConstantVector::getSplat(EC, CF32Zero);
+ Constant *CSplatF32Pi = ConstantVector::getSplat(EC, CF32Pi);
+
+ EXPECT_TRUE(match(CSplatF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CSplatF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CSplatF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CSplatF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+ EXPECT_FALSE(match(CSplatF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_TRUE(match(CSplatF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CSplatF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CSplatF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+ EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CSplatF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CSplatF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+ }
+
+ // Int arbitrary vector
+
+ Constant *CMixedU32 = ConstantVector::get({CU32Max, CU32Zero, CU32DeadBeef});
+ Constant *CU32Undef = UndefValue::get(U32Ty);
+ Constant *CU32MaxWithUndef =
+ ConstantVector::get({CU32Undef, CU32Max, CU32Undef});
+
+ EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CMixedU32, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(match(CMixedU32, cst_pred_ty<always_false_pred<APInt>>()));
+
+ EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_max_pred>()));
+ EXPECT_FALSE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_zero_pred>()));
+ EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<always_true_pred<APInt>>()));
+ EXPECT_FALSE(
+ match(CU32MaxWithUndef, cst_pred_ty<always_false_pred<APInt>>()));
+
+ // Float arbitrary vector
+
+ Constant *CMixedF32 = ConstantVector::get({CF32NaN, CF32Zero, CF32Pi});
+ Constant *CF32Undef = UndefValue::get(F32Ty);
+ Constant *CF32NaNWithUndef =
+ ConstantVector::get({CF32Undef, CF32NaN, CF32Undef});
+
+ EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(match(CMixedF32, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+ EXPECT_TRUE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_nan_pred>()));
+ EXPECT_FALSE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_zero_pred>()));
+ EXPECT_TRUE(
+ match(CF32NaNWithUndef, cstfp_pred_ty<always_true_pred<APFloat>>()));
+ EXPECT_FALSE(
+ match(CF32NaNWithUndef, cstfp_pred_ty<always_false_pred<APFloat>>()));
+}
+
template <typename T> struct MutableConstTest : PatternMatchTest { };
typedef ::testing::Types<std::tuple<Value*, Instruction*>,
More information about the llvm-commits
mailing list