[llvm] 686a038 - [Analysis] add query to get splat value from array of ints

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 5 11:56:44 PST 2020


Author: Sanjay Patel
Date: 2020-02-05T14:55:02-05:00
New Revision: 686a038ed8f96e5539c54fea28aedac63145cf71

URL: https://github.com/llvm/llvm-project/commit/686a038ed8f96e5539c54fea28aedac63145cf71
DIFF: https://github.com/llvm/llvm-project/commit/686a038ed8f96e5539c54fea28aedac63145cf71.diff

LOG: [Analysis] add query to get splat value from array of ints

I was debug stepping through an x86 shuffle lowering and
noticed we were doing an N^2 search for splat index. I
didn't find the equivalent functionality anywhere else in
LLVM, so here's a helper that takes an array of int and
returns a splatted index while ignoring undefs (any
negative value).

This might also be used inside existing
ShuffleVectorInst/ShuffleVectorSDNode functions and/or
help with D72467.

Differential Revision: https://reviews.llvm.org/D74064

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/VectorUtils.h
    llvm/lib/Analysis/VectorUtils.cpp
    llvm/unittests/Analysis/VectorUtilsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7726cf0c2220..e8d62416ccda 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -301,6 +301,11 @@ Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
 /// from the vector.
 Value *findScalarElement(Value *V, unsigned EltNo);
 
+/// If all non-negative \p Mask elements are the same value, return that value.
+/// If all elements are negative (undefined) or \p Mask contains 
diff erent
+/// non-negative values, return -1.
+int getSplatIndex(ArrayRef<int> Mask);
+
 /// Get splat value if the input is a splat vector or return nullptr.
 /// The value may be extracted from a splat constants vector or from
 /// a sequence of instructions that broadcast a single value into a vector.

diff  --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index e4b00108cfef..d2c521ac9c9d 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -307,6 +307,24 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
   return nullptr;
 }
 
+int llvm::getSplatIndex(ArrayRef<int> Mask) {
+  int SplatIndex = -1;
+  for (int M : Mask) {
+    // Ignore invalid (undefined) mask elements.
+    if (M < 0)
+      continue;
+
+    // There can be only 1 non-negative mask element value if this is a splat.
+    if (SplatIndex != -1 && SplatIndex != M)
+      return -1;
+
+    // Initialize the splat index to the 1st non-negative mask element.
+    SplatIndex = M;
+  }
+  assert((SplatIndex == -1 || SplatIndex >= 0) && "Negative index?");
+  return SplatIndex;
+}
+
 /// Get splat value if the input is a splat vector or return nullptr.
 /// This function is not fully general. It checks only 2 cases:
 /// the input value is (1) a splat constant vector or (2) a sequence

diff  --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index ea5282f9d74d..df744ac71657 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -98,6 +98,17 @@ TEST_F(BasicTest, isSplat) {
   EXPECT_FALSE(isSplatValue(SplatWithUndefC));
 }
 
+TEST_F(BasicTest, getSplatIndex) {
+  EXPECT_EQ(getSplatIndex({0,0,0}), 0);
+  EXPECT_EQ(getSplatIndex({1,0,0}), -1);     // no splat
+  EXPECT_EQ(getSplatIndex({0,1,1}), -1);     // no splat
+  EXPECT_EQ(getSplatIndex({42,42,42}), 42);  // array size is independent of splat index
+  EXPECT_EQ(getSplatIndex({42,42,-1}), 42);  // ignore negative
+  EXPECT_EQ(getSplatIndex({-1,42,-1}), 42);  // ignore negatives
+  EXPECT_EQ(getSplatIndex({-4,42,-42}), 42); // ignore all negatives
+  EXPECT_EQ(getSplatIndex({-4,-1,-42}), -1); // all negative values map to -1
+}
+
 TEST_F(VectorUtilsTest, isSplatValue_00) {
   parseAssembly(
       "define <2 x i8> @test(<2 x i8> %x) {\n"


        


More information about the llvm-commits mailing list