[llvm] [RISCV] Sink vp.splat operands of VP intrinsic. (PR #133245)

via llvm-commits llvm-commits at lists.llvm.org
Thu Mar 27 05:53:37 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-risc-v

Author: MingYan (NexMing)

<details>
<summary>Changes</summary>

This patch introduces a `vp.splat` matching method for VP support by sinking the `vp.splat` operand of VP operations back into the same basic block as the VP operation, facilitating the generation of .vx instructions to reduce vector register pressure.

---
Full diff: https://github.com/llvm/llvm-project/pull/133245.diff


3 Files Affected:

- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp (+37) 
- (modified) llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (+2) 
- (modified) llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll (+67) 


``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
index f49ad2f5bd20e..5f20b63b0127f 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
@@ -2772,6 +2772,40 @@ bool RISCVTTIImpl::canSplatOperand(Instruction *I, int Operand) const {
   }
 }
 
+bool RISCVTTIImpl::tryToSinkVPSplat(VPIntrinsic *VPI,
+                                    SmallVectorImpl<Use *> &Ops) const {
+  Value *EVL = VPI->getVectorLengthParam();
+  if (!EVL)
+    return false;
+
+  for (auto &Op : VPI->operands()) {
+    auto *I = dyn_cast<Instruction>(Op.get());
+    if (!I || I->getParent() == VPI->getParent() ||
+        llvm::is_contained(Ops, &Op))
+      continue;
+
+    // We are looking for a vp.splat that can be sunk.
+    if (!match(I, m_Intrinsic<Intrinsic::experimental_vp_splat>(
+                      m_Value(), m_AllOnes(), m_Specific(EVL))))
+      continue;
+
+    // Don't sink i1 splats.
+    if (I->getType()->getScalarType()->isIntegerTy(1))
+      continue;
+
+    // All uses of the vp.splat should be sunk to avoid duplicating it across
+    // gpr and vector registers
+    if (all_of(I->uses(), [&](Use &U) {
+          auto *VPI = dyn_cast<VPIntrinsic>(U.getUser());
+          return VPI && VPI->getVectorLengthParam() == EVL &&
+                 canSplatOperand(VPI, U.getOperandNo());
+        }))
+      Ops.push_back(&Op);
+  }
+
+  return !Ops.empty();
+}
+
 /// Check if sinking \p I's operands to I's basic block is profitable, because
 /// the operands can be folded into a target instruction, e.g.
 /// splats of scalars can fold into vector instructions.
