[llvm] [RISCV] Combine vp_strided_load with zero stride to scalar load + splat (PR #97798)
Luke Lau via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 5 01:28:08 PDT 2024
https://github.com/lukel97 created https://github.com/llvm/llvm-project/pull/97798
This is another version of #97394, but performs it as a DAGCombine instead of lowering so that we have a better chance of detecting non-zero EVLs before they are legalized.
The riscv_masked_strided_load already does this, but this combine also checks that the vector element type is legal. Currently a riscv_masked_strided_load with a zero stride of nxv1i64 will crash on rv32, but I'm hoping we can remove the masked_strided intrinsics and replace them with their VP counterparts.
RISCVISelDAGToDAG will lower splats of scalar loads back to zero strided loads anyway, so the test changes are to show how combining it to a scalar load can lead to some .vx patterns being matched.
>From 7bf10212fdb7500177825338c03a5b24927d0bcf Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 5 Jul 2024 12:04:24 +0800
Subject: [PATCH 1/2] Precommit tests
---
llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll | 52 +++++++++++++++++++
1 file changed, 52 insertions(+)
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
index 4d3bced0bcb50..39c0ed56188ef 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
@@ -780,3 +780,55 @@ define <vscale x 16 x double> @strided_load_nxv17f64(ptr %ptr, i64 %stride, <vsc
declare <vscale x 17 x double> @llvm.experimental.vp.strided.load.nxv17f64.p0.i64(ptr, i64, <vscale x 17 x i1>, i32)
declare <vscale x 1 x double> @llvm.experimental.vector.extract.nxv1f64(<vscale x 17 x double> %vec, i64 %idx)
declare <vscale x 16 x double> @llvm.experimental.vector.extract.nxv16f64(<vscale x 17 x double> %vec, i64 %idx)
+
+define <vscale x 1 x i64> @zero_strided_zero_evl(ptr %ptr, <vscale x 1 x i64> %v) {
+; CHECK-LABEL: zero_strided_zero_evl:
+; CHECK: # %bb.0:
+; CHECK-NEXT: ret
+ %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 0)
+ %res = add <vscale x 1 x i64> %v, %load
+ ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 1 x i64> @zero_strided_not_known_notzero_evl(ptr %ptr, <vscale x 1 x i64> %v, i32 zeroext %evl) {
+; CHECK-LABEL: zero_strided_not_known_notzero_evl:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a1, e64, m1, ta, ma
+; CHECK-NEXT: vlse64.v v9, (a0), zero
+; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+ %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 %evl)
+ %res = add <vscale x 1 x i64> %v, %load
+ ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 1 x i64> @zero_strided_known_notzero_avl(ptr %ptr, <vscale x 1 x i64> %v) {
+; CHECK-LABEL: zero_strided_known_notzero_avl:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, ma
+; CHECK-NEXT: vlse64.v v9, (a0), zero
+; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT: vadd.vv v8, v8, v9
+; CHECK-NEXT: ret
+ %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 1)
+ %res = add <vscale x 1 x i64> %v, %load
+ ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 2 x i64> @zero_strided_vec_length_avl(ptr %ptr, <vscale x 2 x i64> %v) vscale_range(2, 1024) {
+; CHECK-LABEL: zero_strided_vec_length_avl:
+; CHECK: # %bb.0:
+; CHECK-NEXT: csrr a1, vlenb
+; CHECK-NEXT: srli a1, a1, 2
+; CHECK-NEXT: vsetvli zero, a1, e64, m2, ta, ma
+; CHECK-NEXT: vlse64.v v10, (a0), zero
+; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-NEXT: vadd.vv v8, v8, v10
+; CHECK-NEXT: ret
+ %vscale = call i32 @llvm.vscale()
+ %veclen = mul i32 %vscale, 2
+ %load = call <vscale x 2 x i64> @llvm.experimental.vp.strided.load.nxv2i64.p0.i32(ptr %ptr, i32 0, <vscale x 2 x i1> splat (i1 true), i32 %veclen)
+ %res = add <vscale x 2 x i64> %v, %load
+ ret <vscale x 2 x i64> %res
+}
>From af35a30f9bc0f938782c1b430b86db2747f09871 Mon Sep 17 00:00:00 2001
From: Luke Lau <luke at igalia.com>
Date: Fri, 5 Jul 2024 14:11:23 +0800
Subject: [PATCH 2/2] [RISCV] Combine vp_strided_load with zero stride to
scalar load + splat
This is another version of #97394, but performs it as a DAGCombine instead of lowering so that we have a better chance of detecting non-zero EVLs before they are legalized.
The riscv_masked_strided_load already does this, but this combine also checks that the vector element type is legal. Currently a riscv_masked_strided_load with a zero stride of nxv1i64 will crash on rv32, but I'm hoping we can remove the masked_strided intrinsics and replace them with their VP counterparts.
RISCVISelDAGToDAG will lower splats of scalar loads back to zero strided loads anyway, so the test changes are to show how combining it to a scalar load can lead to some .vx patterns being matched.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 32 ++++++++++++-
llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll | 46 ++++++++++++-------
2 files changed, 61 insertions(+), 17 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 022b8bcedda4d..24e384fa64f1a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1502,7 +1502,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
if (Subtarget.hasVInstructions())
setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
- ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
+ ISD::VP_GATHER, ISD::VP_SCATTER,
+ ISD::EXPERIMENTAL_VP_STRIDED_LOAD, ISD::SRA, ISD::SRL,
ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
@@ -17108,6 +17109,35 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
VPSN->getMemOperand(), IndexType);
break;
}
+ case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: {
+ if (DCI.isBeforeLegalize())
+ break;
+ auto *Load = cast<VPStridedLoadSDNode>(N);
+ MVT VT = N->getSimpleValueType(0);
+
+ // Combine a zero strided load -> scalar load + splat
+ // The mask must be all ones and the EVL must be known to not be zero
+ if (!DAG.isKnownNeverZero(Load->getVectorLength()) ||
+ !Load->getOffset().isUndef() || !Load->isSimple() ||
+ !ISD::isConstantSplatVectorAllOnes(Load->getMask().getNode()) ||
+ !isNullConstant(Load->getStride()) ||
+ !isTypeLegal(VT.getVectorElementType()))
+ break;
+
+ SDValue ScalarLoad;
+ if (VT.isInteger())
+ ScalarLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, XLenVT, Load->getChain(),
+ Load->getBasePtr(), VT.getVectorElementType(),
+ Load->getMemOperand());
+ else
+ ScalarLoad = DAG.getLoad(VT.getVectorElementType(), DL, Load->getChain(),
+ Load->getBasePtr(), Load->getMemOperand());
+ SDValue Splat = VT.isFixedLengthVector()
+ ? DAG.getSplatBuildVector(VT, DL, ScalarLoad)
+ : DAG.getSplatVector(VT, DL, ScalarLoad);
+ return DAG.getMergeValues({Splat, SDValue(ScalarLoad.getNode(), 1)}, DL);
+ break;
+ }
case RISCVISD::SHL_VL:
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
return V;
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
index 39c0ed56188ef..c19ecbb75d818 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
@@ -804,28 +804,42 @@ define <vscale x 1 x i64> @zero_strided_not_known_notzero_evl(ptr %ptr, <vscale
}
define <vscale x 1 x i64> @zero_strided_known_notzero_avl(ptr %ptr, <vscale x 1 x i64> %v) {
-; CHECK-LABEL: zero_strided_known_notzero_avl:
-; CHECK: # %bb.0:
-; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, ma
-; CHECK-NEXT: vlse64.v v9, (a0), zero
-; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
-; CHECK-NEXT: vadd.vv v8, v8, v9
-; CHECK-NEXT: ret
+; CHECK-RV32-LABEL: zero_strided_known_notzero_avl:
+; CHECK-RV32: # %bb.0:
+; CHECK-RV32-NEXT: vsetivli zero, 1, e64, m1, ta, ma
+; CHECK-RV32-NEXT: vlse64.v v9, (a0), zero
+; CHECK-RV32-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-RV32-NEXT: vadd.vv v8, v8, v9
+; CHECK-RV32-NEXT: ret
+;
+; CHECK-RV64-LABEL: zero_strided_known_notzero_avl:
+; CHECK-RV64: # %bb.0:
+; CHECK-RV64-NEXT: ld a0, 0(a0)
+; CHECK-RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma
+; CHECK-RV64-NEXT: vadd.vx v8, v8, a0
+; CHECK-RV64-NEXT: ret
%load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 1)
%res = add <vscale x 1 x i64> %v, %load
ret <vscale x 1 x i64> %res
}
define <vscale x 2 x i64> @zero_strided_vec_length_avl(ptr %ptr, <vscale x 2 x i64> %v) vscale_range(2, 1024) {
-; CHECK-LABEL: zero_strided_vec_length_avl:
-; CHECK: # %bb.0:
-; CHECK-NEXT: csrr a1, vlenb
-; CHECK-NEXT: srli a1, a1, 2
-; CHECK-NEXT: vsetvli zero, a1, e64, m2, ta, ma
-; CHECK-NEXT: vlse64.v v10, (a0), zero
-; CHECK-NEXT: vsetvli a0, zero, e64, m2, ta, ma
-; CHECK-NEXT: vadd.vv v8, v8, v10
-; CHECK-NEXT: ret
+; CHECK-RV32-LABEL: zero_strided_vec_length_avl:
+; CHECK-RV32: # %bb.0:
+; CHECK-RV32-NEXT: csrr a1, vlenb
+; CHECK-RV32-NEXT: srli a1, a1, 2
+; CHECK-RV32-NEXT: vsetvli zero, a1, e64, m2, ta, ma
+; CHECK-RV32-NEXT: vlse64.v v10, (a0), zero
+; CHECK-RV32-NEXT: vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-RV32-NEXT: vadd.vv v8, v8, v10
+; CHECK-RV32-NEXT: ret
+;
+; CHECK-RV64-LABEL: zero_strided_vec_length_avl:
+; CHECK-RV64: # %bb.0:
+; CHECK-RV64-NEXT: ld a0, 0(a0)
+; CHECK-RV64-NEXT: vsetvli a1, zero, e64, m2, ta, ma
+; CHECK-RV64-NEXT: vadd.vx v8, v8, a0
+; CHECK-RV64-NEXT: ret
%vscale = call i32 @llvm.vscale()
%veclen = mul i32 %vscale, 2
%load = call <vscale x 2 x i64> @llvm.experimental.vp.strided.load.nxv2i64.p0.i32(ptr %ptr, i32 0, <vscale x 2 x i1> splat (i1 true), i32 %veclen)
More information about the llvm-commits
mailing list