[llvm] 5358968 - [RISCV] Pattern match scalable strided load/store
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Sat Sep 24 17:42:12 PDT 2022
Author: Philip Reames
Date: 2022-09-24T17:41:58-07:00
New Revision: 5358968e131cc486f236b0a5593d2201a609cdc1
URL: https://github.com/llvm/llvm-project/commit/5358968e131cc486f236b0a5593d2201a609cdc1
DIFF: https://github.com/llvm/llvm-project/commit/5358968e131cc486f236b0a5593d2201a609cdc1.diff
LOG: [RISCV] Pattern match scalable strided load/store
Very straight forward extension of the existing pattern matching pass to handle scalable types as well as fixed length types. The only extra bit beyond removing a bailout is recognizing stepvector.
Differential Revision: https://reviews.llvm.org/D134502
Added:
Modified:
llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 2410cc1f8859..eada2bf5c09b 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -21,9 +21,11 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsRISCV.h"
+#include "llvm/IR/PatternMatch.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
+using namespace PatternMatch;
#define DEBUG_TYPE "riscv-gather-scatter-lowering"
@@ -139,6 +141,12 @@ static std::pair<Value *, Value *> matchStridedStart(Value *Start,
if (StartC)
return matchStridedConstant(StartC);
+ // Base case, start is a stepvector
+ if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
+ auto *Ty = Start->getType()->getScalarType();
+ return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
+ }
+
// Not a constant, maybe it's a strided constant with a splat added to it.
auto *BO = dyn_cast<BinaryOperator>(Start);
if (!BO || BO->getOpcode() != Instruction::Add)
@@ -482,11 +490,9 @@ bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
for (BasicBlock &BB : F) {
for (Instruction &I : BB) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
- if (II && II->getIntrinsicID() == Intrinsic::masked_gather &&
- isa<FixedVectorType>(II->getType())) {
+ if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
Gathers.push_back(II);
- } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter &&
- isa<FixedVectorType>(II->getArgOperand(0)->getType())) {
+ } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
Scatters.push_back(II);
}
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
index dd663be047a7..be151e4377c4 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll
@@ -10,21 +10,18 @@ define <vscale x 1 x i64> @gather(ptr %a, i32 %len) {
; CHECK-NEXT: vector.ph:
; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[TMP0]], i64 0
-; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[DOTSPLATINSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 1 x i64> [ [[TMP1]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
; CHECK-NEXT: [[ACCUM:%.*]] = phi <vscale x 1 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[ACCUM_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], <vscale x 1 x i64> [[VEC_IND]], i32 3
-; CHECK-NEXT: [[GATHER:%.*]] = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> [[TMP2]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i32 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer), <vscale x 1 x i64> undef)
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3
+; CHECK-NEXT: [[GATHER:%.*]] = call <vscale x 1 x i64> @llvm.riscv.masked.strided.load.nxv1i64.p0.i64(<vscale x 1 x i64> undef, ptr [[TMP1]], i64 16, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i32 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: [[ACCUM_NEXT]] = add <vscale x 1 x i64> [[ACCUM]], [[GATHER]]
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]]
-; CHECK-NEXT: [[VEC_IND_NEXT]] = add <vscale x 1 x i64> [[VEC_IND]], [[DOTSPLAT]]
-; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]]
-; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
+; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT: br i1 [[TMP2]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
; CHECK: for.cond.cleanup:
; CHECK-NEXT: ret <vscale x 1 x i64> [[ACCUM_NEXT]]
;
@@ -57,19 +54,16 @@ define void @scatter(ptr %a, i32 %len) {
; CHECK-NEXT: vector.ph:
; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
-; CHECK-NEXT: [[TMP1:%.*]] = tail call <vscale x 1 x i64> @llvm.experimental.stepvector.nxv1i64()
-; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 1 x i64> poison, i64 [[TMP0]], i64 0
-; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 1 x i64> [[DOTSPLATINSERT]], <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 1 x i64> [ [[TMP1]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], <vscale x 1 x i64> [[VEC_IND]], i32 3
-; CHECK-NEXT: tail call void @llvm.masked.scatter.nxv1i64.nxv1p0(<vscale x 1 x i64> zeroinitializer, <vscale x 1 x ptr> [[TMP2]], i32 8, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i32 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
+; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3
+; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64(<vscale x 1 x i64> zeroinitializer, ptr [[TMP1]], i64 16, <vscale x 1 x i1> shufflevector (<vscale x 1 x i1> insertelement (<vscale x 1 x i1> poison, i1 true, i32 0), <vscale x 1 x i1> poison, <vscale x 1 x i32> zeroinitializer))
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]]
-; CHECK-NEXT: [[VEC_IND_NEXT]] = add <vscale x 1 x i64> [[VEC_IND]], [[DOTSPLAT]]
-; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]]
-; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
+; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[TMP0]]
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]]
+; CHECK-NEXT: br i1 [[TMP2]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
; CHECK: for.cond.cleanup:
; CHECK-NEXT: ret void
;
More information about the llvm-commits
mailing list