[llvm] 715a043 - [RISCVGatherScatterLowering] Support shl in non-recursive matching

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri May 12 12:22:08 PDT 2023


Author: Philip Reames
Date: 2023-05-12T12:21:33-07:00
New Revision: 715a04309069aaa87ee4952918fde339492e4295

URL: https://github.com/llvm/llvm-project/commit/715a04309069aaa87ee4952918fde339492e4295
DIFF: https://github.com/llvm/llvm-project/commit/715a04309069aaa87ee4952918fde339492e4295.diff

LOG: [RISCVGatherScatterLowering] Support shl in non-recursive matching

We can apply the same logic as for multiply since a left shift is just a multiply by a power of two. Note that since shl is not commutative, we do need to be careful to match sure that the splat is the RHS of the instruction.

Differential Revision: https://reviews.llvm.org/D150471

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 30c1e39022d5..5e527c60ca5f 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -135,15 +135,16 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   // multipled.
   auto *BO = dyn_cast<BinaryOperator>(Start);
   if (!BO || (BO->getOpcode() != Instruction::Add &&
+              BO->getOpcode() != Instruction::Shl &&
               BO->getOpcode() != Instruction::Mul))
     return std::make_pair(nullptr, nullptr);
 
   // Look for an operand that is splatted.
-  unsigned OtherIndex = 1;
-  Value *Splat = getSplatValue(BO->getOperand(0));
-  if (!Splat) {
-    Splat = getSplatValue(BO->getOperand(1));
-    OtherIndex = 0;
+  unsigned OtherIndex = 0;
+  Value *Splat = getSplatValue(BO->getOperand(1));
+  if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
+    Splat = getSplatValue(BO->getOperand(0));
+    OtherIndex = 1;
   }
   if (!Splat)
     return std::make_pair(nullptr, nullptr);
@@ -158,13 +159,22 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
   Builder.SetCurrentDebugLocation(DebugLoc());
   // Add the splat value to the start or multiply the start and stride by the
   // splat.
-  if (BO->getOpcode() == Instruction::Add) {
+  switch (BO->getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case Instruction::Add:
     Start = Builder.CreateAdd(Start, Splat);
-  } else {
-    assert(BO->getOpcode() == Instruction::Mul && "Unexpected opcode");
+    break;
+  case Instruction::Mul:
     Start = Builder.CreateMul(Start, Splat);
     Stride = Builder.CreateMul(Stride, Splat);
+    break;
+  case Instruction::Shl:
+    Start = Builder.CreateShl(Start, Splat);
+    Stride = Builder.CreateShl(Stride, Splat);
+    break;
   }
+
   return std::make_pair(Start, Stride);
 }
 

diff  --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index f5ba1683ef4c..e5a2da6b8060 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -112,6 +112,94 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
   ret <vscale x 1 x i64> %x
 }
 
+define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
+; CHECK-LABEL: @straightline_offset_add(
+; CHECK-NEXT:    [[TMP1:%.*]] = add i64 0, [[OFFSET:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP2]], i64 4, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
+;
+  %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+  %splat.insert = insertelement <vscale x 1 x i64> poison, i64 %offset, i64 0
+  %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+  %offsetv = add <vscale x 1 x i64> %step, %splat
+  %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offsetv
+  %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+  <vscale x 1 x ptr> %ptrs,
+  i32 8,
+  <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+  <vscale x 1 x i64> poison
+  )
+  ret <vscale x 1 x i64> %x
+}
+
+define <vscale x 1 x i64> @straightline_offset_shl(ptr %p) {
+; CHECK-LABEL: @straightline_offset_shl(
+; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 0
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 32, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
+;
+  %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+  %splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
+  %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+  %offset = shl <vscale x 1 x i64> %step, %splat
+  %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
+  %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+  <vscale x 1 x ptr> %ptrs,
+  i32 8,
+  <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+  <vscale x 1 x i64> poison
+  )
+  ret <vscale x 1 x i64> %x
+}
+
+define <vscale x 1 x i64> @neg_shl_is_not_commutative(ptr %p) {
+; CHECK-LABEL: @neg_shl_is_not_commutative(
+; CHECK-NEXT:    [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+; CHECK-NEXT:    [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
+; CHECK-NEXT:    [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+; CHECK-NEXT:    [[OFFSET:%.*]] = shl <vscale x 1 x i64> [[SPLAT]], [[STEP]]
+; CHECK-NEXT:    [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSET]]
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
+;
+  %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+  %splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
+  %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+  %offset = shl <vscale x 1 x i64> %splat, %step
+  %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
+  %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+  <vscale x 1 x ptr> %ptrs,
+  i32 8,
+  <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+  <vscale x 1 x i64> poison
+  )
+  ret <vscale x 1 x i64> %x
+}
+
+define <vscale x 1 x i64> @straightline_offset_shl_nonc(ptr %p, i64 %shift) {
+; CHECK-LABEL: @straightline_offset_shl_nonc(
+; CHECK-NEXT:    [[TMP1:%.*]] = shl i64 0, [[SHIFT:%.*]]
+; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 1, [[SHIFT]]
+; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT:    [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT:    ret <vscale x 1 x i64> [[X]]
+;
+  %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+  %splat.insert = insertelement <vscale x 1 x i64> poison, i64 %shift, i64 0
+  %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+  %offset = shl <vscale x 1 x i64> %step, %splat
+  %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
+  %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+  <vscale x 1 x ptr> %ptrs,
+  i32 8,
+  <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+  <vscale x 1 x i64> poison
+  )
+  ret <vscale x 1 x i64> %x
+}
+
 define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
 ; CHECK-LABEL: @scatter_loopless(
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]


        


More information about the llvm-commits mailing list