[llvm] [SLP]Transform stores + reverse to strided stores with stride -1, if profitable. (PR #90464)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Wed May 1 04:32:16 PDT 2024


https://github.com/alexey-bataev updated https://github.com/llvm/llvm-project/pull/90464

>From 6ff922637e514f1a9efebe2fc07c73b087830dd1 Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Mon, 29 Apr 2024 13:02:58 +0000
Subject: [PATCH 1/2] =?UTF-8?q?[=F0=9D=98=80=F0=9D=97=BD=F0=9D=97=BF]=20in?=
 =?UTF-8?q?itial=20version?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Created using spr 1.3.5
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 73 +++++++++++++++++--
 .../RISCV/strided-stores-vectorized.ll        | 31 ++------
 2 files changed, 70 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f62270fe62ebea..2cca7130a0be70 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7860,6 +7860,32 @@ void BoUpSLP::transformNodes() {
       }
       break;
     }
+    case Instruction::Store: {
+      Type *ScalarTy = cast<StoreInst>(E.getMainOp())->getValueOperand()->getType();
+      auto *VecTy = FixedVectorType::get(ScalarTy, E.Scalars.size());
+      Align CommonAlignment = computeCommonAlignment<StoreInst>(E.Scalars);
+      // Check if profitable to represent consecutive load + reverse as strided
+      // load with stride -1.
+      if (isReverseOrder(E.ReorderIndices) &&
+          TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) {
+        SmallVector<int> Mask;
+        inversePermutation(E.ReorderIndices, Mask);
+        auto *BaseSI = cast<StoreInst>(E.Scalars.back());
+        InstructionCost OriginalVecCost =
+            TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(),
+                                 BaseSI->getPointerAddressSpace(), CostKind,
+                                 TTI::OperandValueInfo()) +
+            ::getShuffleCost(*TTI, TTI::SK_Reverse, VecTy, Mask, CostKind);
+        InstructionCost StridedCost = TTI->getStridedMemoryOpCost(
+            Instruction::Store, VecTy, BaseSI->getPointerOperand(),
+            /*VariableMask=*/false, CommonAlignment, CostKind, BaseSI);
+        if (StridedCost < OriginalVecCost)
+          // Strided load is more profitable than consecutive load + reverse -
+          // transform the node to strided load.
+          E.State = TreeEntry::StridedVectorize;
+      }
+      break;
+    }
     default:
       break;
     }
