[llvm] 021d56a - [SVE] Make Constant::getSplatValue work for scalable vector splats

Christopher Tetreault via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 7 13:46:07 PDT 2020


Author: Christopher Tetreault
Date: 2020-07-07T13:45:51-07:00
New Revision: 021d56abb9ee3028cb88895144d71365e566c32f

URL: https://github.com/llvm/llvm-project/commit/021d56abb9ee3028cb88895144d71365e566c32f
DIFF: https://github.com/llvm/llvm-project/commit/021d56abb9ee3028cb88895144d71365e566c32f.diff

LOG: [SVE] Make Constant::getSplatValue work for scalable vector splats

Summary:
Make Constant::getSplatValue recognize scalable vector splats of the
form created by ConstantVector::getSplat. Add unit test to verify that
C == ConstantVector::getSplat(C)->getSplatValue() for fixed width and
scalable vector splats

Reviewers: efriedma, spatel, fpetrogalli, c-rhodes

Reviewed By: efriedma

Subscribers: sdesmalen, tschuett, hiraditya, rkruppe, psnobl, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D82416

Added: 
    

Modified: 
    llvm/lib/IR/Constants.cpp
    llvm/test/Transforms/InstSimplify/vscale.ll
    llvm/unittests/IR/ConstantsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index d8e044ee4bdc..cbbcca20ea51 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1585,6 +1585,27 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const {
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
     return CV->getSplatValue(AllowUndefs);
+
+  // Check if this is a constant expression splat of the form returned by
+  // ConstantVector::getSplat()
+  const auto *Shuf = dyn_cast<ConstantExpr>(this);
+  if (Shuf && Shuf->getOpcode() == Instruction::ShuffleVector &&
+      isa<UndefValue>(Shuf->getOperand(1))) {
+
+    const auto *IElt = dyn_cast<ConstantExpr>(Shuf->getOperand(0));
+    if (IElt && IElt->getOpcode() == Instruction::InsertElement &&
+        isa<UndefValue>(IElt->getOperand(0))) {
+
+      ArrayRef<int> Mask = Shuf->getShuffleMask();
+      Constant *SplatVal = IElt->getOperand(1);
+      ConstantInt *Index = dyn_cast<ConstantInt>(IElt->getOperand(2));
+
+      if (Index && Index->getValue() == 0 &&
+          std::all_of(Mask.begin(), Mask.end(), [](int I) { return I == 0; }))
+        return SplatVal;
+    }
+  }
+
   return nullptr;
 }
 

diff  --git a/llvm/test/Transforms/InstSimplify/vscale.ll b/llvm/test/Transforms/InstSimplify/vscale.ll
index 669c824685e8..d396f0289196 100644
--- a/llvm/test/Transforms/InstSimplify/vscale.ll
+++ b/llvm/test/Transforms/InstSimplify/vscale.ll
@@ -95,6 +95,15 @@ define i32 @insert_extract_element_same_vec_idx_2(<vscale x 4 x i32> %a) {
   ret i32 %r
 }
 
+; more complicated expressions
+
+define <vscale x 2 x i1> @cmp_le_smax_always_true(<vscale x 2 x i64> %x) {
+; CHECK-LABEL: @cmp_le_smax_always_true(
+; CHECK-NEXT:    ret <vscale x 2 x i1> shufflevector (<vscale x 2 x i1> insertelement (<vscale x 2 x i1> undef, i1 true, i32 0), <vscale x 2 x i1> undef, <vscale x 2 x i32> zeroinitializer)
+   %cmp = icmp sle <vscale x 2 x i64> %x, shufflevector (<vscale x 2 x i64> insertelement (<vscale x 2 x i64> undef, i64 9223372036854775807, i32 0), <vscale x 2 x i64> undef, <vscale x 2 x i32> zeroinitializer)
+   ret <vscale x 2 x i1> %cmp
+}
+
 ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
 ;; Memory Access and Addressing Operations
 ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

diff  --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp
index e20039e8d9c4..3fed395daee4 100644
--- a/llvm/unittests/IR/ConstantsTest.cpp
+++ b/llvm/unittests/IR/ConstantsTest.cpp
@@ -638,5 +638,34 @@ TEST(ConstantsTest, isElementWiseEqual) {
   EXPECT_FALSE(CP00U->isElementWiseEqual(CP00U0));
 }
 
+TEST(ConstantsTest, GetSplatValueRoundTrip) {
+  LLVMContext Context;
+
+  Type *FloatTy = Type::getFloatTy(Context);
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  Type *Int8Ty = Type::getInt8Ty(Context);
+
+  for (unsigned Min : {1, 2, 8}) {
+    ElementCount SEC = {Min, true};
+    ElementCount FEC = {Min, false};
+
+    for (auto EC : {SEC, FEC}) {
+      for (auto *Ty : {FloatTy, Int32Ty, Int8Ty}) {
+        Constant *Zero = Constant::getNullValue(Ty);
+        Constant *One = Constant::getAllOnesValue(Ty);
+
+        for (auto *C : {Zero, One}) {
+          Constant *Splat = ConstantVector::getSplat(EC, C);
+          ASSERT_NE(nullptr, Splat);
+
+          Constant *SplatVal = Splat->getSplatValue();
+          EXPECT_NE(nullptr, SplatVal);
+          EXPECT_EQ(SplatVal, C);
+        }
+      }
+    }
+  }
+}
+
 }  // end anonymous namespace
 }  // end namespace llvm


        


More information about the llvm-commits mailing list