[llvm] dab19da - [SLP]Fix a crash for the strided nodes with reversed order and externally used pointer.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 23 07:39:08 PDT 2024


Author: Alexey Bataev
Date: 2024-08-23T07:35:48-07:00
New Revision: dab19dac94eee19483ba1a7c37bdec4b8501acc3

URL: https://github.com/llvm/llvm-project/commit/dab19dac94eee19483ba1a7c37bdec4b8501acc3
DIFF: https://github.com/llvm/llvm-project/commit/dab19dac94eee19483ba1a7c37bdec4b8501acc3.diff

LOG: [SLP]Fix a crash for the strided nodes with reversed order and externally used pointer.

If the strided node is reversed, need to cehck for the last instruction,
not the first one in the list of scalars, when checking if the root
pointer must be extracted.

Added: 
    llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index caee3bf9c958d5..949579772b94d5 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -1020,6 +1020,8 @@ static bool allSameType(ArrayRef<Value *> VL) {
 /// possible scalar operand in vectorized instruction.
 static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
                                         TargetLibraryInfo *TLI) {
+  if (!UserInst)
+    return false;
   unsigned Opcode = UserInst->getOpcode();
   switch (Opcode) {
   case Instruction::Load: {
@@ -2809,6 +2811,11 @@ class BoUpSLP {
   /// \ returns the graph entry for the \p Idx operand of the \p E entry.
   const TreeEntry *getOperandEntry(const TreeEntry *E, unsigned Idx) const;
 
+  /// Gets the root instruction for the given node. If the node is a strided
+  /// load/store node with the reverse order, the root instruction is the last
+  /// one.
+  Instruction *getRootEntryInstruction(const TreeEntry &Entry) const;
+
   /// \returns Cast context for the given graph node.
   TargetTransformInfo::CastContextHint
   getCastContextHint(const TreeEntry &TE) const;
@@ -5987,6 +5994,15 @@ void BoUpSLP::reorderBottomToTop(bool IgnoreReorder) {
     VectorizableTree.front()->ReorderIndices.clear();
 }
 
+Instruction *BoUpSLP::getRootEntryInstruction(const TreeEntry &Entry) const {
+  if ((Entry.getOpcode() == Instruction::Store ||
+       Entry.getOpcode() == Instruction::Load) &&
+      Entry.State == TreeEntry::StridedVectorize &&
+      !Entry.ReorderIndices.empty() && isReverseOrder(Entry.ReorderIndices))
+    return dyn_cast<Instruction>(Entry.Scalars[Entry.ReorderIndices.front()]);
+  return dyn_cast<Instruction>(Entry.Scalars.front());
+}
+
 void BoUpSLP::buildExternalUses(
     const ExtraValueToDebugLocsMap &ExternallyUsedValues) {
   DenseMap<Value *, unsigned> ScalarToExtUses;
@@ -6036,7 +6052,7 @@ void BoUpSLP::buildExternalUses(
           // be used.
           if (UseEntry->State == TreeEntry::ScatterVectorize ||
               !doesInTreeUserNeedToExtract(
-                  Scalar, cast<Instruction>(UseEntry->Scalars.front()), TLI)) {
+                  Scalar, getRootEntryInstruction(*UseEntry), TLI)) {
             LLVM_DEBUG(dbgs() << "SLP: \tInternal user will be removed:" << *U
                               << ".\n");
             assert(!UseEntry->isGather() && "Bad state");
@@ -8450,8 +8466,8 @@ void BoUpSLP::transformNodes() {
             Instruction::Store, VecTy, BaseSI->getPointerOperand(),
             /*VariableMask=*/false, CommonAlignment, CostKind, BaseSI);
         if (StridedCost < OriginalVecCost)
-          // Strided load is more profitable than consecutive load + reverse -
-          // transform the node to strided load.
+          // Strided store is more profitable than reverse + consecutive store -
+          // transform the node to strided store.
           E.State = TreeEntry::StridedVectorize;
       }
       break;
@@ -13776,7 +13792,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
         ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
       } else {
         assert(E->State == TreeEntry::StridedVectorize &&
-               "Expected either strided or conseutive stores.");
+               "Expected either strided or consecutive stores.");
         if (!E->ReorderIndices.empty()) {
           SI = cast<StoreInst>(E->Scalars[E->ReorderIndices.front()]);
           Ptr = SI->getPointerOperand();
@@ -14380,8 +14396,7 @@ Value *BoUpSLP::vectorizeTree(
                               (E->State == TreeEntry::Vectorize ||
                                E->State == TreeEntry::StridedVectorize) &&
                               doesInTreeUserNeedToExtract(
-                                  Scalar,
-                                  cast<Instruction>(UseEntry->Scalars.front()),
+                                  Scalar, getRootEntryInstruction(*UseEntry),
                                   TLI);
                      })) &&
              "Scalar with nullptr User must be registered in "

diff  --git a/llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll
new file mode 100644
index 00000000000000..3fa42047162e45
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/reversed-strided-node-with-external-ptr.ll
@@ -0,0 +1,49 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S --passes=slp-vectorizer -slp-threshold=-99999 -mtriple=riscv64 -mattr=+v < %s | FileCheck %s
+
+define void @test(ptr %a, i64 %0) {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ptr [[A:%.*]], i64 [[TMP0:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[TMP1:%.*]] = insertelement <2 x ptr> poison, ptr [[A]], i32 0
+; CHECK-NEXT:    [[TMP2:%.*]] = shufflevector <2 x ptr> [[TMP1]], <2 x ptr> poison, <2 x i32> zeroinitializer
+; CHECK-NEXT:    br label %[[BB:.*]]
+; CHECK:       [[BB]]:
+; CHECK-NEXT:    [[TMP3:%.*]] = or disjoint i64 [[TMP0]], 1
+; CHECK-NEXT:    [[ARRAYIDX17_I28_1:%.*]] = getelementptr double, ptr [[A]], i64 [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = insertelement <2 x i64> poison, i64 [[TMP3]], i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = insertelement <2 x i64> [[TMP4]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr double, <2 x ptr> [[TMP2]], <2 x i64> [[TMP5]]
+; CHECK-NEXT:    [[TMP7:%.*]] = call <2 x double> @llvm.masked.gather.v2f64.v2p0(<2 x ptr> [[TMP6]], i32 8, <2 x i1> <i1 true, i1 true>, <2 x double> poison)
+; CHECK-NEXT:    [[TMP8:%.*]] = load <2 x double>, ptr [[A]], align 8
+; CHECK-NEXT:    [[TMP9:%.*]] = load <2 x double>, ptr [[A]], align 8
+; CHECK-NEXT:    [[TMP10:%.*]] = fsub <2 x double> [[TMP8]], [[TMP9]]
+; CHECK-NEXT:    [[TMP11:%.*]] = fsub <2 x double> [[TMP7]], [[TMP10]]
+; CHECK-NEXT:    call void @llvm.experimental.vp.strided.store.v2f64.p0.i64(<2 x double> [[TMP11]], ptr align 8 [[ARRAYIDX17_I28_1]], i64 -8, <2 x i1> <i1 true, i1 true>, i32 2)
+; CHECK-NEXT:    br label %[[BB]]
+;
+entry:
+  br label %bb
+
+bb:
+  %indvars.iv.next239.i = add i64 0, 0
+  %arrayidx.i.1 = getelementptr double, ptr %a, i64 %indvars.iv.next239.i
+  %1 = load double, ptr %arrayidx.i.1, align 8
+  %arrayidx10.i.1 = getelementptr double, ptr %a, i64 %0
+  %2 = or disjoint i64 %0, 1
+  %arrayidx17.i28.1 = getelementptr double, ptr %a, i64 %2
+  %3 = load double, ptr %arrayidx17.i28.1, align 8
+  %4 = load double, ptr %a, align 8
+  %5 = load double, ptr %a, align 8
+  %arrayidx38.i.1 = getelementptr double, ptr %a, i64 1
+  %6 = load double, ptr %arrayidx38.i.1, align 8
+  %arrayidx41.i.1 = getelementptr double, ptr %a, i64 1
+  %7 = load double, ptr %arrayidx41.i.1, align 8
+  %sub47.i.1 = fsub double %4, %5
+  %sub54.i.1 = fsub double %6, %7
+  %sub69.i.1 = fsub double %1, %sub54.i.1
+  store double %sub69.i.1, ptr %arrayidx10.i.1, align 8
+  %sub72.i.1 = fsub double %3, %sub47.i.1
+  store double %sub72.i.1, ptr %arrayidx17.i28.1, align 8
+  br label %bb
+}


        


More information about the llvm-commits mailing list