[llvm] 9c87c55 - [SVE] Make cstfp_pred_ty and cst_pred_ty work with scalable splats

Christopher Tetreault via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 14 14:20:54 PDT 2020


Author: Christopher Tetreault
Date: 2020-07-14T14:20:39-07:00
New Revision: 9c87c5580575cefdebb02cc6685fb6b66fb375c9

URL: https://github.com/llvm/llvm-project/commit/9c87c5580575cefdebb02cc6685fb6b66fb375c9
DIFF: https://github.com/llvm/llvm-project/commit/9c87c5580575cefdebb02cc6685fb6b66fb375c9.diff

LOG: [SVE] Make cstfp_pred_ty and cst_pred_ty work with scalable splats

Reviewers: efriedma, lebedev.ri, fhahn, c-rhodes, david-arm

Reviewed By: efriedma, david-arm

Subscribers: tschuett, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D83001

Added: 
    

Modified: 
    llvm/include/llvm/IR/Constants.h
    llvm/include/llvm/IR/PatternMatch.h
    llvm/test/Transforms/InstCombine/fmul.ll
    llvm/test/Transforms/InstCombine/mul.ll
    llvm/unittests/IR/PatternMatch.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 3579c9f1ee33..8e2dba9b2417 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -308,6 +308,7 @@ class ConstantFP final : public ConstantData {
   /// Return true if Ty is big enough to represent V.
   static bool isValueValidForType(Type *Ty, const APFloat &V);
   inline const APFloat &getValueAPF() const { return Val; }
+  inline const APFloat &getValue() const { return Val; }
 
   /// Return true if the value is positive or negative zero.
   bool isZero() const { return Val.isZero(); }

diff  --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 98182bc3d85d..4c11bc82510b 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -262,17 +262,23 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
   return constantint_match<Val>();
 }
 
