[llvm] 5352c79 - [RISCV] Add a combine to form masked.load from unit strided load (#65674)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 11 13:01:18 PDT 2023


Author: Philip Reames
Date: 2023-09-11T13:01:14-07:00
New Revision: 5352c7939806b19429f6c7918c400be43797b4c7

URL: https://github.com/llvm/llvm-project/commit/5352c7939806b19429f6c7918c400be43797b4c7
DIFF: https://github.com/llvm/llvm-project/commit/5352c7939806b19429f6c7918c400be43797b4c7.diff

LOG: [RISCV] Add a combine to form masked.load from unit strided load (#65674)

Add a DAG combine to form a masked.load from a masked_strided_load
intrinsic with stride equal to element size. This covers a couple of
extra test cases, and allows us to simplify and common some existing
code on the concat_vector(load, ...) to strided load transform.

This is the first in a mini-patch series to try and generalize our
strided load and gather matching to handle more cases, and common up
different approaches to the same problems in different places.

Added: 
    

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
    llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 4ff264635cda248..1158d14002e1d2a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -13371,27 +13371,6 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
       return SDValue();
   }
 
-  // A special case is if the stride is exactly the width of one of the loads,
-  // in which case it's contiguous and can be combined into a regular vle
-  // without changing the element size
-  if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
-      ConstStride && !Reversed &&
-      ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) {
-    MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
-        BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(),
-        VT.getStoreSize(), Align);
-    // Can't do the combine if the load isn't naturally aligned with the element
-    // type
-    if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(),
-                                            DAG.getDataLayout(), VT, *MMO))
-      return SDValue();
-
-    SDValue WideLoad = DAG.getLoad(VT, DL, BaseLd->getChain(), BasePtr, MMO);
-    for (SDValue Ld : N->ops())
-      DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), WideLoad);
-    return WideLoad;
-  }
-
   // Get the widened scalar type, e.g. v4i8 -> i64
   unsigned WideScalarBitWidth =
       BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements();
@@ -13406,20 +13385,22 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
     return SDValue();
 
-  MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT);
-  SDValue VL =
-      getDefaultVLOps(WideVecVT, ContainerVT, DL, DAG, Subtarget).second;
-  SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
+  SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
   SDValue IntID =
-      DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT());
+    DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
+                          Subtarget.getXLenVT());
   if (Reversed)
     Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
+  SDValue AllOneMask =
+    DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
+                 DAG.getConstant(1, DL, MVT::i1));
+
   SDValue Ops[] = {BaseLd->getChain(),
                    IntID,
-                   DAG.getUNDEF(ContainerVT),
+                   DAG.getUNDEF(WideVecVT),
                    BasePtr,
                    Stride,
-                   VL};
+                   AllOneMask};
 
   uint64_t MemSize;
   if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
@@ -13441,11 +13422,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
   for (SDValue Ld : N->ops())
     DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad);
 
-  // Note: Perform the bitcast before the convertFromScalableVector so we have
-  // balanced pairs of convertFromScalable/convertToScalable
-  SDValue Res = DAG.getBitcast(
-      TLI.getContainerForFixedLengthVector(VT.getSimpleVT()), StridedLoad);
-  return convertFromScalableVector(VT, Res, DAG, Subtarget);
+  return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
 }
 
 static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
@@ -14184,6 +14161,25 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
       // By default we do not combine any intrinsic.
     default:
       return SDValue();
+    case Intrinsic::riscv_masked_strided_load: {
+      MVT VT = N->getSimpleValueType(0);
+      auto *Load = cast<MemIntrinsicSDNode>(N);
+      SDValue PassThru = N->getOperand(2);
+      SDValue Base = N->getOperand(3);
+      SDValue Stride = N->getOperand(4);
+      SDValue Mask = N->getOperand(5);
+
+      // If the stride is equal to the element size in bytes,  we can use
+      // a masked.load.
+      const unsigned ElementSize = VT.getScalarStoreSize();
+      if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride);
+          StrideC && StrideC->getZExtValue() == ElementSize)
+        return DAG.getMaskedLoad(VT, DL, Load->getChain(), Base,
+                                 DAG.getUNDEF(XLenVT), Mask, PassThru,
+                                 Load->getMemoryVT(), Load->getMemOperand(),
+                                 ISD::UNINDEXED, ISD::NON_EXTLOAD);
+      return SDValue();
+    }
     case Intrinsic::riscv_vcpop:
     case Intrinsic::riscv_vcpop_mask:
     case Intrinsic::riscv_vfirst:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
index f7352b4659e5a9b..f3af177ac0ff27e 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
@@ -13010,9 +13010,8 @@ define <4 x i32> @mgather_broadcast_load_masked(ptr %base, <4 x i1> %m) {
 define <4 x i32> @mgather_unit_stride_load(ptr %base) {
 ; RV32-LABEL: mgather_unit_stride_load:
 ; RV32:       # %bb.0:
-; RV32-NEXT:    li a1, 4
 ; RV32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT:    vlse32.v v8, (a0), a1
+; RV32-NEXT:    vle32.v v8, (a0)
 ; RV32-NEXT:    ret
 ;
 ; RV64V-LABEL: mgather_unit_stride_load:
@@ -13082,9 +13081,8 @@ define <4 x i32> @mgather_unit_stride_load_with_offset(ptr %base) {
 ; RV32-LABEL: mgather_unit_stride_load_with_offset:
 ; RV32:       # %bb.0:
 ; RV32-NEXT:    addi a0, a0, 16
-; RV32-NEXT:    li a1, 4
 ; RV32-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
-; RV32-NEXT:    vlse32.v v8, (a0), a1
+; RV32-NEXT:    vle32.v v8, (a0)
 ; RV32-NEXT:    ret
 ;
 ; RV64V-LABEL: mgather_unit_stride_load_with_offset:

diff  --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
index c4653954302796e..06af0fc8971a543 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
@@ -55,9 +55,8 @@ define <32 x i8> @strided_load_i8_nostride(ptr %p, <32 x i1> %m) {
 ; CHECK-LABEL: strided_load_i8_nostride:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    li a1, 32
-; CHECK-NEXT:    li a2, 1
 ; CHECK-NEXT:    vsetvli zero, a1, e8, m2, ta, ma
-; CHECK-NEXT:    vlse8.v v8, (a0), a2, v0.t
+; CHECK-NEXT:    vle8.v v8, (a0), v0.t
 ; CHECK-NEXT:    ret
   %res = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0.i64(<32 x i8> undef, ptr %p, i64 1, <32 x i1> %m)
   ret <32 x i8> %res


        


More information about the llvm-commits mailing list