[llvm] 715a043 - [RISCVGatherScatterLowering] Support shl in non-recursive matching
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Fri May 12 12:22:08 PDT 2023
Author: Philip Reames
Date: 2023-05-12T12:21:33-07:00
New Revision: 715a04309069aaa87ee4952918fde339492e4295
URL: https://github.com/llvm/llvm-project/commit/715a04309069aaa87ee4952918fde339492e4295
DIFF: https://github.com/llvm/llvm-project/commit/715a04309069aaa87ee4952918fde339492e4295.diff
LOG: [RISCVGatherScatterLowering] Support shl in non-recursive matching
We can apply the same logic as for multiply since a left shift is just a multiply by a power of two. Note that since shl is not commutative, we do need to be careful to match sure that the splat is the RHS of the instruction.
Differential Revision: https://reviews.llvm.org/D150471
Added:
Modified:
llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 30c1e39022d5..5e527c60ca5f 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -135,15 +135,16 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
// multipled.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || (BO->getOpcode() != Instruction::Add &&
+ BO->getOpcode() != Instruction::Shl &&
BO->getOpcode() != Instruction::Mul))
return std::make_pair(nullptr, nullptr);
// Look for an operand that is splatted.
- unsigned OtherIndex = 1;
- Value *Splat = getSplatValue(BO->getOperand(0));
- if (!Splat) {
- Splat = getSplatValue(BO->getOperand(1));
- OtherIndex = 0;
+ unsigned OtherIndex = 0;
+ Value *Splat = getSplatValue(BO->getOperand(1));
+ if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
+ Splat = getSplatValue(BO->getOperand(0));
+ OtherIndex = 1;
}
if (!Splat)
return std::make_pair(nullptr, nullptr);
@@ -158,13 +159,22 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
Builder.SetCurrentDebugLocation(DebugLoc());
// Add the splat value to the start or multiply the start and stride by the
// splat.
- if (BO->getOpcode() == Instruction::Add) {
+ switch (BO->getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case Instruction::Add:
Start = Builder.CreateAdd(Start, Splat);
- } else {
- assert(BO->getOpcode() == Instruction::Mul && "Unexpected opcode");
+ break;
+ case Instruction::Mul:
Start = Builder.CreateMul(Start, Splat);
Stride = Builder.CreateMul(Stride, Splat);
+ break;
+ case Instruction::Shl:
+ Start = Builder.CreateShl(Start, Splat);
+ Stride = Builder.CreateShl(Stride, Splat);
+ break;
}
+
return std::make_pair(Start, Stride);
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index f5ba1683ef4c..e5a2da6b8060 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -112,6 +112,94 @@ define <vscale x 1 x i64> @gather_loopless(ptr %p, i64 %stride) {
ret <vscale x 1 x i64> %x
}
+define <vscale x 1 x i64> @straightline_offset_add(ptr %p, i64 %offset) {
+; CHECK-LABEL: @straightline_offset_add(
+; CHECK-NEXT: [[TMP1:%.*]] = add i64 0, [[OFFSET:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP2]], i64 4, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
+;
+ %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+ %splat.insert = insertelement <vscale x 1 x i64> poison, i64 %offset, i64 0
+ %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+ %offsetv = add <vscale x 1 x i64> %step, %splat
+ %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offsetv
+ %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+ <vscale x 1 x ptr> %ptrs,
+ i32 8,
+ <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+ <vscale x 1 x i64> poison
+ )
+ ret <vscale x 1 x i64> %x
+}
+
+define <vscale x 1 x i64> @straightline_offset_shl(ptr %p) {
+; CHECK-LABEL: @straightline_offset_shl(
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 0
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP1]], i64 32, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
+;
+ %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+ %splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
+ %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+ %offset = shl <vscale x 1 x i64> %step, %splat
+ %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
+ %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+ <vscale x 1 x ptr> %ptrs,
+ i32 8,
+ <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+ <vscale x 1 x i64> poison
+ )
+ ret <vscale x 1 x i64> %x
+}
+
+define <vscale x 1 x i64> @neg_shl_is_not_commutative(ptr %p) {
+; CHECK-LABEL: @neg_shl_is_not_commutative(
+; CHECK-NEXT: [[STEP:%.*]] = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+; CHECK-NEXT: [[SPLAT_INSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
+; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[SPLAT_INSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+; CHECK-NEXT: [[OFFSET:%.*]] = shl <vscale x 1 x i64> [[SPLAT]], [[STEP]]
+; CHECK-NEXT: [[PTRS:%.*]] = getelementptr i32, ptr [[P:%.*]], <vscale x 1 x i64> [[OFFSET]]
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[PTRS]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> poison)
+; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
+;
+ %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+ %splat.insert = insertelement <vscale x 1 x i64> poison, i64 3, i64 0
+ %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+ %offset = shl <vscale x 1 x i64> %splat, %step
+ %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
+ %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+ <vscale x 1 x ptr> %ptrs,
+ i32 8,
+ <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+ <vscale x 1 x i64> poison
+ )
+ ret <vscale x 1 x i64> %x
+}
+
+define <vscale x 1 x i64> @straightline_offset_shl_nonc(ptr %p, i64 %shift) {
+; CHECK-LABEL: @straightline_offset_shl_nonc(
+; CHECK-NEXT: [[TMP1:%.*]] = shl i64 0, [[SHIFT:%.*]]
+; CHECK-NEXT: [[TMP2:%.*]] = shl i64 1, [[SHIFT]]
+; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, ptr [[P:%.*]], i64 [[TMP1]]
+; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP2]], 4
+; CHECK-NEXT: [[X:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> poison, ptr [[TMP3]], i64 [[TMP4]], <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT: ret <vscale x 1 x i64> [[X]]
+;
+ %step = call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
+ %splat.insert = insertelement <vscale x 1 x i64> poison, i64 %shift, i64 0
+ %splat = shufflevector <vscale x 1 x i64> %splat.insert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
+ %offset = shl <vscale x 1 x i64> %step, %splat
+ %ptrs = getelementptr i32, ptr %p, <vscale x 1 x i64> %offset
+ %x = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(
+ <vscale x 1 x ptr> %ptrs,
+ i32 8,
+ <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 1, i64 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer),
+ <vscale x 1 x i64> poison
+ )
+ ret <vscale x 1 x i64> %x
+}
+
define void @scatter_loopless(<vscale x 1 x i64> %x, ptr %p, i64 %stride) {
; CHECK-LABEL: @scatter_loopless(
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 0, [[STRIDE:%.*]]
More information about the llvm-commits
mailing list