@@ -9343,11 +9369,22 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
         cast<StoreInst>(IsReorder ? VL[E->ReorderIndices.front()] : VL0);
     auto GetVectorCost = [=](InstructionCost CommonCost) {
       // We know that we can merge the stores. Calculate the cost.
-      TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0));
-      return TTI->getMemoryOpCost(Instruction::Store, VecTy, BaseSI->getAlign(),
-                                  BaseSI->getPointerAddressSpace(), CostKind,
-                                  OpInfo) +
-             CommonCost;
+      InstructionCost VecStCost;
+      if (E->State == TreeEntry::StridedVectorize) {
+        Align CommonAlignment =
+            computeCommonAlignment<StoreInst>(UniqueValues.getArrayRef());
+        VecStCost = TTI->getStridedMemoryOpCost(
+            Instruction::Store, VecTy, BaseSI->getPointerOperand(),
+            /*VariableMask=*/false, CommonAlignment, CostKind);
+      } else {
+        assert(E->State == TreeEntry::Vectorize &&
+               "Expected either strided or consecutive stores.");
+        TTI::OperandValueInfo OpInfo = getOperandInfo(E->getOperand(0));
+        VecStCost = TTI->getMemoryOpCost(
+            Instruction::Store, VecTy, BaseSI->getAlign(),
+            BaseSI->getPointerAddressSpace(), CostKind, OpInfo);
+      }
+      return VecStCost + CommonCost;
     };
     SmallVector<Value *> PointerOps(VL.size());
     for (auto [I, V] : enumerate(VL)) {
@@ -12251,7 +12288,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
   bool IsReverseOrder = isReverseOrder(E->ReorderIndices);
   auto FinalShuffle = [&](Value *V, const TreeEntry *E, VectorType *VecTy) {
     ShuffleInstructionBuilder ShuffleBuilder(Builder, *this);
-    if (E->getOpcode() == Instruction::Store) {
+    if (E->getOpcode() == Instruction::Store &&
+        E->State == TreeEntry::Vectorize) {
       ArrayRef<int> Mask =
           ArrayRef(reinterpret_cast<const int *>(E->ReorderIndices.begin()),
                    E->ReorderIndices.size());
@@ -12848,8 +12886,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       VecValue = FinalShuffle(VecValue, E, VecTy);
 
       Value *Ptr = SI->getPointerOperand();
-      StoreInst *ST =
-          Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
+      Instruction *ST;
+      if (E->State == TreeEntry::Vectorize) {
+        ST = Builder.CreateAlignedStore(VecValue, Ptr, SI->getAlign());
+      } else {
+        assert(E->State == TreeEntry::StridedVectorize &&
+               "Expected either strided or conseutive stores.");
+        Align CommonAlignment = computeCommonAlignment<StoreInst>(E->Scalars);
+        Type *StrideTy = DL->getIndexType(SI->getPointerOperandType());
+        auto *Inst = Builder.CreateIntrinsic(
+            Intrinsic::experimental_vp_strided_store,
+            {VecTy, Ptr->getType(), StrideTy},
+            {VecValue, Ptr,
+             ConstantInt::get(
+                 StrideTy, -static_cast<int>(DL->getTypeAllocSize(ScalarTy))),
+             Builder.getAllOnesMask(VecTy->getElementCount()),
+             Builder.getInt32(E->Scalars.size())});
+        Inst->addParamAttr(
+            /*ArgNo=*/1,
+            Attribute::getWithAlignment(Inst->getContext(), CommonAlignment));
+        ST = Inst;
+      }
 
       Value *V = propagateMetadata(ST, E->Scalars);
 
diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll
index 0dfa45da9d87f4..56e8829b0ec68b 100644
--- a/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll
+++ b/llvm/test/Transforms/SLPVectorizer/RISCV/strided-stores-vectorized.ll
@@ -4,33 +4,12 @@
 define void @store_reverse(ptr %p3) {
 ; CHECK-LABEL: @store_reverse(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP0:%.*]] = load i64, ptr [[P3:%.*]], align 8
-; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 8
-; CHECK-NEXT:    [[TMP1:%.*]] = load i64, ptr [[ARRAYIDX1]], align 8
-; CHECK-NEXT:    [[SHL:%.*]] = shl i64 [[TMP0]], [[TMP1]]
-; CHECK-NEXT:    [[ARRAYIDX2:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 7
-; CHECK-NEXT:    store i64 [[SHL]], ptr [[ARRAYIDX2]], align 8
-; CHECK-NEXT:    [[ARRAYIDX3:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 1
-; CHECK-NEXT:    [[TMP2:%.*]] = load i64, ptr [[ARRAYIDX3]], align 8
-; CHECK-NEXT:    [[ARRAYIDX4:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 9
-; CHECK-NEXT:    [[TMP3:%.*]] = load i64, ptr [[ARRAYIDX4]], align 8
-; CHECK-NEXT:    [[SHL5:%.*]] = shl i64 [[TMP2]], [[TMP3]]
-; CHECK-NEXT:    [[ARRAYIDX6:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 6
-; CHECK-NEXT:    store i64 [[SHL5]], ptr [[ARRAYIDX6]], align 8
-; CHECK-NEXT:    [[ARRAYIDX7:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 2
-; CHECK-NEXT:    [[TMP4:%.*]] = load i64, ptr [[ARRAYIDX7]], align 8
-; CHECK-NEXT:    [[ARRAYIDX8:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 10
-; CHECK-NEXT:    [[TMP5:%.*]] = load i64, ptr [[ARRAYIDX8]], align 8
-; CHECK-NEXT:    [[SHL9:%.*]] = shl i64 [[TMP4]], [[TMP5]]
-; CHECK-NEXT:    [[ARRAYIDX10:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 5
-; CHECK-NEXT:    store i64 [[SHL9]], ptr [[ARRAYIDX10]], align 8
-; CHECK-NEXT:    [[ARRAYIDX11:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 3
-; CHECK-NEXT:    [[TMP6:%.*]] = load i64, ptr [[ARRAYIDX11]], align 8
-; CHECK-NEXT:    [[ARRAYIDX12:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 11
-; CHECK-NEXT:    [[TMP7:%.*]] = load i64, ptr [[ARRAYIDX12]], align 8
-; CHECK-NEXT:    [[SHL13:%.*]] = shl i64 [[TMP6]], [[TMP7]]
+; CHECK-NEXT:    [[ARRAYIDX1:%.*]] = getelementptr inbounds i64, ptr [[P3:%.*]], i64 8
 ; CHECK-NEXT:    [[ARRAYIDX14:%.*]] = getelementptr inbounds i64, ptr [[P3]], i64 4
-; CHECK-NEXT:    store i64 [[SHL13]], ptr [[ARRAYIDX14]], align 8
+; CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i64>, ptr [[P3]], align 8
+; CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i64>, ptr [[ARRAYIDX1]], align 8
+; CHECK-NEXT:    [[TMP2:%.*]] = shl <4 x i64> [[TMP0]], [[TMP1]]
+; CHECK-NEXT:    call void @llvm.experimental.vp.strided.store.v4i64.p0.i64(<4 x i64> [[TMP2]], ptr align 8 [[ARRAYIDX14]], i64 -8, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, i32 4)
 ; CHECK-NEXT:    ret void
 ;
 entry:

>From 75b96edc24f1f7e7048de1ae019275c0f4c0256c Mon Sep 17 00:00:00 2001
From: Alexey Bataev <a.bataev at outlook.com>
Date: Mon, 29 Apr 2024 13:08:21 +0000
Subject: [PATCH 2/2] Fix formatting

Created using spr 1.3.5
---
 llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2cca7130a0be70..527a0a347b71b7 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7861,7 +7861,8 @@ void BoUpSLP::transformNodes() {
       break;
     }
     case Instruction::Store: {
-      Type *ScalarTy = cast<StoreInst>(E.getMainOp())->getValueOperand()->getType();
+      Type *ScalarTy =
+          cast<StoreInst>(E.getMainOp())->getValueOperand()->getType();
       auto *VecTy = FixedVectorType::get(ScalarTy, E.Scalars.size());
       Align CommonAlignment = computeCommonAlignment<StoreInst>(E.Scalars);
       // Check if profitable to represent consecutive load + reverse as strided



More information about the llvm-commits mailing list