[llvm] 576085c - [SelectionDAG][RISCV] Add support for splitting vp.splice (#145184)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 23 08:55:16 PDT 2025


Author: Craig Topper
Date: 2025-06-23T08:55:13-07:00
New Revision: 576085c94855fc1536aa6343b272d9e87b7cb3ed

URL: https://github.com/llvm/llvm-project/commit/576085c94855fc1536aa6343b272d9e87b7cb3ed
DIFF: https://github.com/llvm/llvm-project/commit/576085c94855fc1536aa6343b272d9e87b7cb3ed.diff

LOG: [SelectionDAG][RISCV] Add support for splitting vp.splice (#145184)

Use a stack based expansion similar to the non-VP splice.

This code has been in our downstream for a while. I don't know how often
it is exercised though. Our downstream was missing clipping for the
immediate value to keep it in range of the stack object so I've added
it.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
    llvm/test/CodeGen/RISCV/rvv/vp-splice.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index a541833684f38..8643ae9d78159 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -985,6 +985,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void SplitVecRes_VECTOR_INTERLEAVE(SDNode *N);
   void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
+  void SplitVecRes_VP_SPLICE(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo, SDValue &Hi);
   void SplitVecRes_GET_ACTIVE_LANE_MASK(SDNode *N, SDValue &Lo, SDValue &Hi);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index c56cfec81acdd..32c5961195450 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1382,6 +1382,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
   case ISD::UDIVFIXSAT:
     SplitVecRes_FIX(N, Lo, Hi);
     break;
+  case ISD::EXPERIMENTAL_VP_SPLICE:
+    SplitVecRes_VP_SPLICE(N, Lo, Hi);
+    break;
   case ISD::EXPERIMENTAL_VP_REVERSE:
     SplitVecRes_VP_REVERSE(N, Lo, Hi);
     break;
@@ -3209,6 +3212,78 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
   std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
 }
 
+void DAGTypeLegalizer::SplitVecRes_VP_SPLICE(SDNode *N, SDValue &Lo,
+                                             SDValue &Hi) {
+  EVT VT = N->getValueType(0);
+  SDValue V1 = N->getOperand(0);
+  SDValue V2 = N->getOperand(1);
+  int64_t Imm = cast<ConstantSDNode>(N->getOperand(2))->getSExtValue();
+  SDValue Mask = N->getOperand(3);
+  SDValue EVL1 = N->getOperand(4);
+  SDValue EVL2 = N->getOperand(5);
+  SDLoc DL(N);
+
+  // Since EVL2 is considered the real VL it gets promoted during
+  // SelectionDAGBuilder. Promote EVL1 here if needed.
+  if (getTypeAction(EVL1.getValueType()) == TargetLowering::TypePromoteInteger)
+    EVL1 = ZExtPromotedInteger(EVL1);
+
+  Align Alignment = DAG.getReducedAlign(VT, /*UseABI=*/false);
+
+  EVT MemVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+                               VT.getVectorElementCount() * 2);
+  SDValue StackPtr = DAG.CreateStackTemporary(MemVT.getStoreSize(), Alignment);
+  EVT PtrVT = StackPtr.getValueType();
+  auto &MF = DAG.getMachineFunction();
+  auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+  auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
+
+  MachineMemOperand *StoreMMO = DAG.getMachineFunction().getMachineMemOperand(
+      PtrInfo, MachineMemOperand::MOStore, LocationSize::beforeOrAfterPointer(),
+      Alignment);
+  MachineMemOperand *LoadMMO = DAG.getMachineFunction().getMachineMemOperand(
+      PtrInfo, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
+      Alignment);
+
+  SDValue StackPtr2 = TLI.getVectorElementPointer(DAG, StackPtr, VT, EVL1);
+
+  SDValue TrueMask = DAG.getBoolConstant(true, DL, Mask.getValueType(), VT);
+  SDValue StoreV1 = DAG.getStoreVP(DAG.getEntryNode(), DL, V1, StackPtr,
+                                   DAG.getUNDEF(PtrVT), TrueMask, EVL1,
+                                   V1.getValueType(), StoreMMO, ISD::UNINDEXED);
+
+  SDValue StoreV2 =
+      DAG.getStoreVP(StoreV1, DL, V2, StackPtr2, DAG.getUNDEF(PtrVT), TrueMask,
+                     EVL2, V2.getValueType(), StoreMMO, ISD::UNINDEXED);
+
+  SDValue Load;
+  if (Imm >= 0) {
+    StackPtr = TLI.getVectorElementPointer(DAG, StackPtr, VT, N->getOperand(2));
+    Load = DAG.getLoadVP(VT, DL, StoreV2, StackPtr, Mask, EVL2, LoadMMO);
+  } else {
+    uint64_t TrailingElts = -Imm;
+    unsigned EltWidth = VT.getScalarSizeInBits() / 8;
+    SDValue TrailingBytes = DAG.getConstant(TrailingElts * EltWidth, DL, PtrVT);
+
+    // Make sure TrailingBytes doesn't exceed the size of vec1.
+    SDValue OffsetToV2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, StackPtr);
+    TrailingBytes =
+        DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, OffsetToV2);
+
+    // Calculate the start address of the spliced result.
+    StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
+    Load = DAG.getLoadVP(VT, DL, StoreV2, StackPtr2, Mask, EVL2, LoadMMO);
+  }
+
+  EVT LoVT, HiVT;
+  std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
+  Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoVT, Load,
+                   DAG.getVectorIdxConstant(0, DL));
+  Hi =
+      DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiVT, Load,
+                  DAG.getVectorIdxConstant(LoVT.getVectorMinNumElements(), DL));
+}
+
 void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
                                                       SDValue &Hi) {
   SDLoc DL(N);

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll b/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
index a4f91c3e7c99e..ffeb493989103 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
@@ -286,3 +286,144 @@ define <vscale x 2 x float> @test_vp_splice_nxv2f32_masked(<vscale x 2 x float>
   %v = call <vscale x 2 x float> @llvm.experimental.vp.splice.nxv2f32(<vscale x 2 x float> %va, <vscale x 2 x float> %vb, i32 5, <vscale x 2 x i1> %mask, i32 %evla, i32 %evlb)
   ret <vscale x 2 x float> %v
 }
+
+define <vscale x 16 x i64> @test_vp_splice_nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 zeroext %evla, i32 zeroext %evlb) nounwind {
+; CHECK-LABEL: test_vp_splice_nxv16i64:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a4, vlenb
+; CHECK-NEXT:    slli a5, a4, 1
+; CHECK-NEXT:    addi a5, a5, -1
+; CHECK-NEXT:    slli a1, a4, 3
+; CHECK-NEXT:    mv a7, a2
+; CHECK-NEXT:    bltu a2, a5, .LBB21_2
+; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    mv a7, a5
+; CHECK-NEXT:  .LBB21_2:
+; CHECK-NEXT:    addi sp, sp, -80
+; CHECK-NEXT:    sd ra, 72(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    sd s0, 64(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    addi s0, sp, 80
+; CHECK-NEXT:    csrr a5, vlenb
+; CHECK-NEXT:    slli a5, a5, 5
+; CHECK-NEXT:    sub sp, sp, a5
+; CHECK-NEXT:    andi sp, sp, -64
+; CHECK-NEXT:    add a5, a0, a1
+; CHECK-NEXT:    slli a7, a7, 3
+; CHECK-NEXT:    addi a6, sp, 64
+; CHECK-NEXT:    mv t0, a2
+; CHECK-NEXT:    bltu a2, a4, .LBB21_4
+; CHECK-NEXT:  # %bb.3:
+; CHECK-NEXT:    mv t0, a4
+; CHECK-NEXT:  .LBB21_4:
+; CHECK-NEXT:    vl8re64.v v24, (a5)
+; CHECK-NEXT:    add a5, a6, a7
+; CHECK-NEXT:    vl8re64.v v0, (a0)
+; CHECK-NEXT:    vsetvli zero, t0, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v8, (a6)
+; CHECK-NEXT:    sub a0, a2, a4
+; CHECK-NEXT:    sltu a2, a2, a0
+; CHECK-NEXT:    addi a2, a2, -1
+; CHECK-NEXT:    and a0, a2, a0
+; CHECK-NEXT:    add a6, a6, a1
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v16, (a6)
+; CHECK-NEXT:    mv a0, a3
+; CHECK-NEXT:    bltu a3, a4, .LBB21_6
+; CHECK-NEXT:  # %bb.5:
+; CHECK-NEXT:    mv a0, a4
+; CHECK-NEXT:  .LBB21_6:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v0, (a5)
+; CHECK-NEXT:    sub a2, a3, a4
+; CHECK-NEXT:    add a5, a5, a1
+; CHECK-NEXT:    sltu a3, a3, a2
+; CHECK-NEXT:    addi a3, a3, -1
+; CHECK-NEXT:    and a2, a3, a2
+; CHECK-NEXT:    addi a3, sp, 104
+; CHECK-NEXT:    add a1, a3, a1
+; CHECK-NEXT:    vsetvli zero, a2, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v24, (a5)
+; CHECK-NEXT:    vle64.v v16, (a1)
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT:    vle64.v v8, (a3)
+; CHECK-NEXT:    addi sp, s0, -80
+; CHECK-NEXT:    ld ra, 72(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    ld s0, 64(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 80
+; CHECK-NEXT:    ret
+  %v = call <vscale x 16 x i64> @llvm.experimental.vp.splice.nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 5, <vscale x 16 x i1> splat (i1 1), i32 %evla, i32 %evlb)
+  ret <vscale x 16 x i64> %v
+}
+
+define <vscale x 16 x i64> @test_vp_splice_nxv16i64_negative_offset(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 zeroext %evla, i32 zeroext %evlb) nounwind {
+; CHECK-LABEL: test_vp_splice_nxv16i64_negative_offset:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    csrr a5, vlenb
+; CHECK-NEXT:    slli a6, a5, 1
+; CHECK-NEXT:    addi a6, a6, -1
+; CHECK-NEXT:    slli a1, a5, 3
+; CHECK-NEXT:    mv a4, a2
+; CHECK-NEXT:    bltu a2, a6, .LBB22_2
+; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    mv a4, a6
+; CHECK-NEXT:  .LBB22_2:
+; CHECK-NEXT:    addi sp, sp, -80
+; CHECK-NEXT:    sd ra, 72(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    sd s0, 64(sp) # 8-byte Folded Spill
+; CHECK-NEXT:    addi s0, sp, 80
+; CHECK-NEXT:    csrr a6, vlenb
+; CHECK-NEXT:    slli a6, a6, 5
+; CHECK-NEXT:    sub sp, sp, a6
+; CHECK-NEXT:    andi sp, sp, -64
+; CHECK-NEXT:    add a6, a0, a1
+; CHECK-NEXT:    slli a4, a4, 3
+; CHECK-NEXT:    addi a7, sp, 64
+; CHECK-NEXT:    mv t0, a2
+; CHECK-NEXT:    bltu a2, a5, .LBB22_4
+; CHECK-NEXT:  # %bb.3:
+; CHECK-NEXT:    mv t0, a5
+; CHECK-NEXT:  .LBB22_4:
+; CHECK-NEXT:    vl8re64.v v24, (a6)
+; CHECK-NEXT:    add a6, a7, a4
+; CHECK-NEXT:    vl8re64.v v0, (a0)
+; CHECK-NEXT:    vsetvli zero, t0, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v8, (a7)
+; CHECK-NEXT:    sub a0, a2, a5
+; CHECK-NEXT:    sltu a2, a2, a0
+; CHECK-NEXT:    addi a2, a2, -1
+; CHECK-NEXT:    and a0, a2, a0
+; CHECK-NEXT:    add a7, a7, a1
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v16, (a7)
+; CHECK-NEXT:    mv a0, a3
+; CHECK-NEXT:    bltu a3, a5, .LBB22_6
+; CHECK-NEXT:  # %bb.5:
+; CHECK-NEXT:    mv a0, a5
+; CHECK-NEXT:  .LBB22_6:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v0, (a6)
+; CHECK-NEXT:    sub a2, a3, a5
+; CHECK-NEXT:    add a5, a6, a1
+; CHECK-NEXT:    sltu a3, a3, a2
+; CHECK-NEXT:    addi a3, a3, -1
+; CHECK-NEXT:    and a2, a3, a2
+; CHECK-NEXT:    li a3, 8
+; CHECK-NEXT:    vsetvli zero, a2, e64, m8, ta, ma
+; CHECK-NEXT:    vse64.v v24, (a5)
+; CHECK-NEXT:    bltu a4, a3, .LBB22_8
+; CHECK-NEXT:  # %bb.7:
+; CHECK-NEXT:    li a4, 8
+; CHECK-NEXT:  .LBB22_8:
+; CHECK-NEXT:    sub a2, a6, a4
+; CHECK-NEXT:    add a1, a2, a1
+; CHECK-NEXT:    vle64.v v16, (a1)
+; CHECK-NEXT:    vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT:    vle64.v v8, (a2)
+; CHECK-NEXT:    addi sp, s0, -80
+; CHECK-NEXT:    ld ra, 72(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    ld s0, 64(sp) # 8-byte Folded Reload
+; CHECK-NEXT:    addi sp, sp, 80
+; CHECK-NEXT:    ret
+  %v = call <vscale x 16 x i64> @llvm.experimental.vp.splice.nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 -1, <vscale x 16 x i1> splat (i1 1), i32 %evla, i32 %evlb)
+  ret <vscale x 16 x i64> %v
+}


        


More information about the llvm-commits mailing list