[llvm] 686a038 - [Analysis] add query to get splat value from array of ints
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 5 11:56:44 PST 2020
Author: Sanjay Patel
Date: 2020-02-05T14:55:02-05:00
New Revision: 686a038ed8f96e5539c54fea28aedac63145cf71
URL: https://github.com/llvm/llvm-project/commit/686a038ed8f96e5539c54fea28aedac63145cf71
DIFF: https://github.com/llvm/llvm-project/commit/686a038ed8f96e5539c54fea28aedac63145cf71.diff
LOG: [Analysis] add query to get splat value from array of ints
I was debug stepping through an x86 shuffle lowering and
noticed we were doing an N^2 search for splat index. I
didn't find the equivalent functionality anywhere else in
LLVM, so here's a helper that takes an array of int and
returns a splatted index while ignoring undefs (any
negative value).
This might also be used inside existing
ShuffleVectorInst/ShuffleVectorSDNode functions and/or
help with D72467.
Differential Revision: https://reviews.llvm.org/D74064
Added:
Modified:
llvm/include/llvm/Analysis/VectorUtils.h
llvm/lib/Analysis/VectorUtils.cpp
llvm/unittests/Analysis/VectorUtilsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7726cf0c2220..e8d62416ccda 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -301,6 +301,11 @@ Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
/// from the vector.
Value *findScalarElement(Value *V, unsigned EltNo);
+/// If all non-negative \p Mask elements are the same value, return that value.
+/// If all elements are negative (undefined) or \p Mask contains
diff erent
+/// non-negative values, return -1.
+int getSplatIndex(ArrayRef<int> Mask);
+
/// Get splat value if the input is a splat vector or return nullptr.
/// The value may be extracted from a splat constants vector or from
/// a sequence of instructions that broadcast a single value into a vector.
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index e4b00108cfef..d2c521ac9c9d 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -307,6 +307,24 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) {
return nullptr;
}
+int llvm::getSplatIndex(ArrayRef<int> Mask) {
+ int SplatIndex = -1;
+ for (int M : Mask) {
+ // Ignore invalid (undefined) mask elements.
+ if (M < 0)
+ continue;
+
+ // There can be only 1 non-negative mask element value if this is a splat.
+ if (SplatIndex != -1 && SplatIndex != M)
+ return -1;
+
+ // Initialize the splat index to the 1st non-negative mask element.
+ SplatIndex = M;
+ }
+ assert((SplatIndex == -1 || SplatIndex >= 0) && "Negative index?");
+ return SplatIndex;
+}
+
/// Get splat value if the input is a splat vector or return nullptr.
/// This function is not fully general. It checks only 2 cases:
/// the input value is (1) a splat constant vector or (2) a sequence
diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp
index ea5282f9d74d..df744ac71657 100644
--- a/llvm/unittests/Analysis/VectorUtilsTest.cpp
+++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp
@@ -98,6 +98,17 @@ TEST_F(BasicTest, isSplat) {
EXPECT_FALSE(isSplatValue(SplatWithUndefC));
}
+TEST_F(BasicTest, getSplatIndex) {
+ EXPECT_EQ(getSplatIndex({0,0,0}), 0);
+ EXPECT_EQ(getSplatIndex({1,0,0}), -1); // no splat
+ EXPECT_EQ(getSplatIndex({0,1,1}), -1); // no splat
+ EXPECT_EQ(getSplatIndex({42,42,42}), 42); // array size is independent of splat index
+ EXPECT_EQ(getSplatIndex({42,42,-1}), 42); // ignore negative
+ EXPECT_EQ(getSplatIndex({-1,42,-1}), 42); // ignore negatives
+ EXPECT_EQ(getSplatIndex({-4,42,-42}), 42); // ignore all negatives
+ EXPECT_EQ(getSplatIndex({-4,-1,-42}), -1); // all negative values map to -1
+}
+
TEST_F(VectorUtilsTest, isSplatValue_00) {
parseAssembly(
"define <2 x i8> @test(<2 x i8> %x) {\n"
More information about the llvm-commits
mailing list