[llvm] [RISCV] Combine vp_strided_load with zero stride to scalar load + splat (PR #97798)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 5 01:28:39 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

<details>
<summary>Changes</summary>

This is another version of #<!-- -->97394, but performs it as a DAGCombine instead of lowering so that we have a better chance of detecting non-zero EVLs before they are legalized.

The riscv_masked_strided_load already does this, but this combine also checks that the vector element type is legal. Currently a riscv_masked_strided_load with a zero stride of nxv1i64 will crash on rv32, but I'm hoping we can remove the masked_strided intrinsics and replace them with their VP counterparts.

RISCVISelDAGToDAG will lower splats of scalar loads back to zero strided loads anyway, so the test changes are to show how combining it to a scalar load can lead to some .vx patterns being matched.


---
Full diff: https://github.com/llvm/llvm-project/pull/97798.diff


2 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+31-1) 
- (modified) llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll (+66) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 022b8bcedda4d..24e384fa64f1a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1502,7 +1502,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
   if (Subtarget.hasVInstructions())
     setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
-                         ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
+                         ISD::VP_GATHER, ISD::VP_SCATTER,
+                         ISD::EXPERIMENTAL_VP_STRIDED_LOAD, ISD::SRA, ISD::SRL,
                          ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
                          ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
                          ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
@@ -17108,6 +17109,35 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                               VPSN->getMemOperand(), IndexType);
     break;
   }
+  case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: {
+    if (DCI.isBeforeLegalize())
+      break;
+    auto *Load = cast<VPStridedLoadSDNode>(N);
+    MVT VT = N->getSimpleValueType(0);
+
+    // Combine a zero strided load -> scalar load + splat
+    // The mask must be all ones and the EVL must be known to not be zero
+    if (!DAG.isKnownNeverZero(Load->getVectorLength()) ||
+        !Load->getOffset().isUndef() || !Load->isSimple() ||
+        !ISD::isConstantSplatVectorAllOnes(Load->getMask().getNode()) ||
+        !isNullConstant(Load->getStride()) ||
+        !isTypeLegal(VT.getVectorElementType()))
+      break;
+
+    SDValue ScalarLoad;
+    if (VT.isInteger())
+      ScalarLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, XLenVT, Load->getChain(),
+                                  Load->getBasePtr(), VT.getVectorElementType(),
+                                  Load->getMemOperand());
+    else
+      ScalarLoad = DAG.getLoad(VT.getVectorElementType(), DL, Load->getChain(),
+                               Load->getBasePtr(), Load->getMemOperand());
+    SDValue Splat = VT.isFixedLengthVector()
+                        ? DAG.getSplatBuildVector(VT, DL, ScalarLoad)
+                        : DAG.getSplatVector(VT, DL, ScalarLoad);
+    return DAG.getMergeValues({Splat, SDValue(ScalarLoad.getNode(), 1)}, DL);
+    break;
+  }
   case RISCVISD::SHL_VL:
     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
       return V;
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
index 4d3bced0bcb50..c19ecbb75d818 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
@@ -780,3 +780,69 @@ define <vscale x 16 x double> @strided_load_nxv17f64(ptr %ptr, i64 %stride, <vsc
 declare <vscale x 17 x double> @llvm.experimental.vp.strided.load.nxv17f64.p0.i64(ptr, i64, <vscale x 17 x i1>, i32)
 declare <vscale x 1 x double> @llvm.experimental.vector.extract.nxv1f64(<vscale x 17 x double> %vec, i64 %idx)
 declare <vscale x 16 x double> @llvm.experimental.vector.extract.nxv16f64(<vscale x 17 x double> %vec, i64 %idx)
+
+define <vscale x 1 x i64> @zero_strided_zero_evl(ptr %ptr, <vscale x 1 x i64> %v) {
+; CHECK-LABEL: zero_strided_zero_evl:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 0)
+  %res = add <vscale x 1 x i64> %v, %load
+  ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 1 x i64> @zero_strided_not_known_notzero_evl(ptr %ptr, <vscale x 1 x i64> %v, i32 zeroext %evl) {
+; CHECK-LABEL: zero_strided_not_known_notzero_evl:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e64, m1, ta, ma
+; CHECK-NEXT:    vlse64.v v9, (a0), zero
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 %evl)
+  %res = add <vscale x 1 x i64> %v, %load
+  ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 1 x i64> @zero_strided_known_notzero_avl(ptr %ptr, <vscale x 1 x i64> %v) {
+; CHECK-RV32-LABEL: zero_strided_known_notzero_avl:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetivli zero, 1, e64, m1, ta, ma
+; CHECK-RV32-NEXT:    vlse64.v v9, (a0), zero
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-RV32-NEXT:    vadd.vv v8, v8, v9
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: zero_strided_known_notzero_avl:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    ld a0, 0(a0)
+; CHECK-RV64-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
+; CHECK-RV64-NEXT:    vadd.vx v8, v8, a0
+; CHECK-RV64-NEXT:    ret
+  %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 1)
+  %res = add <vscale x 1 x i64> %v, %load
+  ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 2 x i64> @zero_strided_vec_length_avl(ptr %ptr, <vscale x 2 x i64> %v) vscale_range(2, 1024) {
+; CHECK-RV32-LABEL: zero_strided_vec_length_avl:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    csrr a1, vlenb
+; CHECK-RV32-NEXT:    srli a1, a1, 2
+; CHECK-RV32-NEXT:    vsetvli zero, a1, e64, m2, ta, ma
+; CHECK-RV32-NEXT:    vlse64.v v10, (a0), zero
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-RV32-NEXT:    vadd.vv v8, v8, v10
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: zero_strided_vec_length_avl:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    ld a0, 0(a0)
+; CHECK-RV64-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
+; CHECK-RV64-NEXT:    vadd.vx v8, v8, a0
+; CHECK-RV64-NEXT:    ret
+  %vscale = call i32 @llvm.vscale()
+  %veclen = mul i32 %vscale, 2
+  %load = call <vscale x 2 x i64> @llvm.experimental.vp.strided.load.nxv2i64.p0.i32(ptr %ptr, i32 0, <vscale x 2 x i1> splat (i1 true), i32 %veclen)
+  %res = add <vscale x 2 x i64> %v, %load
+  ret <vscale x 2 x i64> %res
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/97798


More information about the llvm-commits mailing list