[llvm] [RISCVGatherScatterLowering] Support vp_gather and vp_scatter (PR #73612)

via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 27 20:41:48 PST 2023


https://github.com/ShivaChen created https://github.com/llvm/llvm-project/pull/73612

Support transfering vp_gather to experimental_vp_strided_load and vp_scatter to experimental_vp_strided_store.

>From 12c50c30d4963051836016978297ddfe3c2dcb47 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Tue, 28 Nov 2023 03:08:22 +0000
Subject: [PATCH] [RISCVGatherScatterLowering] Support vp_gather and vp_scatter

---
 .../RISCV/RISCVGatherScatterLowering.cpp      | 48 +++++++++++---
 .../CodeGen/RISCV/rvv/strided-load-store.ll   | 63 +++++++++++++++++++
 2 files changed, 101 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 5ad1e082344e77a..973b970f8ce131d 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -65,7 +65,7 @@ class RISCVGatherScatterLowering : public FunctionPass {
 
 private:
   bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
-                                 Value *AlignOp);
+                                 MaybeAlign MA);
 
   std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
                                                      IRBuilderBase &Builder);
@@ -459,9 +459,8 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
                                                            Type *DataType,
                                                            Value *Ptr,
-                                                           Value *AlignOp) {
+                                                           MaybeAlign MA) {
   // Make sure the operation will be supported by the backend.
-  MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
   EVT DataTypeVT = TLI->getValueType(*DL, DataType);
   if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
     return false;
@@ -493,11 +492,22 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
         Intrinsic::riscv_masked_strided_load,
         {DataType, BasePtr->getType(), Stride->getType()},
         {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
-  else
+  else if (II->getIntrinsicID() == Intrinsic::vp_gather)
+    Call = Builder.CreateIntrinsic(
+        Intrinsic::experimental_vp_strided_load,
+        {DataType, BasePtr->getType(), Stride->getType()},
+        {BasePtr, Stride, II->getArgOperand(1), II->getArgOperand(2)});
+  else if (II->getIntrinsicID() == Intrinsic::masked_scatter)
     Call = Builder.CreateIntrinsic(
         Intrinsic::riscv_masked_strided_store,
         {DataType, BasePtr->getType(), Stride->getType()},
         {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
+  else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
+    Call = Builder.CreateIntrinsic(
+        Intrinsic::experimental_vp_strided_store,
+        {DataType, BasePtr->getType(), Stride->getType()},
+        {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(2),
+         II->getArgOperand(3)});
 
   Call->takeName(II);
   II->replaceAllUsesWith(Call);
@@ -533,22 +543,40 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
   for (BasicBlock &BB : F) {
     for (Instruction &I : BB) {
       IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
-      if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
+      if (!II)
+        continue;
+      if (II->getIntrinsicID() == Intrinsic::masked_gather ||
+          II->getIntrinsicID() == Intrinsic::vp_gather) {
         Gathers.push_back(II);
-      } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
+      } else if (II->getIntrinsicID() == Intrinsic::masked_scatter ||
+                 II->getIntrinsicID() == Intrinsic::vp_scatter) {
         Scatters.push_back(II);
       }
     }
   }
 
   // Rewrite gather/scatter to form strided load/store if possible.
-  for (auto *II : Gathers)
+  MaybeAlign MA;
+  for (auto *II : Gathers) {
+    if (II->getIntrinsicID() == Intrinsic::masked_gather)
+      MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue();
+    else if (II->getIntrinsicID() == Intrinsic::vp_gather)
+      MA = II->getAttributes().getParamAttrs(0).getAlignment();
+
     Changed |= tryCreateStridedLoadStore(
-        II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
-  for (auto *II : Scatters)
+        II, II->getType(), II->getArgOperand(0), MA);
+  }
+
+  for (auto *II : Scatters) {
+    if (II->getIntrinsicID() == Intrinsic::masked_scatter)
+      MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue();
+    else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
+      MA = II->getAttributes().getParamAttrs(1).getAlignment();
+
     Changed |=
         tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
-                                  II->getArgOperand(1), II->getArgOperand(2));
+                                  II->getArgOperand(1), MA);
+  }
 
   // Remove any dead phis.
   while (!MaybeDeadPHIs.empty()) {
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index fcb3742eb2363ba..e722bea1919b06f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -230,6 +230,69 @@ define void @constant_stride(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
   ret void
 }
 
+define void @vp_gather_scatter(ptr %A, i64 %vl, i64 %stride, i64 %n.vec) {
+; CHECK-LABEL: @vp_gather_scatter(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    br label [[VECTOR_PH:%.*]]
+; CHECK:       vector.ph:
+; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
+; CHECK:       vector.body:
+; CHECK-NEXT:    [[INDVARS_IV36:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDVARS_IV_NEXT37:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[EVL_BASED_IV:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT:    [[AVL:%.*]] = sub i64 100, [[EVL_BASED_IV]]
+; CHECK-NEXT:    [[EVL:%.*]] = tail call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 4, i1 true)
+; CHECK-NEXT:    [[TMP0:%.*]] = getelementptr [200 x i32], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i64 [[INDVARS_IV36]]
+; CHECK-NEXT:    [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vp.strided.load.nxv4i32.p0.i64(ptr [[TMP0]], i64 800, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 [[EVL]])
+; CHECK-NEXT:    [[TMP1:%.*]] = shl nsw <vscale x 4 x i32> [[WIDE_MASKED_GATHER]], shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 1, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+; CHECK-NEXT:    call void @llvm.experimental.vp.strided.store.nxv4i32.p0.i64(<vscale x 4 x i32> [[TMP1]], ptr [[TMP0]], i64 800, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 [[EVL]])
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i32 [[EVL]] to i64
+; CHECK-NEXT:    [[INDEX_EVL_NEXT]] = add i64 [[EVL_BASED_IV]], [[TMP2]]
+; CHECK-NEXT:    [[INDEX_NEXT]] = add i64 [[INDEX]], [[VL:%.*]]
+; CHECK-NEXT:    [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STRIDE:%.*]]
+; CHECK-NEXT:    [[INDVARS_IV_NEXT37]] = add nuw nsw i64 [[INDVARS_IV36]], 1
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC:%.*]]
+; CHECK-NEXT:    br i1 [[TMP3]], label [[END:%.*]], label [[VECTOR_BODY]]
+; CHECK:       end:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %stepvector = tail call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+  %.splatinsert = insertelement <vscale x 4 x i64> poison, i64 %stride, i64 0
+  %.splat = shufflevector <vscale x 4 x i64> %.splatinsert, <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer
+  br label %vector.ph
+
+vector.ph:                                        ; preds = %for.inc14, %entry
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %vector.ph
+  %indvars.iv36 = phi i64 [ 0, %vector.ph ], [ %indvars.iv.next37, %vector.body ]
+  %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+  %evl.based.iv = phi i64 [ 0, %vector.ph ], [ %index.evl.next, %vector.body ]
+  %vec.ind = phi <vscale x 4 x i64> [ %stepvector, %vector.ph ], [ %vec.ind.next, %vector.body ]
+  %avl = sub i64 100, %evl.based.iv
+  %evl = tail call i32 @llvm.experimental.get.vector.length.i64(i64 %avl, i32 4, i1 true)
+  %0 = getelementptr inbounds [200 x i32], ptr %A, <vscale x 4 x i64> %vec.ind, i64 %indvars.iv36
+  %wide.masked.gather = tail call <vscale x 4 x i32> @llvm.vp.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> align 4 %0, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 %evl)
+  %1 = shl nsw <vscale x 4 x i32> %wide.masked.gather, shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 1, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+  tail call void @llvm.vp.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> %1, <vscale x 4 x ptr> align 4 %0, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 %evl)
+  %2 = zext i32 %evl to i64
+  %index.evl.next = add i64 %evl.based.iv, %2
+  %index.next = add i64 %index, %vl
+  %vec.ind.next = add <vscale x 4 x i64> %vec.ind, %.splat
+  %indvars.iv.next37 = add nuw nsw i64 %indvars.iv36, 1
+  %3 = icmp eq i64 %index.next, %n.vec
+  br i1 %3, label %end, label %vector.body
+
+end:
+  ret void
+}
+
 declare i64 @llvm.vscale.i64()
 declare void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64>, <vscale x 1 x ptr>, i32, <vscale x 1 x i1>)
 declare <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr>, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+declare i32 @llvm.experimental.get.vector.length.i64(i64, i32 immarg, i1 immarg)
+declare <vscale x 4 x i32> @llvm.vp.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr>, <vscale x 4 x i1>, i32)
+declare void @llvm.vp.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32>, <vscale x 4 x ptr>, <vscale x 4 x i1>, i32)



More information about the llvm-commits mailing list