-/// This helper class is used to match scalar and fixed width vector integer
-/// constants that satisfy a specified predicate.
-/// For vector constants, undefined elements are ignored.
-template <typename Predicate> struct cst_pred_ty : public Predicate {
+/// This helper class is used to match constant scalars, vector splats,
+/// and fixed width vectors that satisfy a specified predicate.
+/// For fixed width vector constants, undefined elements are ignored.
+template <typename Predicate, typename ConstantVal>
+struct cstval_pred_ty : public Predicate {
   template <typename ITy> bool match(ITy *V) {
-    if (const auto *CI = dyn_cast<ConstantInt>(V))
-      return this->isValue(CI->getValue());
-    if (const auto *FVTy = dyn_cast<FixedVectorType>(V->getType())) {
+    if (const auto *CV = dyn_cast<ConstantVal>(V))
+      return this->isValue(CV->getValue());
+    if (const auto *VTy = dyn_cast<VectorType>(V->getType())) {
       if (const auto *C = dyn_cast<Constant>(V)) {
-        if (const auto *CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue()))
-          return this->isValue(CI->getValue());
+        if (const auto *CV = dyn_cast_or_null<ConstantVal>(C->getSplatValue()))
+          return this->isValue(CV->getValue());
+
+        // Number of elements of a scalable vector unknown at compile time
+        auto *FVTy = dyn_cast<FixedVectorType>(VTy);
+        if (!FVTy)
+          return false;
 
         // Non-splat vector constant: check each element for a match.
         unsigned NumElts = FVTy->getNumElements();
@@ -284,8 +290,8 @@ template <typename Predicate> struct cst_pred_ty : public Predicate {
             return false;
           if (isa<UndefValue>(Elt))
             continue;
-          auto *CI = dyn_cast<ConstantInt>(Elt);
-          if (!CI || !this->isValue(CI->getValue()))
+          auto *CV = dyn_cast<ConstantVal>(Elt);
+          if (!CV || !this->isValue(CV->getValue()))
             return false;
           HasNonUndefElements = true;
         }
@@ -296,6 +302,14 @@ template <typename Predicate> struct cst_pred_ty : public Predicate {
   }
 };
 
+/// specialization of cstval_pred_ty for ConstantInt
+template <typename Predicate>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+
+/// specialization of cstval_pred_ty for ConstantFP
+template <typename Predicate>
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APInt.
 template <typename Predicate> struct api_pred_ty : public Predicate {
@@ -321,44 +335,6 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
   }
 };
 
-/// This helper class is used to match scalar and vector floating-point
-/// constants that satisfy a specified predicate.
-/// For vector constants, undefined elements are ignored.
-template <typename Predicate> struct cstfp_pred_ty : public Predicate {
-  template <typename ITy> bool match(ITy *V) {
-    if (const auto *CF = dyn_cast<ConstantFP>(V))
-      return this->isValue(CF->getValueAPF());
-    if (V->getType()->isVectorTy()) {
-      if (const auto *C = dyn_cast<Constant>(V)) {
-        if (const auto *CF = dyn_cast_or_null<ConstantFP>(C->getSplatValue()))
-          return this->isValue(CF->getValueAPF());
-
-        // Number of elements of a scalable vector unknown at compile time
-        if (isa<ScalableVectorType>(V->getType()))
-          return false;
-
-        // Non-splat vector constant: check each element for a match.
-        unsigned NumElts = cast<VectorType>(V->getType())->getNumElements();
-        assert(NumElts != 0 && "Constant vector with no elements?");
-        bool HasNonUndefElements = false;
-        for (unsigned i = 0; i != NumElts; ++i) {
-          Constant *Elt = C->getAggregateElement(i);
-          if (!Elt)
-            return false;
-          if (isa<UndefValue>(Elt))
-            continue;
-          auto *CF = dyn_cast<ConstantFP>(Elt);
-          if (!CF || !this->isValue(CF->getValueAPF()))
-            return false;
-          HasNonUndefElements = true;
-        }
-        return HasNonUndefElements;
-      }
-    }
-    return false;
-  }
-};
-
 ///////////////////////////////////////////////////////////////////////////////
 //
 // Encapsulate constant value queries for use in templated predicate matchers.

diff  --git a/llvm/test/Transforms/InstCombine/fmul.ll b/llvm/test/Transforms/InstCombine/fmul.ll
index 8e168f252978..4162973f0bed 100644
--- a/llvm/test/Transforms/InstCombine/fmul.ll
+++ b/llvm/test/Transforms/InstCombine/fmul.ll
@@ -1164,3 +1164,12 @@ define double @fmul_sqrt_select(double %x, i1 %c) {
   %mul = fmul fast double %sqr, %sel
   ret double %mul
 }
+
+; fastmath => z * splat(0) = splat(0), even for scalable vectors
+define <vscale x 2 x float> @mul_scalable_splat_zero(<vscale x 2 x float> %z) {
+; CHECK-LABEL: @mul_scalable_splat_zero(
+; CHECK-NEXT:    ret <vscale x 2 x float> zeroinitializer
+  %shuf = shufflevector <vscale x 2 x float> insertelement (<vscale x 2 x float> undef, float 0.0, i32 0), <vscale x 2 x float> undef, <vscale x 2 x i32> zeroinitializer
+  %t3 = fmul fast <vscale x 2 x float> %shuf, %z
+  ret <vscale x 2 x float> %t3
+}

diff  --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll
index 059b18d30b90..d2844561ca7a 100644
--- a/llvm/test/Transforms/InstCombine/mul.ll
+++ b/llvm/test/Transforms/InstCombine/mul.ll
@@ -857,3 +857,12 @@ define <4 x i32> @combine_mul_nabs_v4i32(<4 x i32> %0) {
   %m = mul <4 x i32> %r, %r
   ret <4 x i32> %m
 }
+
+; z * splat(0) = splat(0), even for scalable vectors
+define <vscale x 2 x i64> @mul_scalable_splat_zero(<vscale x 2 x i64> %z) {
+; CHECK-LABEL: @mul_scalable_splat_zero(
+; CHECK-NEXT:    ret <vscale x 2 x i64> zeroinitializer
+  %shuf = shufflevector <vscale x 2 x i64> insertelement (<vscale x 2 x i64> undef, i64 0, i32 0), <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer
+  %t3 = mul <vscale x 2 x i64> %shuf, %z
+  ret <vscale x 2 x i64> %t3
+}

diff  --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index bbcbd91c8f1f..34dc2767021a 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1325,6 +1325,183 @@ TEST_F(PatternMatchTest, IntrinsicMatcher) {
                             m_SpecificInt(10))));
 }
 
+namespace {
+
+struct is_unsigned_zero_pred {
+  bool isValue(const APInt &C) { return C.isNullValue(); }
+};
+
+struct is_float_zero_pred {
+  bool isValue(const APFloat &C) { return C.isZero(); }
+};
+
+template <typename T> struct always_true_pred {
+  bool isValue(const T &) { return true; }
+};
+
+template <typename T> struct always_false_pred {
+  bool isValue(const T &) { return false; }
+};
+
+struct is_unsigned_max_pred {
+  bool isValue(const APInt &C) { return C.isMaxValue(); }
+};
+
+struct is_float_nan_pred {
+  bool isValue(const APFloat &C) { return C.isNaN(); }
+};
+
+} // namespace
+
+TEST_F(PatternMatchTest, ConstantPredicateType) {
+
+  // Scalar integer
+  APInt U32Max = APInt::getAllOnesValue(32);
+  APInt U32Zero = APInt::getNullValue(32);
+  APInt U32DeadBeef(32, 0xDEADBEEF);
+
+  Type *U32Ty = Type::getInt32Ty(Ctx);
+
+  Constant *CU32Max = Constant::getIntegerValue(U32Ty, U32Max);
+  Constant *CU32Zero = Constant::getIntegerValue(U32Ty, U32Zero);
+  Constant *CU32DeadBeef = Constant::getIntegerValue(U32Ty, U32DeadBeef);
+
+  EXPECT_TRUE(match(CU32Max, cst_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(match(CU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CU32Max, cst_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(match(CU32Max, cst_pred_ty<always_false_pred<APInt>>()));
+
+  EXPECT_FALSE(match(CU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_TRUE(match(CU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(match(CU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+
+  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+
+  // Scalar float
+  APFloat F32NaN = APFloat::getNaN(APFloat::IEEEsingle());
+  APFloat F32Zero = APFloat::getZero(APFloat::IEEEsingle());
+  APFloat F32Pi(3.14f);
+
+  Type *F32Ty = Type::getFloatTy(Ctx);
+
+  Constant *CF32NaN = ConstantFP::get(F32Ty, F32NaN);
+  Constant *CF32Zero = ConstantFP::get(F32Ty, F32Zero);
+  Constant *CF32Pi = ConstantFP::get(F32Ty, F32Pi);
+
+  EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+  EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
+  EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(match(CF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+  ElementCount FixedEC(4, false);
+  ElementCount ScalableEC(4, true);
+
+  // Vector splat
+
+  for (auto EC : {FixedEC, ScalableEC}) {
+    // integer
+
+    Constant *CSplatU32Max = ConstantVector::getSplat(EC, CU32Max);
+    Constant *CSplatU32Zero = ConstantVector::getSplat(EC, CU32Zero);
+    Constant *CSplatU32DeadBeef = ConstantVector::getSplat(EC, CU32DeadBeef);
+
+    EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<is_unsigned_max_pred>()));
+    EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<is_unsigned_zero_pred>()));
+    EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty<always_true_pred<APInt>>()));
+    EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty<always_false_pred<APInt>>()));
+
+    EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_max_pred>()));
+    EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<is_unsigned_zero_pred>()));
+    EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty<always_true_pred<APInt>>()));
+    EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty<always_false_pred<APInt>>()));
+
+    EXPECT_FALSE(match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_max_pred>()));
+    EXPECT_FALSE(
+        match(CSplatU32DeadBeef, cst_pred_ty<is_unsigned_zero_pred>()));
+    EXPECT_TRUE(
+        match(CSplatU32DeadBeef, cst_pred_ty<always_true_pred<APInt>>()));
+    EXPECT_FALSE(
+        match(CSplatU32DeadBeef, cst_pred_ty<always_false_pred<APInt>>()));
+
+    // float
+
+    Constant *CSplatF32NaN = ConstantVector::getSplat(EC, CF32NaN);
+    Constant *CSplatF32Zero = ConstantVector::getSplat(EC, CF32Zero);
+    Constant *CSplatF32Pi = ConstantVector::getSplat(EC, CF32Pi);
+
+    EXPECT_TRUE(match(CSplatF32NaN, cstfp_pred_ty<is_float_nan_pred>()));
+    EXPECT_FALSE(match(CSplatF32NaN, cstfp_pred_ty<is_float_zero_pred>()));
+    EXPECT_TRUE(
+        match(CSplatF32NaN, cstfp_pred_ty<always_true_pred<APFloat>>()));
+    EXPECT_FALSE(
+        match(CSplatF32NaN, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+    EXPECT_FALSE(match(CSplatF32Zero, cstfp_pred_ty<is_float_nan_pred>()));
+    EXPECT_TRUE(match(CSplatF32Zero, cstfp_pred_ty<is_float_zero_pred>()));
+    EXPECT_TRUE(
+        match(CSplatF32Zero, cstfp_pred_ty<always_true_pred<APFloat>>()));
+    EXPECT_FALSE(
+        match(CSplatF32Zero, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+    EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_nan_pred>()));
+    EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty<is_float_zero_pred>()));
+    EXPECT_TRUE(match(CSplatF32Pi, cstfp_pred_ty<always_true_pred<APFloat>>()));
+    EXPECT_FALSE(
+        match(CSplatF32Pi, cstfp_pred_ty<always_false_pred<APFloat>>()));
+  }
+
+  // Int arbitrary vector
+
+  Constant *CMixedU32 = ConstantVector::get({CU32Max, CU32Zero, CU32DeadBeef});
+  Constant *CU32Undef = UndefValue::get(U32Ty);
+  Constant *CU32MaxWithUndef =
+      ConstantVector::get({CU32Undef, CU32Max, CU32Undef});
+
+  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CMixedU32, cst_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(match(CMixedU32, cst_pred_ty<always_false_pred<APInt>>()));
+
+  EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_max_pred>()));
+  EXPECT_FALSE(match(CU32MaxWithUndef, cst_pred_ty<is_unsigned_zero_pred>()));
+  EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty<always_true_pred<APInt>>()));
+  EXPECT_FALSE(
+      match(CU32MaxWithUndef, cst_pred_ty<always_false_pred<APInt>>()));
+
+  // Float arbitrary vector
+
+  Constant *CMixedF32 = ConstantVector::get({CF32NaN, CF32Zero, CF32Pi});
+  Constant *CF32Undef = UndefValue::get(F32Ty);
+  Constant *CF32NaNWithUndef =
+      ConstantVector::get({CF32Undef, CF32NaN, CF32Undef});
+
+  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(match(CMixedF32, cstfp_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty<always_false_pred<APFloat>>()));
+
+  EXPECT_TRUE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_nan_pred>()));
+  EXPECT_FALSE(match(CF32NaNWithUndef, cstfp_pred_ty<is_float_zero_pred>()));
+  EXPECT_TRUE(
+      match(CF32NaNWithUndef, cstfp_pred_ty<always_true_pred<APFloat>>()));
+  EXPECT_FALSE(
+      match(CF32NaNWithUndef, cstfp_pred_ty<always_false_pred<APFloat>>()));
+}
+
 template <typename T> struct MutableConstTest : PatternMatchTest { };
 
 typedef ::testing::Types<std::tuple<Value*, Instruction*>,


        


More information about the llvm-commits mailing list