[llvm] d49f2c6 - [RISCV] Handle non-recursive muls of strides in gather/scatter lowering
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 27 12:22:21 PDT 2023
Author: Luke Lau
Date: 2023-03-27T20:21:50+01:00
New Revision: d49f2c6d40da986c09262c3e1a421a4cb7983c4b
URL: https://github.com/llvm/llvm-project/commit/d49f2c6d40da986c09262c3e1a421a4cb7983c4b
DIFF: https://github.com/llvm/llvm-project/commit/d49f2c6d40da986c09262c3e1a421a4cb7983c4b.diff
LOG: [RISCV] Handle non-recursive muls of strides in gather/scatter lowering
The gather scatter lowering pass can fold multiplies of a step vector
into the stride for the recursive case, so this extends it for the
non-recursive case.
The logic can probably be shared between the two at some point to extend
it to shls and ors.
Reviewed By: reames
Differential Revision: https://reviews.llvm.org/D146983
Added:
Modified:
llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 8a440ed29ac3..b1171dac6a09 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -148,9 +148,11 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
}
- // Not a constant, maybe it's a strided constant with a splat added to it.
+ // Not a constant, maybe it's a strided constant with a splat added or
+ // multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
- if (!BO || BO->getOpcode() != Instruction::Add)
+ if (!BO || (BO->getOpcode() != Instruction::Add &&
+ BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);
// Look for an operand that is splatted.
@@ -169,10 +171,17 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
if (!Start)
return std::make_pair(nullptr, nullptr);
- // Add the splat value to the start.
Builder.SetInsertPoint(BO);
Builder.SetCurrentDebugLocation(DebugLoc());
- Start = Builder.CreateAdd(Start, Splat);
+ // Add the splat value to the start
+ if (BO->getOpcode() == Instruction::Add) {
+ Start = Builder.CreateAdd(Start, Splat);
+ }
+ // Or multiply the start and stride by the splat.
+ else if (BO->getOpcode() == Instruction::Mul) {
+ Start = Builder.CreateMul(Start, Splat);
+ Stride = Builder.CreateMul(Stride, Splat);
+ }
return std::make_pair(Start, Stride);
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 31fcf10fa380..bcc73e039977 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -91,12 +91,11 @@ for.cond.cleanup: ; preds = %vector.body
define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
; CHECK-LABEL: @gather_loopless(
-; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
-; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
-; CHECK-NEXT: [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
-; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
-; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -115,12 +114,11 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
; CHECK-LABEL: @scatter_loopless(
-; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
-; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
-; CHECK-NEXT: [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
-; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
-; CHECK-NEXT: call void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64> [[X:%.*]], <vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> [[X:%.*]], ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: ret void
;
%step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
More information about the llvm-commits
mailing list