[llvm] 9b9e2da - [Analysis] add optional index parameter to isSplatValue()

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun Feb 2 08:14:31 PST 2020


Author: Sanjay Patel
Date: 2020-02-02T10:52:00-05:00
New Revision: 9b9e2da07dd3b103e5a41a3519d839117d994ffa

URL: https://github.com/llvm/llvm-project/commit/9b9e2da07dd3b103e5a41a3519d839117d994ffa
DIFF: https://github.com/llvm/llvm-project/commit/9b9e2da07dd3b103e5a41a3519d839117d994ffa.diff

LOG: [Analysis] add optional index parameter to isSplatValue()

We want to allow splat value transforms to improve PR44588 and related bugs:
https://bugs.llvm.org/show_bug.cgi?id=44588
...but to do that, we need to know if values are splatted from the same,
specific index (lane) rather than splatted from an arbitrary index.

We can improve the undef handling with 1-liner follow-ups because the
Constant API optionally allow undefs now.

Differential Revision: https://reviews.llvm.org/D73549

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/unittests/Analysis/VectorUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index f0b0f15d9476..7726cf0c2220 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -306,11 +306,13 @@ Value *findScalarElement(Value *V, unsigned EltNo);
 /// a sequence of instructions that broadcast a single value into a vector.
 const Value *getSplatValue(const Value *V);
 
-/// Return true if the input value is known to be a vector with all identical
-/// elements (potentially including undefined elements).
+/// Return true if each element of the vector value \p V is poisoned or equal to
+/// every other non-poisoned element. If an index element is specified, either
+/// every element of the vector is poisoned or the element at that index is not
+/// poisoned and equal to every other non-poisoned element.
 /// This may be more powerful than the related getSplatValue() because it is
 /// not limited by finding a scalar source value to a splatted vector.
-bool isSplatValue(const Value *V, unsigned Depth = 0);
+bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
 
 /// Compute a map of integer instructions to their minimum legal type
 /// size.

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 0fb09e3e0d4f..e4b00108cfef 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -330,21 +330,32 @@ const llvm::Value *llvm::getSplatValue(const Value *V) {
 // adjusted if needed.
 const unsigned MaxDepth = 6;
 
-bool llvm::isSplatValue(const Value *V, unsigned Depth) {
+bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) {
   assert(Depth <= MaxDepth && "Limit Search Depth");
 
   if (isa<VectorType>(V->getType())) {
     if (isa<UndefValue>(V))
       return true;
-    // FIXME: Constant splat analysis does not allow undef elements.
+    // FIXME: We can allow undefs, but if Index was specified, we may want to
+    //        check that the constant is defined at that index.
     if (auto *C = dyn_cast<Constant>(V))
       return C->getSplatValue() != nullptr;
   }
 
-  // FIXME: Constant splat analysis does not allow undef elements.
-  Constant *Mask;
-  if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask))))
-    return Mask->getSplatValue() != nullptr;
+  if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
+    // FIXME: We can safely allow undefs here. If Index was specified, we will
+    //        check that the mask elt is defined at the required index.
+    if (!Shuf->getMask()->getSplatValue())
+      return false;
+
+    // Match any index.
+    if (Index == -1)
+      return true;
+
+    // Match a specific element. The mask should be defined at and match the
+    // specified index.
+    return Shuf->getMaskValue(Index) == Index;
+  }
 
   // The remaining tests are all recursive, so bail out if we hit the limit.
   if (Depth++ == MaxDepth)
@@ -353,12 +364,12 @@ bool llvm::isSplatValue(const Value *V, unsigned Depth) {
   // If both operands of a binop are splats, the result is a splat.
   Value *X, *Y, *Z;
   if (match(V, m_BinOp(m_Value(X), m_Value(Y))))
-    return isSplatValue(X, Depth) && isSplatValue(Y, Depth);
+    return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth);
 
   // If all operands of a select are splats, the result is a splat.
   if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z))))
-    return isSplatValue(X, Depth) && isSplatValue(Y, Depth) &&
-           isSplatValue(Z, Depth);
+    return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth) &&
+           isSplatValue(Z, Index, Depth);
 
   // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops).
 

