[llvm] d49f2c6 - [RISCV] Handle non-recursive muls of strides in gather/scatter lowering

Luke Lau via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 27 12:22:21 PDT 2023


Author: Luke Lau
Date: 2023-03-27T20:21:50+01:00
New Revision: d49f2c6d40da986c09262c3e1a421a4cb7983c4b

URL: https://github.com/llvm/llvm-project/commit/d49f2c6d40da986c09262c3e1a421a4cb7983c4b
DIFF: https://github.com/llvm/llvm-project/commit/d49f2c6d40da986c09262c3e1a421a4cb7983c4b.diff

LOG: [RISCV] Handle non-recursive muls of strides in gather/scatter lowering

The gather scatter lowering pass can fold multiplies of a step vector
into the stride for the recursive case, so this extends it for the
non-recursive case.
The logic can probably be shared between the two at some point to extend
it to shls and ors.

Reviewed By: reames

Differential Revision: https://reviews.llvm.org/D146983

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 8a440ed29ac3..b1171dac6a09 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -148,9 +148,11 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
     return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
   }
 
-  // Not a constant, maybe it's a strided constant with a splat added to it.
+  // Not a constant, maybe it's a strided constant with a splat added or
+  // multipled.
   auto *BO = dyn_cast<BinaryOperator>(Start);
-  if (!BO || BO->getOpcode() != Instruction::Add)
+  if (!BO || (BO->getOpcode() != Instruction::Add &&
+              BO->getOpcode() != Instruction::Mul))
     return std::make_pair(nullptr, nullptr);
 
   // Look for an operand that is splatted.
@@ -169,10 +171,17 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   if (!Start)
     return std::make_pair(nullptr, nullptr);
 
-  // Add the splat value to the start.
   Builder.SetInsertPoint(BO);
   Builder.SetCurrentDebugLocation(DebugLoc());
-  Start = Builder.CreateAdd(Start, Splat);
+  // Add the splat value to the start
+  if (BO->getOpcode() == Instruction::Add) {
+    Start = Builder.CreateAdd(Start, Splat);
+  }
+  // Or multiply the start and stride by the splat.
+  else if (BO->getOpcode() == Instruction::Mul) {
+    Start = Builder.CreateMul(Start, Splat);
+    Stride = Builder.CreateMul(Stride, Splat);
+  }
   return std::make_pair(Start, Stride);
 }
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index 31fcf10fa380..bcc73e039977 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -91,12 +91,11 @@ for.cond.cleanup:                                 ; preds = %vector.body
 
 define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
 ; CHECK-LABEL: @gather_loopless(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
-; CHECK-NEXT:    [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
-; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
@@ -115,12 +114,11 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
 
 define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
 ; CHECK-LABEL: @scatter_loopless(
-; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT:    [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[STRIDE:%.*]], i64 0
-; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
-; CHECK-NEXT:    [[OFFSETS:%.*]] = mul <vscale x 1 x i64> [[STEP]], [[SPLAT]]
-; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSETS]]
-; CHECK-NEXT:    call void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64> [[X:%.*]], <vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i64 1, [[STRIDE]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT:    call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> [[X:%.*]], ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
 ; CHECK-NEXT:    ret void
 ;
   %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()


        


More information about the llvm-commits mailing list