[llvm] 954ea0f - [SLP] Simplify indices processing for insertelements

Anton Afanasyev via llvm-commits llvm-commits at lists.llvm.org
Mon Feb 14 03:55:11 PST 2022


Author: Anton Afanasyev
Date: 2022-02-14T14:50:44+03:00
New Revision: 954ea0f044e0cd8368011dbef6efca70f03358e6

URL: https://github.com/llvm/llvm-project/commit/954ea0f044e0cd8368011dbef6efca70f03358e6
DIFF: https://github.com/llvm/llvm-project/commit/954ea0f044e0cd8368011dbef6efca70f03358e6.diff

LOG: [SLP] Simplify indices processing for insertelements

Get rid of non-constant and undef indices of insertelements
at `buildTree()` stage. Fix bugs.

Differential Revision: https://reviews.llvm.org/D119623

Added: 
    llvm/test/Transforms/SLPVectorizer/X86/insert-crash-index.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 04e78c924f005..7e9d8058d1d91 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -728,19 +728,18 @@ static void inversePermutation(ArrayRef<unsigned> Indices,
 
 /// \returns inserting index of InsertElement or InsertValue instruction,
 /// using Offset as base offset for index.
-static Optional<int> getInsertIndex(Value *InsertInst, unsigned Offset = 0) {
+static Optional<unsigned> getInsertIndex(Value *InsertInst,
+                                         unsigned Offset = 0) {
   int Index = Offset;
   if (auto *IE = dyn_cast<InsertElementInst>(InsertInst)) {
     if (auto *CI = dyn_cast<ConstantInt>(IE->getOperand(2))) {
       auto *VT = cast<FixedVectorType>(IE->getType());
       if (CI->getValue().uge(VT->getNumElements()))
-        return UndefMaskElem;
+        return None;
       Index *= VT->getNumElements();
       Index += CI->getZExtValue();
       return Index;
     }
-    if (isa<UndefValue>(IE->getOperand(2)))
-      return UndefMaskElem;
     return None;
   }
 
@@ -4022,13 +4021,15 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       // Check that we have a buildvector and not a shuffle of 2 or more
       // 
diff erent vectors.
       ValueSet SourceVectors;
-      int MinIdx = std::numeric_limits<int>::max();
       for (Value *V : VL) {
         SourceVectors.insert(cast<Instruction>(V)->getOperand(0));
-        Optional<int> Idx = *getInsertIndex(V);
-        if (!Idx || *Idx == UndefMaskElem)
-          continue;
-        MinIdx = std::min(MinIdx, *Idx);
+        if (getInsertIndex(V) == None) {
+          LLVM_DEBUG(dbgs() << "SLP: Gather of insertelement vectors with "
+                               "non-constant or undef index.\n");
+          newTreeEntry(VL, None /*not vectorized*/, S, UserTreeIdx);
+          BS.cancelScheduling(VL, VL0);
+          return;
+        }
       }
 
       if (count_if(VL, [&SourceVectors](Value *V) {
@@ -4050,10 +4051,8 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
                     decltype(OrdCompare)>
           Indices(OrdCompare);
       for (int I = 0, E = VL.size(); I < E; ++I) {
-        Optional<int> Idx = *getInsertIndex(VL[I]);
-        if (!Idx || *Idx == UndefMaskElem)
-          continue;
-        Indices.emplace(*Idx, I);
+        unsigned Idx = *getInsertIndex(VL[I]);
+        Indices.emplace(Idx, I);
       }
       OrdersType CurrentOrder(VL.size(), VL.size());
       bool IsIdentity = true;
@@ -5214,12 +5213,10 @@ InstructionCost BoUpSLP::getEntryCost(const TreeEntry *E,
       SmallVector<int> PrevMask(NumElts, UndefMaskElem);
       Mask.swap(PrevMask);
       for (unsigned I = 0; I < NumScalars; ++I) {
-        Optional<int> InsertIdx = getInsertIndex(VL[PrevMask[I]]);
-        if (!InsertIdx || *InsertIdx == UndefMaskElem)
-          continue;
-        DemandedElts.setBit(*InsertIdx);
-        IsIdentity &= *InsertIdx - Offset == I;
-        Mask[*InsertIdx - Offset] = I;
+        unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]);
+        DemandedElts.setBit(InsertIdx);
+        IsIdentity &= InsertIdx - Offset == I;
+        Mask[InsertIdx - Offset] = I;
       }
       assert(Offset < NumElts && "Failed to find vector index offset");
 
@@ -5923,9 +5920,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
     // to detect it as a final shuffled/identity match.
     if (auto *VU = dyn_cast_or_null<InsertElementInst>(EU.User)) {
       if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) {
-        Optional<int> InsertIdx = getInsertIndex(VU);
-        if (!InsertIdx || *InsertIdx == UndefMaskElem)
-          continue;
+        unsigned InsertIdx = *getInsertIndex(VU);
         auto *It = find_if(FirstUsers, [VU](Value *V) {
           return areTwoInsertFromSameBuildVector(VU,
                                                  cast<InsertElementInst>(V));
@@ -5955,9 +5950,8 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
         } else {
           VecId = std::distance(FirstUsers.begin(), It);
         }
-        int Idx = *InsertIdx;
-        ShuffleMask[VecId][Idx] = EU.Lane;
-        DemandedElts[VecId].setBit(Idx);
+        ShuffleMask[VecId][InsertIdx] = EU.Lane;
+        DemandedElts[VecId].setBit(InsertIdx);
         continue;
       }
     }
@@ -6709,11 +6703,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
       Mask.swap(PrevMask);
       for (unsigned I = 0; I < NumScalars; ++I) {
         Value *Scalar = E->Scalars[PrevMask[I]];
-        Optional<int> InsertIdx = getInsertIndex(Scalar);
-        if (!InsertIdx || *InsertIdx == UndefMaskElem)
-          continue;
-        IsIdentity &= *InsertIdx - Offset == I;
-        Mask[*InsertIdx - Offset] = I;
+        unsigned InsertIdx = *getInsertIndex(Scalar);
+        IsIdentity &= InsertIdx - Offset == I;
+        Mask[InsertIdx - Offset] = I;
       }
       if (!IsIdentity || NumElts != NumScalars) {
         V = Builder.CreateShuffleVector(V, Mask);
@@ -9599,21 +9591,22 @@ static Optional<unsigned> getAggregateSize(Instruction *InsertInst) {
   } while (true);
 }
 
-static bool findBuildAggregate_rec(Instruction *LastInsertInst,
+static void findBuildAggregate_rec(Instruction *LastInsertInst,
                                    TargetTransformInfo *TTI,
                                    SmallVectorImpl<Value *> &BuildVectorOpds,
                                    SmallVectorImpl<Value *> &InsertElts,
                                    unsigned OperandOffset) {
   do {
     Value *InsertedOperand = LastInsertInst->getOperand(1);
-    Optional<int> OperandIndex = getInsertIndex(LastInsertInst, OperandOffset);
+    Optional<unsigned> OperandIndex =
+        getInsertIndex(LastInsertInst, OperandOffset);
     if (!OperandIndex)
-      return false;
+      return;
     if (isa<InsertElementInst>(InsertedOperand) ||
         isa<InsertValueInst>(InsertedOperand)) {
-      if (!findBuildAggregate_rec(cast<Instruction>(InsertedOperand), TTI,
-                                  BuildVectorOpds, InsertElts, *OperandIndex))
-        return false;
+      findBuildAggregate_rec(cast<Instruction>(InsertedOperand), TTI,
+                             BuildVectorOpds, InsertElts, *OperandIndex);
+
     } else {
       BuildVectorOpds[*OperandIndex] = InsertedOperand;
       InsertElts[*OperandIndex] = LastInsertInst;
@@ -9623,7 +9616,6 @@ static bool findBuildAggregate_rec(Instruction *LastInsertInst,
            (isa<InsertValueInst>(LastInsertInst) ||
             isa<InsertElementInst>(LastInsertInst)) &&
            LastInsertInst->hasOneUse());
-  return true;
 }
 
 /// Recognize construction of vectors like
@@ -9658,13 +9650,11 @@ static bool findBuildAggregate(Instruction *LastInsertInst,
   BuildVectorOpds.resize(*AggregateSize);
   InsertElts.resize(*AggregateSize);
 
-  if (findBuildAggregate_rec(LastInsertInst, TTI, BuildVectorOpds, InsertElts,
-                             0)) {
-    llvm::erase_value(BuildVectorOpds, nullptr);
-    llvm::erase_value(InsertElts, nullptr);
-    if (BuildVectorOpds.size() >= 2)
-      return true;
-  }
+  findBuildAggregate_rec(LastInsertInst, TTI, BuildVectorOpds, InsertElts, 0);
+  llvm::erase_value(BuildVectorOpds, nullptr);
+  llvm::erase_value(InsertElts, nullptr);
+  if (BuildVectorOpds.size() >= 2)
+    return true;
 
   return false;
 }

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/insert-crash-index.ll b/llvm/test/Transforms/SLPVectorizer/X86/insert-crash-index.ll
new file mode 100644
index 0000000000000..882681027df76
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/X86/insert-crash-index.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -S -slp-vectorizer < %s | FileCheck %s
+
+; These all crashing before patch
+
+define  <2 x i8> @negative_index(<2 x i8> %aux_vec, i8* %in) {
+; CHECK-LABEL: @negative_index(
+; CHECK-NEXT:    [[IN0:%.*]] = load i8, i8* [[IN:%.*]], align 4
+; CHECK-NEXT:    [[G1:%.*]] = getelementptr inbounds i8, i8* [[IN]], i64 1
+; CHECK-NEXT:    [[IN1:%.*]] = load i8, i8* [[G1]], align 4
+; CHECK-NEXT:    [[V0:%.*]] = insertelement <2 x i8> [[AUX_VEC:%.*]], i8 [[IN0]], i8 -1
+; CHECK-NEXT:    [[V1:%.*]] = insertelement <2 x i8> [[V0]], i8 [[IN1]], i64 1
+; CHECK-NEXT:    [[OUT:%.*]] = add <2 x i8> [[V1]], [[V1]]
+; CHECK-NEXT:    ret <2 x i8> [[OUT]]
+;
+  %in0 = load i8, i8* %in, align 4
+  %g1 = getelementptr inbounds i8, i8* %in, i64 1
+  %in1 = load i8, i8* %g1, align 4
+
+  %v0 = insertelement <2 x i8> %aux_vec, i8 %in0, i8 -1
+  %v1 = insertelement <2 x i8> %v0, i8 %in1, i64 1
+
+  %out = add <2 x i8> %v1, %v1
+  ret <2 x i8> %out
+}
+
+define  <2 x i8> @exceed_index(<2 x i8> %aux_vec, i8* %in) {
+; CHECK-LABEL: @exceed_index(
+; CHECK-NEXT:    [[IN0:%.*]] = load i8, i8* [[IN:%.*]], align 4
+; CHECK-NEXT:    [[G1:%.*]] = getelementptr inbounds i8, i8* [[IN]], i64 1
+; CHECK-NEXT:    [[IN1:%.*]] = load i8, i8* [[G1]], align 4
+; CHECK-NEXT:    [[V0:%.*]] = insertelement <2 x i8> [[AUX_VEC:%.*]], i8 [[IN0]], i8 2
+; CHECK-NEXT:    [[V1:%.*]] = insertelement <2 x i8> [[V0]], i8 [[IN1]], i64 1
+; CHECK-NEXT:    [[OUT:%.*]] = add <2 x i8> [[V1]], [[V1]]
+; CHECK-NEXT:    ret <2 x i8> [[OUT]]
+;
+  %in0 = load i8, i8* %in, align 4
+  %g1 = getelementptr inbounds i8, i8* %in, i64 1
+  %in1 = load i8, i8* %g1, align 4
+
+  %v0 = insertelement <2 x i8> %aux_vec, i8 %in0, i8 2
+  %v1 = insertelement <2 x i8> %v0, i8 %in1, i64 1
+
+  %out = add <2 x i8> %v1, %v1
+  ret <2 x i8> %out
+}
+
+define  <2 x i8> @poison_index(<2 x i8> %aux_vec, i8* %in) {
+; CHECK-LABEL: @poison_index(
+; CHECK-NEXT:    [[IN0:%.*]] = load i8, i8* [[IN:%.*]], align 4
+; CHECK-NEXT:    [[G1:%.*]] = getelementptr inbounds i8, i8* [[IN]], i64 1
+; CHECK-NEXT:    [[IN1:%.*]] = load i8, i8* [[G1]], align 4
+; CHECK-NEXT:    [[V0:%.*]] = insertelement <2 x i8> [[AUX_VEC:%.*]], i8 [[IN0]], i8 poison
+; CHECK-NEXT:    [[V1:%.*]] = insertelement <2 x i8> [[V0]], i8 [[IN1]], i64 1
+; CHECK-NEXT:    [[OUT:%.*]] = add <2 x i8> [[V1]], [[V1]]
+; CHECK-NEXT:    ret <2 x i8> [[OUT]]
+;
+  %in0 = load i8, i8* %in, align 4
+  %g1 = getelementptr inbounds i8, i8* %in, i64 1
+  %in1 = load i8, i8* %g1, align 4
+
+  %v0 = insertelement <2 x i8> %aux_vec, i8 %in0, i8 poison
+  %v1 = insertelement <2 x i8> %v0, i8 %in1, i64 1
+
+  %out = add <2 x i8> %v1, %v1
+  ret <2 x i8> %out
+}


        


More information about the llvm-commits mailing list