[PATCH] D146983: [RISCV] Handle non-recursive muls of strides in gather/scater lowering

Luke Lau via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 27 09:17:21 PDT 2023


luke created this revision.
luke added reviewers: craig.topper, reames.
Herald added subscribers: jobnoorman, asb, pmatos, VincentWu, vkmr, frasercrmck, evandro, luismarques, apazos, sameer.abuasal, s.egerton, Jim, benna, psnobl, jocewei, PkmX, the_o, brucehoult, MartinMosbeck, rogfer01, edward-jones, zzheng, jrtc27, shiva0217, kito-cheng, niosHD, sabuasal, simoncook, johnrusso, rbar, hiraditya, arichardson.
Herald added a project: All.
luke requested review of this revision.
Herald added subscribers: llvm-commits, pcwang-thead, eopXD, MaskRay.
Herald added a project: LLVM.

The gather scatter lowering pass can fold multiplies of a step vector
into the stride for the recursive case, so this extends it for the
non-recursive case.
The logic can probably be shared between the two at some point to extend
it to shls and ors.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D146983

Files:
  llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
  llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll


Index: llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
===================================================================
--- llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -91,12 +91,11 @@
 
 define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
 ; CHECK-LABEL: @gather_loopless(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
-; CHECK-NEXT:    [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -115,12 +114,11 @@
 
 define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
 ; CHECK-LABEL: @scatter_loopless(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
-; CHECK-NEXT:    [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
-; CHECK-NEXT:    call void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64> [[X:%.*]], <vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT:    call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> [[X:%.*]], ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret void
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
Index: llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
===================================================================
--- llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -148,9 +148,11 @@
     return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
   }
 
-  // Not a constant, maybe it's a strided constant with a splat added to it.
+  // Not a constant, maybe it's a strided constant with a splat added or
+  // multipled.
   auto *BO = dyn_cast<BinaryOperator>(Start);
-  if (!BO || BO->getOpcode() != Instruction::Add)
+  if (!BO || (BO->getOpcode() != Instruction::Add &&
+              BO->getOpcode() != Instruction::Mul))
     return std::make_pair(nullptr, nullptr);
 
   // Look for an operand that is splatted.
@@ -169,10 +171,16 @@
   if (!Start)
     return std::make_pair(nullptr, nullptr);
 
-  // Add the splat value to the start.
   Builder.SetInsertPoint(BO);
   Builder.SetCurrentDebugLocation(DebugLoc());
-  Start = Builder.CreateAdd(Start, Splat);
+  // Add the splat value to the start
+  if (BO->getOpcode() == Instruction::Add)
+    Start = Builder.CreateAdd(Start, Splat);
+  // Or multiply the start and stride by the splat.
+  else if (BO->getOpcode() == Instruction::Mul) {
+    Start = Builder.CreateMul(Start, Splat);
+    Stride = Builder.CreateMul(Stride, Splat);
+  }
   return std::make_pair(Start, Stride);
 }
 


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D146983.508692.patch
Type: text/x-patch
Size: 4974 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20230327/e42f8045/attachment.bin>


More information about the llvm-commits mailing list