[llvm] fc7a1ed - [RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK). (#123115)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 16 08:20:20 PST 2025


Author: Craig Topper
Date: 2025-01-16T08:20:17-08:00
New Revision: fc7a1ed0ba5f437bc7f262f562e83488225f0152

URL: https://github.com/llvm/llvm-project/commit/fc7a1ed0ba5f437bc7f262f562e83488225f0152
DIFF: https://github.com/llvm/llvm-project/commit/fc7a1ed0ba5f437bc7f262f562e83488225f0152.diff

LOG: [RISCV] Fold vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK). (#123115)

Co-authored-by: Brandon Wu <brandon.wu at sifive.com>

Added: 
    llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll

Modified: 
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b25cb128bce9fb..f8a5ccc3023a4d 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16229,6 +16229,68 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
 }
 
+static SDValue performVP_REVERSECombine(SDNode *N, SelectionDAG &DAG,
+                                        const RISCVSubtarget &Subtarget) {
+  // Fold:
+  //    vp.reverse(vp.load(ADDR, MASK)) -> vp.strided.load(ADDR, -1, MASK)
+
+  // Check if its first operand is a vp.load.
+  auto *VPLoad = dyn_cast<VPLoadSDNode>(N->getOperand(0));
+  if (!VPLoad)
+    return SDValue();
+
+  EVT LoadVT = VPLoad->getValueType(0);
+  // We do not have a strided_load version for masks, and the evl of vp.reverse
+  // and vp.load should always be the same.
+  if (!LoadVT.getVectorElementType().isByteSized() ||
+      N->getOperand(2) != VPLoad->getVectorLength() ||
+      !N->getOperand(0).hasOneUse())
+    return SDValue();
+
+  // Check if the mask of outer vp.reverse are all 1's.
+  if (!isOneOrOneSplat(N->getOperand(1)))
+    return SDValue();
+
+  SDValue LoadMask = VPLoad->getMask();
+  // If Mask is all ones, then load is unmasked and can be reversed.
+  if (!isOneOrOneSplat(LoadMask)) {
+    // If the mask is not all ones, we can reverse the load if the mask was also
+    // reversed by an unmasked vp.reverse with the same EVL.
+    if (LoadMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+        !isOneOrOneSplat(LoadMask.getOperand(1)) ||
+        LoadMask.getOperand(2) != VPLoad->getVectorLength())
+      return SDValue();
+    LoadMask = LoadMask.getOperand(0);
+  }
+
+  // Base = LoadAddr + (NumElem - 1) * ElemWidthByte
+  SDLoc DL(N);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue NumElem = VPLoad->getVectorLength();
+  uint64_t ElemWidthByte = VPLoad->getValueType(0).getScalarSizeInBits() / 8;
+
+  SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+                              DAG.getConstant(1, DL, XLenVT));
+  SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+                              DAG.getConstant(ElemWidthByte, DL, XLenVT));
+  SDValue Base = DAG.getNode(ISD::ADD, DL, XLenVT, VPLoad->getBasePtr(), Temp2);
+  SDValue Stride = DAG.getConstant(-ElemWidthByte, DL, XLenVT);
+
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachinePointerInfo PtrInfo(VPLoad->getAddressSpace());
+  MachineMemOperand *MMO = MF.getMachineMemOperand(
+      PtrInfo, VPLoad->getMemOperand()->getFlags(),
+      LocationSize::beforeOrAfterPointer(), VPLoad->getAlign());
+
+  SDValue Ret = DAG.getStridedLoadVP(
+      LoadVT, DL, VPLoad->getChain(), Base, Stride, LoadMask,
+      VPLoad->getVectorLength(), MMO, VPLoad->isExpandingLoad());
+
+  DAG.ReplaceAllUsesOfValueWith(SDValue(VPLoad, 1), Ret.getValue(1));
+
+  return Ret;
+}
+
 // Convert from one FMA opcode to another based on whether we are negating the
 // multiply result and/or the accumulator.
 // NOTE: Only supports RVV operations with VL.
