[llvm] [RISCVGatherScatterLowering] Support vp_gather and vp_scatter (PR #73612)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Nov 27 20:41:48 PST 2023
https://github.com/ShivaChen created https://github.com/llvm/llvm-project/pull/73612
Support transfering vp_gather to experimental_vp_strided_load and vp_scatter to experimental_vp_strided_store.
>From 12c50c30d4963051836016978297ddfe3c2dcb47 Mon Sep 17 00:00:00 2001
From: Shiva Chen <shiva.chen at imgtec.com>
Date: Tue, 28 Nov 2023 03:08:22 +0000
Subject: [PATCH] [RISCVGatherScatterLowering] Support vp_gather and vp_scatter
---
.../RISCV/RISCVGatherScatterLowering.cpp | 48 +++++++++++---
.../CodeGen/RISCV/rvv/strided-load-store.ll | 63 +++++++++++++++++++
2 files changed, 101 insertions(+), 10 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 5ad1e082344e77a..973b970f8ce131d 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -65,7 +65,7 @@ class RISCVGatherScatterLowering : public FunctionPass {
private:
bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
- Value *AlignOp);
+ MaybeAlign MA);
std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
IRBuilderBase &Builder);
@@ -459,9 +459,8 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
Type *DataType,
Value *Ptr,
- Value *AlignOp) {
+ MaybeAlign MA) {
// Make sure the operation will be supported by the backend.
- MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
EVT DataTypeVT = TLI->getValueType(*DL, DataType);
if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
return false;
@@ -493,11 +492,22 @@ bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
Intrinsic::riscv_masked_strided_load,
{DataType, BasePtr->getType(), Stride->getType()},
{II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
- else
+ else if (II->getIntrinsicID() == Intrinsic::vp_gather)
+ Call = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vp_strided_load,
+ {DataType, BasePtr->getType(), Stride->getType()},
+ {BasePtr, Stride, II->getArgOperand(1), II->getArgOperand(2)});
+ else if (II->getIntrinsicID() == Intrinsic::masked_scatter)
Call = Builder.CreateIntrinsic(
Intrinsic::riscv_masked_strided_store,
{DataType, BasePtr->getType(), Stride->getType()},
{II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
+ else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
+ Call = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vp_strided_store,
+ {DataType, BasePtr->getType(), Stride->getType()},
+ {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(2),
+ II->getArgOperand(3)});
Call->takeName(II);
II->replaceAllUsesWith(Call);
@@ -533,22 +543,40 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
+ if (!II)
+ continue;
+ if (II->getIntrinsicID() == Intrinsic::masked_gather ||
+ II->getIntrinsicID() == Intrinsic::vp_gather) {
Gathers.push_back(II);
- } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
+ } else if (II->getIntrinsicID() == Intrinsic::masked_scatter ||
+ II->getIntrinsicID() == Intrinsic::vp_scatter) {
Scatters.push_back(II);
}
}
}
// Rewrite gather/scatter to form strided load/store if possible.
- for (auto *II : Gathers)
+ MaybeAlign MA;
+ for (auto *II : Gathers) {
+ if (II->getIntrinsicID() == Intrinsic::masked_gather)
+ MA = cast<ConstantInt>(II->getArgOperand(1))->getMaybeAlignValue();
+ else if (II->getIntrinsicID() == Intrinsic::vp_gather)
+ MA = II->getAttributes().getParamAttrs(0).getAlignment();
+
Changed |= tryCreateStridedLoadStore(
- II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
- for (auto *II : Scatters)
+ II, II->getType(), II->getArgOperand(0), MA);
+ }
+
+ for (auto *II : Scatters) {
+ if (II->getIntrinsicID() == Intrinsic::masked_scatter)
+ MA = cast<ConstantInt>(II->getArgOperand(2))->getMaybeAlignValue();
+ else if (II->getIntrinsicID() == Intrinsic::vp_scatter)
+ MA = II->getAttributes().getParamAttrs(1).getAlignment();
+
Changed |=
tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
- II->getArgOperand(1), II->getArgOperand(2));
+ II->getArgOperand(1), MA);
+ }
// Remove any dead phis.
while (!MaybeDeadPHIs.empty()) {
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index fcb3742eb2363ba..e722bea1919b06f 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -230,6 +230,69 @@ define void @constant_stride(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
ret void
}
+define void @vp_gather_scatter(ptr %A, i64 %vl, i64 %stride, i64 %n.vec) {
+; CHECK-LABEL: @vp_gather_scatter(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: br label [[VECTOR_PH:%.*]]
+; CHECK: vector.ph:
+; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
+; CHECK: vector.body:
+; CHECK-NEXT: [[INDVARS_IV36:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDVARS_IV_NEXT37:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[EVL_BASED_IV:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_EVL_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[AVL:%.*]] = sub i64 100, [[EVL_BASED_IV]]
+; CHECK-NEXT: [[EVL:%.*]] = tail call i32 @llvm.experimental.get.vector.length.i64(i64 [[AVL]], i32 4, i1 true)
+; CHECK-NEXT: [[TMP0:%.*]] = getelementptr [200 x i32], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i64 [[INDVARS_IV36]]
+; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <vscale x 4 x i32> @llvm.experimental.vp.strided.load.nxv4i32.p0.i64(ptr [[TMP0]], i64 800, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 [[EVL]])
+; CHECK-NEXT: [[TMP1:%.*]] = shl nsw <vscale x 4 x i32> [[WIDE_MASKED_GATHER]], shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 1, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+; CHECK-NEXT: call void @llvm.experimental.vp.strided.store.nxv4i32.p0.i64(<vscale x 4 x i32> [[TMP1]], ptr [[TMP0]], i64 800, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 [[EVL]])
+; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[EVL]] to i64
+; CHECK-NEXT: [[INDEX_EVL_NEXT]] = add i64 [[EVL_BASED_IV]], [[TMP2]]
+; CHECK-NEXT: [[INDEX_NEXT]] = add i64 [[INDEX]], [[VL:%.*]]
+; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STRIDE:%.*]]
+; CHECK-NEXT: [[INDVARS_IV_NEXT37]] = add nuw nsw i64 [[INDVARS_IV36]], 1
+; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC:%.*]]
+; CHECK-NEXT: br i1 [[TMP3]], label [[END:%.*]], label [[VECTOR_BODY]]
+; CHECK: end:
+; CHECK-NEXT: ret void
+;
+entry:
+ %stepvector = tail call <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+ %.splatinsert = insertelement <vscale x 4 x i64> poison, i64 %stride, i64 0
+ %.splat = shufflevector <vscale x 4 x i64> %.splatinsert, <vscale x 4 x i64> poison, <vscale x 4 x i32> zeroinitializer
+ br label %vector.ph
+
+vector.ph: ; preds = %for.inc14, %entry
+ br label %vector.body
+
+vector.body: ; preds = %vector.body, %vector.ph
+ %indvars.iv36 = phi i64 [ 0, %vector.ph ], [ %indvars.iv.next37, %vector.body ]
+ %index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
+ %evl.based.iv = phi i64 [ 0, %vector.ph ], [ %index.evl.next, %vector.body ]
+ %vec.ind = phi <vscale x 4 x i64> [ %stepvector, %vector.ph ], [ %vec.ind.next, %vector.body ]
+ %avl = sub i64 100, %evl.based.iv
+ %evl = tail call i32 @llvm.experimental.get.vector.length.i64(i64 %avl, i32 4, i1 true)
+ %0 = getelementptr inbounds [200 x i32], ptr %A, <vscale x 4 x i64> %vec.ind, i64 %indvars.iv36
+ %wide.masked.gather = tail call <vscale x 4 x i32> @llvm.vp.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> align 4 %0, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 %evl)
+ %1 = shl nsw <vscale x 4 x i32> %wide.masked.gather, shufflevector (<vscale x 4 x i32> insertelement (<vscale x 4 x i32> poison, i32 1, i64 0), <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer)
+ tail call void @llvm.vp.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> %1, <vscale x 4 x ptr> align 4 %0, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 true, i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer), i32 %evl)
+ %2 = zext i32 %evl to i64
+ %index.evl.next = add i64 %evl.based.iv, %2
+ %index.next = add i64 %index, %vl
+ %vec.ind.next = add <vscale x 4 x i64> %vec.ind, %.splat
+ %indvars.iv.next37 = add nuw nsw i64 %indvars.iv36, 1
+ %3 = icmp eq i64 %index.next, %n.vec
+ br i1 %3, label %end, label %vector.body
+
+end:
+ ret void
+}
+
declare i64 @llvm.vscale.i64()
declare void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64>, <vscale x 1 x ptr>, i32, <vscale x 1 x i1>)
declare <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr>, i32, <vscale x 1 x i1>, <vscale x 1 x i64>)
+declare <vscale x 4 x i64> @llvm.experimental.stepvector.nxv4i64()
+declare i32 @llvm.experimental.get.vector.length.i64(i64, i32 immarg, i1 immarg)
+declare <vscale x 4 x i32> @llvm.vp.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr>, <vscale x 4 x i1>, i32)
+declare void @llvm.vp.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32>, <vscale x 4 x ptr>, <vscale x 4 x i1>, i32)
More information about the llvm-commits
mailing list