@@ -2823,6 +2857,9 @@ bool RISCVTTIImpl::isProfitableToSinkOperands(
   if (!ST->sinkSplatOperands())
     return false;
 
+  if (isa<VPIntrinsic>(I) && tryToSinkVPSplat(cast<VPIntrinsic>(I), Ops))
+    return true;
+
   for (auto OpIdx : enumerate(I->operands())) {
     if (!canSplatOperand(I, OpIdx.index()))
       continue;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
index 8ffe1b08d1e26..53fc76ef02372 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
@@ -454,6 +454,8 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
   /// able to splat the given operand.
   bool canSplatOperand(unsigned Opcode, int Operand) const;
 
+  bool tryToSinkVPSplat(VPIntrinsic *VPI, SmallVectorImpl<Use *> &Ops) const;
+
   bool isProfitableToSinkOperands(Instruction *I,
                                   SmallVectorImpl<Use *> &Ops) const;
 
diff --git a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll
index 9b794538c404e..663a50ed377a7 100644
--- a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll
@@ -5890,3 +5890,70 @@ vector.body:                                      ; preds = %vector.body, %entry
 for.cond.cleanup:                                 ; preds = %vector.body
   ret void
 }
+
+define dso_local void @sink_vp_splat(ptr nocapture %out, ptr nocapture %in) {
+; CHECK-LABEL: sink_vp_splat:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    li a2, 0
+; CHECK-NEXT:    li a3, 1024
+; CHECK-NEXT:    li a4, 3
+; CHECK-NEXT:    lui a5, 1
+; CHECK-NEXT:  .LBB129_1: # %vector.body
+; CHECK-NEXT:    # =>This Loop Header: Depth=1
+; CHECK-NEXT:    # Child Loop BB129_2 Depth 2
+; CHECK-NEXT:    vsetvli a6, a3, e32, m4, ta, ma
+; CHECK-NEXT:    slli a7, a2, 2
+; CHECK-NEXT:    vmv.v.i v8, 0
+; CHECK-NEXT:    add t0, a1, a7
+; CHECK-NEXT:    li t1, 1024
+; CHECK-NEXT:  .LBB129_2: # %for.body424
+; CHECK-NEXT:    # Parent Loop BB129_1 Depth=1
+; CHECK-NEXT:    # => This Inner Loop Header: Depth=2
+; CHECK-NEXT:    vle32.v v12, (t0)
+; CHECK-NEXT:    addi t1, t1, -1
+; CHECK-NEXT:    vmacc.vx v8, a4, v12
+; CHECK-NEXT:    add t0, t0, a5
+; CHECK-NEXT:    bnez t1, .LBB129_2
+; CHECK-NEXT:  # %bb.3: # %vector.latch
+; CHECK-NEXT:    # in Loop: Header=BB129_1 Depth=1
+; CHECK-NEXT:    add a7, a0, a7
+; CHECK-NEXT:    sub a3, a3, a6
+; CHECK-NEXT:    vse32.v v8, (a7)
+; CHECK-NEXT:    add a2, a2, a6
+; CHECK-NEXT:    bnez a3, .LBB129_1
+; CHECK-NEXT:  # %bb.4: # %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.latch, %entry
+  %scalar.ind = phi i64 [ 0, %entry ], [ %next.ind, %vector.latch ]
+  %trip.count = phi i64 [ 1024, %entry ], [ %remaining.trip.count, %vector.latch ]
+  %evl = tail call i32 @llvm.experimental.get.vector.length.i64(i64 %trip.count, i32 8, i1 true)
+  %vp.splat1 = tail call <vscale x 8 x i32> @llvm.experimental.vp.splat.nxv8i32(i32 0, <vscale x 8 x i1> splat(i1 true), i32 %evl)
+  %vp.splat2 = tail call <vscale x 8 x i32> @llvm.experimental.vp.splat.nxv8i32(i32 3, <vscale x 8 x i1> splat(i1 true), i32 %evl)
+  %evl.cast = zext i32 %evl to i64
+  br label %for.body424
+
+for.body424:                                      ; preds = %for.body424, %vector.body
+  %scalar.phi = phi i64 [ 0, %vector.body ], [ %indvars.iv.next27, %for.body424 ]
+  %vector.phi = phi <vscale x 8 x i32> [ %vp.splat1, %vector.body ], [ %vp.binary26, %for.body424 ]
+  %arrayidx625 = getelementptr inbounds [1024 x i32], ptr %in, i64 %scalar.phi, i64 %scalar.ind
+  %widen.load = tail call <vscale x 8 x i32> @llvm.vp.load.nxv8i32.p0(ptr %arrayidx625, <vscale x 8 x i1> splat (i1 true), i32 %evl)
+  %vp.binary = tail call <vscale x 8 x i32> @llvm.vp.mul.nxv8i32(<vscale x 8 x i32> %widen.load, <vscale x 8 x i32> %vp.splat2, <vscale x 8 x i1> splat (i1 true), i32 %evl)
+  %vp.binary26 = tail call <vscale x 8 x i32> @llvm.vp.add.nxv8i32(<vscale x 8 x i32> %vector.phi, <vscale x 8 x i32> %vp.binary, <vscale x 8 x i1> splat (i1 true), i32 %evl)
+  %indvars.iv.next27 = add nuw nsw i64 %scalar.phi, 1
+  %exitcond.not28 = icmp eq i64 %indvars.iv.next27, 1024
+  br i1 %exitcond.not28, label %vector.latch, label %for.body424
+
+vector.latch:                                     ; preds = %for.body424
+  %arrayidx830 = getelementptr inbounds i32, ptr %out, i64 %scalar.ind
+  tail call void @llvm.vp.store.nxv8i32.p0(<vscale x 8 x i32> %vp.binary26, ptr %arrayidx830, <vscale x 8 x i1> splat (i1 true), i32 %evl)
+  %remaining.trip.count = sub nuw i64 %trip.count, %evl.cast
+  %next.ind = add i64 %scalar.ind, %evl.cast
+  %6 = icmp eq i64 %remaining.trip.count, 0
+  br i1 %6, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:                                 ; preds = %vector.latch
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/133245


More information about the llvm-commits mailing list