@@ -18372,6 +18434,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     }
   }
+  case ISD::EXPERIMENTAL_VP_REVERSE:
+    return performVP_REVERSECombine(N, DAG, Subtarget);
   case ISD::BITCAST: {
     assert(Subtarget.useRVVForFixedLengthVectors());
     SDValue N0 = N->getOperand(0);

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll b/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll
new file mode 100644
index 00000000000000..50e26bd1410700
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-combine-reverse-load.ll
@@ -0,0 +1,79 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s
+
+define <vscale x 2 x float> @test_reverse_load_combiner(<vscale x 2 x float>* %ptr, i32 zeroext %evl) {
+; CHECK-LABEL: test_reverse_load_combiner:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a2, a1, 2
+; CHECK-NEXT:    add a0, a2, a0
+; CHECK-NEXT:    addi a0, a0, -4
+; CHECK-NEXT:    li a2, -4
+; CHECK-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT:    vlse32.v v8, (a0), a2
+; CHECK-NEXT:    ret
+  %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  ret <vscale x 2 x float> %rev
+}
+
+define <vscale x 2 x float> @test_load_mask_is_vp_reverse(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl) {
+; CHECK-LABEL: test_load_mask_is_vp_reverse:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a2, a1, 2
+; CHECK-NEXT:    add a0, a2, a0
+; CHECK-NEXT:    addi a0, a0, -4
+; CHECK-NEXT:    li a2, -4
+; CHECK-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT:    vlse32.v v8, (a0), a2, v0.t
+; CHECK-NEXT:    ret
+  %loadmask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %loadmask, i32 %evl)
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  ret <vscale x 2 x float> %rev
+}
+
+define <vscale x 2 x float> @test_load_mask_not_all_one(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 zeroext %evl) {
+; CHECK-LABEL: test_load_mask_not_all_one:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v9, (a0), v0.t
+; CHECK-NEXT:    vid.v v8, v0.t
+; CHECK-NEXT:    addi a1, a1, -1
+; CHECK-NEXT:    vrsub.vx v10, v8, a1, v0.t
+; CHECK-NEXT:    vrgather.vv v8, v9, v10, v0.t
+; CHECK-NEXT:    ret
+  %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 %evl)
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> %notallones, i32 %evl)
+  ret <vscale x 2 x float> %rev
+}
+
+define <vscale x 2 x float> @test_
diff erent_evl(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl1, i32 zeroext %evl2) {
+; CHECK-LABEL: test_
diff erent_evl:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    addi a3, a1, -1
+; CHECK-NEXT:    vsetvli zero, a1, e16, mf2, ta, ma
+; CHECK-NEXT:    vid.v v8
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv.v.i v9, 0
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vrsub.vx v8, v8, a3
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vmerge.vim v9, v9, 1, v0
+; CHECK-NEXT:    vrgatherei16.vv v10, v9, v8
+; CHECK-NEXT:    vmsne.vi v0, v10, 0
+; CHECK-NEXT:    vsetvli zero, a2, e32, m1, ta, ma
+; CHECK-NEXT:    vle32.v v9, (a0), v0.t
+; CHECK-NEXT:    addi a2, a2, -1
+; CHECK-NEXT:    vid.v v8
+; CHECK-NEXT:    vrsub.vx v10, v8, a2
+; CHECK-NEXT:    vrgather.vv v8, v9, v10
+; CHECK-NEXT:    ret
+  %loadmask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
+  %load = call <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* %ptr, <vscale x 2 x i1> %loadmask, i32 %evl2)
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %load, <vscale x 2 x i1> splat (i1 true), i32 %evl2)
+  ret <vscale x 2 x float> %rev
+}
+
+declare <vscale x 2 x float> @llvm.vp.load.nxv2f32.p0nxv2f32(<vscale x 2 x float>* nocapture, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1>, <vscale x 2 x i1>, i32)


        


More information about the llvm-commits mailing list