[llvm] [RISCV] Add a combine to form masked.load from unit strided load (PR #65674)
Philip Reames via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 7 14:05:01 PDT 2023
https://github.com/preames created https://github.com/llvm/llvm-project/pull/65674:
Add a DAG combine to form a masked.load from a masked_strided_load intrinsic with stride equal to element size. This covers a couple of extra test cases, and allows us to simplify and common some existing code on the concat_vector(load, ...) to strided load transform.
This is the first in a mini-patch series to try and generalize our strided load and gather matching to handle more cases, and common up different approaches to the same problems in different places.
>From 05bcba30908886c556c3ebdbefe2f137f4e7afd8 Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Thu, 7 Sep 2023 12:52:34 -0700
Subject: [PATCH] [RISCV] Add a combine to form masked.load from unit strided
load
Add a DAG combine to form a masked.load from a masked_strided_load intrinsic with stride equal to element size. This covers a couple of extra test cases, and allows us to simplify and common some existing code on the concat_vector(load, ...) to strided load transform.
---
llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 62 +++++++++----------
.../RISCV/rvv/fixed-vectors-masked-gather.ll | 6 +-
.../rvv/strided-load-store-intrinsics.ll | 3 +-
3 files changed, 32 insertions(+), 39 deletions(-)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 05e656ac817027c..d8a5015cb8a46e9 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13263,27 +13263,6 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
- // A special case is if the stride is exactly the width of one of the loads,
- // in which case it's contiguous and can be combined into a regular vle
- // without changing the element size
- if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
- ConstStride && !Reversed &&
- ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) {
- MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
- BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(),
- VT.getStoreSize(), Align);
- // Can't do the combine if the load isn't naturally aligned with the element
- // type
- if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(),
- DAG.getDataLayout(), VT, *MMO))
- return SDValue();
-
- SDValue WideLoad = DAG.getLoad(VT, DL, BaseLd->getChain(), BasePtr, MMO);
- for (SDValue Ld : N->ops())
- DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), WideLoad);
- return WideLoad;
- }
-
// Get the widened scalar type, e.g. v4i8 -> i64
unsigned WideScalarBitWidth =
BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements();
@@ -13298,20 +13277,22 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
return SDValue();
- MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT);
- SDValue VL =
- getDefaultVLOps(WideVecVT, ContainerVT, DL, DAG, Subtarget).second;
- SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
+ SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
SDValue IntID =
- DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT());
+ DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
+ Subtarget.getXLenVT());
if (Reversed)
Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+ SDValue AllOneMask =
+ DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
+ DAG.getConstant(1, DL, MVT::i1));
+
SDValue Ops[] = {BaseLd->getChain(),
IntID,
- DAG.getUNDEF(ContainerVT),
+ DAG.getUNDEF(WideVecVT),
BasePtr,
Stride,
- VL};
+ AllOneMask};
uint64_t MemSize;
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
@@ -13333,11 +13314,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
for (SDValue Ld : N->ops())
DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad);
- // Note: Perform the bitcast before the convertFromScalableVector so we have
- // balanced pairs of convertFromScalable/convertToScalable
- SDValue Res = DAG.getBitcast(
- TLI.getContainerForFixedLengthVector(VT.getSimpleVT()), StridedLoad);
- return convertFromScalableVector(VT, Res, DAG, Subtarget);
+ return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
}
static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
@@ -14076,6 +14053,25 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
// By default we do not combine any intrinsic.
default:
return SDValue();
+ case Intrinsic::riscv_masked_strided_load: {
+ MVT VT = N->getSimpleValueType(0);
+ auto *Load = cast<MemIntrinsicSDNode>(N);
+ SDValue PassThru = N->getOperand(2);
+ SDValue Base = N->getOperand(3);
+ SDValue Stride = N->getOperand(4);
+ SDValue Mask = N->getOperand(5);
+
+ // If the stride is equal to the element size in bytes, we can use
+ // a masked.load.
+ const unsigned ElementSizeInBits = VT.getScalarType().getSizeInBits();
+ if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride);
+ StrideC && StrideC->getZExtValue() == ElementSizeInBits/8)
+ return DAG.getMaskedLoad(VT, DL, Load->getChain(), Base,
+ DAG.getUNDEF(XLenVT), Mask, PassThru,
+ Load->getMemoryVT(), Load->getMemOperand(),
+ ISD::UNINDEXED, ISD::NON_EXTLOAD);
+ return SDValue();
+ }
case Intrinsic::riscv_vcpop:
case Intrinsic::riscv_vcpop_mask:
case Intrinsic::riscv_vfirst:
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index dc52e69e5364d89..b8c3d8ee6dd6dc0 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -13051,9 +13051,8 @@ define <4 x i32> @mgather_broadcast_load_masked(ptr %base, <4 x i1> %m) {
define <4 x i32> @mgather_unit_stride_load(ptr %base) {
; RV32-LABEL: mgather_unit_stride_load:
; RV32: # %bb.0:
-; RV32-NEXT: li a1, 4
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT: vlse32.v v8, (a0), a1
+; RV32-NEXT: vle32.v v8, (a0)
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_unit_stride_load:
@@ -13123,9 +13122,8 @@ define <4 x i32> @mgather_unit_stride_load_with_offset(ptr %base) {
; RV32-LABEL: mgather_unit_stride_load_with_offset:
; RV32: # %bb.0:
; RV32-NEXT: addi a0, a0, 16
-; RV32-NEXT: li a1, 4
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT: vlse32.v v8, (a0), a1
+; RV32-NEXT: vle32.v v8, (a0)
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_unit_stride_load_with_offset:
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
index c4653954302796e..06af0fc8971a543 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
@@ -55,9 +55,8 @@ define <32 x i8> @strided_load_i8_nostride(ptr %p, <32 x i1> %m) {
; CHECK-LABEL: strided_load_i8_nostride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
-; CHECK-NEXT: li a2, 1
; CHECK-NEXT: vsetvli zero, a1, e8, m2, ta, ma
-; CHECK-NEXT: vlse8.v v8, (a0), a2, v0.t
+; CHECK-NEXT: vle8.v v8, (a0), v0.t
; CHECK-NEXT: ret
%res = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0.i64(<32 x i8> undef, ptr %p, i64 1, <32 x i1> %m)
ret <32 x i8> %res
More information about the llvm-commits
mailing list