[llvm] [RISCV] Fold vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR, -1, MASK) (PR #123123)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 16 08:56:45 PST 2025


https://github.com/topperc updated https://github.com/llvm/llvm-project/pull/123123

>From 8367247e1906a53ad9ca1667cf54423a61b9cb23 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 15 Jan 2025 13:15:55 -0800
Subject: [PATCH 1/2] [RISCV] Fold vp.store(vp.reverse(VAL), ADDR, MASK) ->
 vp.strided.store(VAL, NEW_ADDR, -1, MASK)

This was extracted from our downstream with only a quick re-review. It was originally written 1.5 years ago so there might be existing helper functions added since then that could simplify it.

Co-authored-by: Brandon Wu <brandon.wu at sifive.com>
---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp   | 78 ++++++++++++++++--
 .../RISCV/rvv/vp-combine-store-reverse.ll     | 81 +++++++++++++++++++
 2 files changed, 153 insertions(+), 6 deletions(-)
 create mode 100644 llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index b25cb128bce9fb..fc2f7781eae443 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1524,13 +1524,17 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
                          ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
   if (Subtarget.hasVInstructions())
-    setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
-                         ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
-                         ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
+    setTargetDAGCombine({ISD::FCOPYSIGN,    ISD::MGATHER,
+                         ISD::MSCATTER,     ISD::VP_GATHER,
+                         ISD::VP_SCATTER,   ISD::SRA,
+                         ISD::SRL,          ISD::SHL,
+                         ISD::STORE,        ISD::SPLAT_VECTOR,
                          ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
-                         ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
-                         ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
-                         ISD::INSERT_VECTOR_ELT, ISD::ABS, ISD::CTPOP});
+                         ISD::VP_STORE,     ISD::EXPERIMENTAL_VP_REVERSE,
+                         ISD::MUL,          ISD::SDIV,
+                         ISD::UDIV,         ISD::SREM,
+                         ISD::UREM,         ISD::INSERT_VECTOR_ELT,
+                         ISD::ABS,          ISD::CTPOP});
   if (Subtarget.hasVendorXTHeadMemPair())
     setTargetDAGCombine({ISD::LOAD, ISD::STORE});
   if (Subtarget.useRVVForFixedLengthVectors())
@@ -16229,6 +16233,66 @@ static SDValue performBITREVERSECombine(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(RISCVISD::BREV8, DL, VT, Src.getOperand(0));
 }
 
+static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
+                                      const RISCVSubtarget &Subtarget) {
+  // Fold:
+  //    vp.store(vp.reverse(VAL), ADDR, MASK) -> vp.strided.store(VAL, NEW_ADDR,
+  //    -1, MASK)
+  auto *VPStore = cast<VPStoreSDNode>(N);
+
+  if (VPStore->getValue().getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE)
+    return SDValue();
+
+  SDValue VPReverse = VPStore->getValue();
+  EVT ReverseVT = VPReverse->getValueType(0);
+
+  // We do not have a strided_store version for masks, and the evl of vp.reverse
+  // and vp.store should always be the same.
+  if (!ReverseVT.getVectorElementType().isByteSized() ||
+      VPStore->getVectorLength() != VPReverse.getOperand(2) ||
+      !VPReverse.hasOneUse())
+    return SDValue();
+
+  SDValue StoreMask = VPStore->getMask();
+  // If Mask is not all 1's, try to replace the mask if it's opcode
+  // is EXPERIMENTAL_VP_REVERSE and it's operand can be directly extracted.
+  if (!isOneOrOneSplat(StoreMask)) {
+    // Check if the mask of vp.reverse in vp.store are all 1's and
+    // the length of mask is same as evl.
+    if (StoreMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
+        !isOneOrOneSplat(StoreMask.getOperand(1)) ||
+        StoreMask.getOperand(2) != VPStore->getVectorLength())
+      return SDValue();
+    StoreMask = StoreMask.getOperand(0);
+  }
+
+  // Base = StoreAddr + (NumElem - 1) * ElemWidthByte
+  SDLoc DL(N);
+  MVT XLenVT = Subtarget.getXLenVT();
+  SDValue NumElem = VPStore->getVectorLength();
+  uint64_t ElemWidthByte = VPReverse.getValueType().getScalarSizeInBits() / 8;
+
+  SDValue Temp1 = DAG.getNode(ISD::SUB, DL, XLenVT, NumElem,
+                              DAG.getConstant(1, DL, XLenVT));
+  SDValue Temp2 = DAG.getNode(ISD::MUL, DL, XLenVT, Temp1,
+                              DAG.getConstant(ElemWidthByte, DL, XLenVT));
+  SDValue Base =
+      DAG.getNode(ISD::ADD, DL, XLenVT, VPStore->getBasePtr(), Temp2);
+  SDValue Stride = DAG.getConstant(0 - ElemWidthByte, DL, XLenVT);
+
+  MachineFunction &MF = DAG.getMachineFunction();
+  MachinePointerInfo PtrInfo(VPStore->getAddressSpace());
+  MachineMemOperand *MMO = MF.getMachineMemOperand(
+      PtrInfo, VPStore->getMemOperand()->getFlags(),
+      LocationSize::beforeOrAfterPointer(), VPStore->getAlign());
+
+  return DAG.getStridedStoreVP(
+      VPStore->getChain(), DL, VPReverse.getOperand(0), Base,
+      VPStore->getOffset(), Stride, StoreMask, VPStore->getVectorLength(),
+      VPStore->getMemoryVT(), MMO, VPStore->getAddressingMode(),
+      VPStore->isTruncatingStore(), VPStore->isCompressingStore());
+}
+
 // Convert from one FMA opcode to another based on whether we are negating the
 // multiply result and/or the accumulator.
 // NOTE: Only supports RVV operations with VL.