diff  --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index 074316082d18..ea5282f9d74d 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -107,6 +107,24 @@ TEST_F(VectorUtilsTest, isSplatValue_00) {
   EXPECT_TRUE(isSplatValue(A));
 }
 
+TEST_F(VectorUtilsTest, isSplatValue_00_index0) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_TRUE(isSplatValue(A, 0));
+}
+
+TEST_F(VectorUtilsTest, isSplatValue_00_index1) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 1));
+}
+
 TEST_F(VectorUtilsTest, isSplatValue_11) {
   parseAssembly(
       "define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -116,6 +134,24 @@ TEST_F(VectorUtilsTest, isSplatValue_11) {
   EXPECT_TRUE(isSplatValue(A));
 }
 
+TEST_F(VectorUtilsTest, isSplatValue_11_index0) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 0));
+}
+
+TEST_F(VectorUtilsTest, isSplatValue_11_index1) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_TRUE(isSplatValue(A, 1));
+}
+
 TEST_F(VectorUtilsTest, isSplatValue_01) {
   parseAssembly(
       "define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -125,7 +161,25 @@ TEST_F(VectorUtilsTest, isSplatValue_01) {
   EXPECT_FALSE(isSplatValue(A));
 }
 
-// FIXME: Constant (mask) splat analysis does not allow undef elements.
+TEST_F(VectorUtilsTest, isSplatValue_01_index0) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 1>\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 0));
+}
+
+TEST_F(VectorUtilsTest, isSplatValue_01_index1) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 1>\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 1));
+}
+
+// FIXME: Allow undef matching with Constant (mask) splat analysis.
 
 TEST_F(VectorUtilsTest, isSplatValue_0u) {
   parseAssembly(
@@ -136,6 +190,26 @@ TEST_F(VectorUtilsTest, isSplatValue_0u) {
   EXPECT_FALSE(isSplatValue(A));
 }
 
+// FIXME: Allow undef matching with Constant (mask) splat analysis.
+
+TEST_F(VectorUtilsTest, isSplatValue_0u_index0) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 undef>\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 0));
+}
+
+TEST_F(VectorUtilsTest, isSplatValue_0u_index1) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 undef>\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 1));
+}
+
 TEST_F(VectorUtilsTest, isSplatValue_Binop) {
   parseAssembly(
       "define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -147,6 +221,28 @@ TEST_F(VectorUtilsTest, isSplatValue_Binop) {
   EXPECT_TRUE(isSplatValue(A));
 }
 
+TEST_F(VectorUtilsTest, isSplatValue_Binop_index0) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 0>\n"
+      "  %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
+      "  %A = udiv <2 x i8> %v0, %v1\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 0));
+}
+
+TEST_F(VectorUtilsTest, isSplatValue_Binop_index1) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 0, i32 0>\n"
+      "  %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
+      "  %A = udiv <2 x i8> %v0, %v1\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 1));
+}
+
 TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) {
   parseAssembly(
       "define <2 x i8> @test(<2 x i8> %x) {\n"
@@ -157,6 +253,26 @@ TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) {
   EXPECT_TRUE(isSplatValue(A));
 }
 
+TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index0) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
+      "  %A = ashr <2 x i8> <i8 42, i8 42>, %v1\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_FALSE(isSplatValue(A, 0));
+}
+
+TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index1) {
+  parseAssembly(
+      "define <2 x i8> @test(<2 x i8> %x) {\n"
+      "  %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> <i32 1, i32 1>\n"
+      "  %A = ashr <2 x i8> <i8 42, i8 42>, %v1\n"
+      "  ret <2 x i8> %A\n"
+      "}\n");
+  EXPECT_TRUE(isSplatValue(A, 1));
+}
+
 TEST_F(VectorUtilsTest, isSplatValue_Binop_Not_Op0) {
   parseAssembly(
       "define <2 x i8> @test(<2 x i8> %x) {\n"


        


More information about the llvm-commits mailing list