@@ -18372,6 +18436,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     }
     }
   }
+  case ISD::VP_STORE:
+    return performVP_STORECombine(N, DAG, Subtarget);
   case ISD::BITCAST: {
     assert(Subtarget.useRVVForFixedLengthVectors());
     SDValue N0 = N->getOperand(0);
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll b/llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll
new file mode 100644
index 00000000000000..4896a1367935ac
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-combine-store-reverse.ll
@@ -0,0 +1,81 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv64 -mattr=+f,+v -verify-machineinstrs < %s | FileCheck %s
+
+define void @test_store_reverse_combiner(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, i32 zeroext %evl) {
+; CHECK-LABEL: test_store_reverse_combiner:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a2, a1, 2
+; CHECK-NEXT:    add a0, a2, a0
+; CHECK-NEXT:    addi a0, a0, -4
+; CHECK-NEXT:    li a2, -4
+; CHECK-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT:    vsse32.v v8, (a0), a2
+; CHECK-NEXT:    ret
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  ret void
+}
+
+define void @test_store_mask_is_vp_reverse(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl) {
+; CHECK-LABEL: test_store_mask_is_vp_reverse:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    slli a2, a1, 2
+; CHECK-NEXT:    add a0, a2, a0
+; CHECK-NEXT:    addi a0, a0, -4
+; CHECK-NEXT:    li a2, -4
+; CHECK-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT:    vsse32.v v8, (a0), a2, v0.t
+; CHECK-NEXT:    ret
+  %storemask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl)
+  call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %storemask, i32 %evl)
+  ret void
+}
+
+define void @test_store_mask_not_all_one(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 zeroext %evl) {
+; CHECK-LABEL: test_store_mask_not_all_one:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e32, m1, ta, ma
+; CHECK-NEXT:    vid.v v9, v0.t
+; CHECK-NEXT:    addi a1, a1, -1
+; CHECK-NEXT:    vrsub.vx v9, v9, a1, v0.t
+; CHECK-NEXT:    vrgather.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vse32.v v10, (a0), v0.t
+; CHECK-NEXT:    ret
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> %notallones, i32 %evl)
+  call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %notallones, i32 %evl)
+  ret void
+}
+
+define void @test_different_evl(<vscale x 2 x float> %val, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %mask, i32 zeroext %evl1, i32 zeroext %evl2) {
+; CHECK-LABEL: test_different_evl:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e16, mf2, ta, ma
+; CHECK-NEXT:    vid.v v9
+; CHECK-NEXT:    addi a1, a1, -1
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vmv.v.i v10, 0
+; CHECK-NEXT:    vmerge.vim v10, v10, 1, v0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vid.v v11
+; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
+; CHECK-NEXT:    vrsub.vx v9, v9, a1
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vrsub.vx v11, v11, a1
+; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
+; CHECK-NEXT:    vrgatherei16.vv v12, v10, v9
+; CHECK-NEXT:    vmsne.vi v0, v12, 0
+; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
+; CHECK-NEXT:    vrgather.vv v9, v8, v11
+; CHECK-NEXT:    vsetvli zero, a2, e32, m1, ta, ma
+; CHECK-NEXT:    vse32.v v9, (a0), v0.t
+; CHECK-NEXT:    ret
+  %storemask = call <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1> %mask, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
+  %rev = call <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float> %val, <vscale x 2 x i1> splat (i1 true), i32 %evl1)
+  call void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float> %rev, <vscale x 2 x float>* %ptr, <vscale x 2 x i1> %storemask, i32 %evl2)
+  ret void
+}
+
+declare <vscale x 2 x float> @llvm.experimental.vp.reverse.nxv2f32(<vscale x 2 x float>, <vscale x 2 x i1>, i32)
+declare <vscale x 2 x i1> @llvm.experimental.vp.reverse.nxv2i1(<vscale x 2 x i1>, <vscale x 2 x i1>, i32)
+declare void @llvm.vp.store.nxv2f32.p0nxv2f32(<vscale x 2 x float>, <vscale x 2 x float>* nocapture, <vscale x 2 x i1>, i32)

>From 395376d54e5b9bc30b9cdcb32bf40cfcb47f0101 Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Thu, 16 Jan 2025 08:52:02 -0800
Subject: [PATCH 2/2] fixup! update comments

---
 llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 350a647fdfc9f0..6d3b1bf2051709 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16316,11 +16316,10 @@ static SDValue performVP_STORECombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   SDValue StoreMask = VPStore->getMask();
-  // If Mask is not all 1's, try to replace the mask if it's opcode
-  // is EXPERIMENTAL_VP_REVERSE and it's operand can be directly extracted.
+  // If Mask is all ones, then load is unmasked and can be reversed.
   if (!isOneOrOneSplat(StoreMask)) {
-    // Check if the mask of vp.reverse in vp.store are all 1's and
-    // the length of mask is same as evl.
+    // If the mask is not all ones, we can reverse the store if the mask was
+    // also reversed by an unmasked vp.reverse with the same EVL.
     if (StoreMask.getOpcode() != ISD::EXPERIMENTAL_VP_REVERSE ||
         !isOneOrOneSplat(StoreMask.getOperand(1)) ||
         StoreMask.getOperand(2) != VPStore->getVectorLength())



More information about the llvm-commits mailing list