[llvm] 576085c - [SelectionDAG][RISCV] Add support for splitting vp.splice (#145184)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 23 08:55:16 PDT 2025
Author: Craig Topper
Date: 2025-06-23T08:55:13-07:00
New Revision: 576085c94855fc1536aa6343b272d9e87b7cb3ed
URL: https://github.com/llvm/llvm-project/commit/576085c94855fc1536aa6343b272d9e87b7cb3ed
DIFF: https://github.com/llvm/llvm-project/commit/576085c94855fc1536aa6343b272d9e87b7cb3ed.diff
LOG: [SelectionDAG][RISCV] Add support for splitting vp.splice (#145184)
Use a stack based expansion similar to the non-VP splice.
This code has been in our downstream for a while. I don't know how often
it is exercised though. Our downstream was missing clipping for the
immediate value to keep it in range of the stack object so I've added
it.
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index a541833684f38..8643ae9d78159 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -985,6 +985,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void SplitVecRes_VECTOR_INTERLEAVE(SDNode *N);
void SplitVecRes_VAARG(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
+ void SplitVecRes_VP_SPLICE(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo, SDValue &Hi);
void SplitVecRes_GET_ACTIVE_LANE_MASK(SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index c56cfec81acdd..32c5961195450 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -1382,6 +1382,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::UDIVFIXSAT:
SplitVecRes_FIX(N, Lo, Hi);
break;
+ case ISD::EXPERIMENTAL_VP_SPLICE:
+ SplitVecRes_VP_SPLICE(N, Lo, Hi);
+ break;
case ISD::EXPERIMENTAL_VP_REVERSE:
SplitVecRes_VP_REVERSE(N, Lo, Hi);
break;
@@ -3209,6 +3212,78 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
std::tie(Lo, Hi) = DAG.SplitVector(Load, DL);
}
+void DAGTypeLegalizer::SplitVecRes_VP_SPLICE(SDNode *N, SDValue &Lo,
+ SDValue &Hi) {
+ EVT VT = N->getValueType(0);
+ SDValue V1 = N->getOperand(0);
+ SDValue V2 = N->getOperand(1);
+ int64_t Imm = cast<ConstantSDNode>(N->getOperand(2))->getSExtValue();
+ SDValue Mask = N->getOperand(3);
+ SDValue EVL1 = N->getOperand(4);
+ SDValue EVL2 = N->getOperand(5);
+ SDLoc DL(N);
+
+ // Since EVL2 is considered the real VL it gets promoted during
+ // SelectionDAGBuilder. Promote EVL1 here if needed.
+ if (getTypeAction(EVL1.getValueType()) == TargetLowering::TypePromoteInteger)
+ EVL1 = ZExtPromotedInteger(EVL1);
+
+ Align Alignment = DAG.getReducedAlign(VT, /*UseABI=*/false);
+
+ EVT MemVT = EVT::getVectorVT(*DAG.getContext(), VT.getVectorElementType(),
+ VT.getVectorElementCount() * 2);
+ SDValue StackPtr = DAG.CreateStackTemporary(MemVT.getStoreSize(), Alignment);
+ EVT PtrVT = StackPtr.getValueType();
+ auto &MF = DAG.getMachineFunction();
+ auto FrameIndex = cast<FrameIndexSDNode>(StackPtr.getNode())->getIndex();
+ auto PtrInfo = MachinePointerInfo::getFixedStack(MF, FrameIndex);
+
+ MachineMemOperand *StoreMMO = DAG.getMachineFunction().getMachineMemOperand(
+ PtrInfo, MachineMemOperand::MOStore, LocationSize::beforeOrAfterPointer(),
+ Alignment);
+ MachineMemOperand *LoadMMO = DAG.getMachineFunction().getMachineMemOperand(
+ PtrInfo, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(),
+ Alignment);
+
+ SDValue StackPtr2 = TLI.getVectorElementPointer(DAG, StackPtr, VT, EVL1);
+
+ SDValue TrueMask = DAG.getBoolConstant(true, DL, Mask.getValueType(), VT);
+ SDValue StoreV1 = DAG.getStoreVP(DAG.getEntryNode(), DL, V1, StackPtr,
+ DAG.getUNDEF(PtrVT), TrueMask, EVL1,
+ V1.getValueType(), StoreMMO, ISD::UNINDEXED);
+
+ SDValue StoreV2 =
+ DAG.getStoreVP(StoreV1, DL, V2, StackPtr2, DAG.getUNDEF(PtrVT), TrueMask,
+ EVL2, V2.getValueType(), StoreMMO, ISD::UNINDEXED);
+
+ SDValue Load;
+ if (Imm >= 0) {
+ StackPtr = TLI.getVectorElementPointer(DAG, StackPtr, VT, N->getOperand(2));
+ Load = DAG.getLoadVP(VT, DL, StoreV2, StackPtr, Mask, EVL2, LoadMMO);
+ } else {
+ uint64_t TrailingElts = -Imm;
+ unsigned EltWidth = VT.getScalarSizeInBits() / 8;
+ SDValue TrailingBytes = DAG.getConstant(TrailingElts * EltWidth, DL, PtrVT);
+
+ // Make sure TrailingBytes doesn't exceed the size of vec1.
+ SDValue OffsetToV2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, StackPtr);
+ TrailingBytes =
+ DAG.getNode(ISD::UMIN, DL, PtrVT, TrailingBytes, OffsetToV2);
+
+ // Calculate the start address of the spliced result.
+ StackPtr2 = DAG.getNode(ISD::SUB, DL, PtrVT, StackPtr2, TrailingBytes);
+ Load = DAG.getLoadVP(VT, DL, StoreV2, StackPtr2, Mask, EVL2, LoadMMO);
+ }
+
+ EVT LoVT, HiVT;
+ std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT);
+ Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, LoVT, Load,
+ DAG.getVectorIdxConstant(0, DL));
+ Hi =
+ DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HiVT, Load,
+ DAG.getVectorIdxConstant(LoVT.getVectorMinNumElements(), DL));
+}
+
void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
SDValue &Hi) {
SDLoc DL(N);
diff --git a/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll b/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
index a4f91c3e7c99e..ffeb493989103 100644
--- a/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/vp-splice.ll
@@ -286,3 +286,144 @@ define <vscale x 2 x float> @test_vp_splice_nxv2f32_masked(<vscale x 2 x float>
%v = call <vscale x 2 x float> @llvm.experimental.vp.splice.nxv2f32(<vscale x 2 x float> %va, <vscale x 2 x float> %vb, i32 5, <vscale x 2 x i1> %mask, i32 %evla, i32 %evlb)
ret <vscale x 2 x float> %v
}
+
+define <vscale x 16 x i64> @test_vp_splice_nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 zeroext %evla, i32 zeroext %evlb) nounwind {
+; CHECK-LABEL: test_vp_splice_nxv16i64:
+; CHECK: # %bb.0:
+; CHECK-NEXT: csrr a4, vlenb
+; CHECK-NEXT: slli a5, a4, 1
+; CHECK-NEXT: addi a5, a5, -1
+; CHECK-NEXT: slli a1, a4, 3
+; CHECK-NEXT: mv a7, a2
+; CHECK-NEXT: bltu a2, a5, .LBB21_2
+; CHECK-NEXT: # %bb.1:
+; CHECK-NEXT: mv a7, a5
+; CHECK-NEXT: .LBB21_2:
+; CHECK-NEXT: addi sp, sp, -80
+; CHECK-NEXT: sd ra, 72(sp) # 8-byte Folded Spill
+; CHECK-NEXT: sd s0, 64(sp) # 8-byte Folded Spill
+; CHECK-NEXT: addi s0, sp, 80
+; CHECK-NEXT: csrr a5, vlenb
+; CHECK-NEXT: slli a5, a5, 5
+; CHECK-NEXT: sub sp, sp, a5
+; CHECK-NEXT: andi sp, sp, -64
+; CHECK-NEXT: add a5, a0, a1
+; CHECK-NEXT: slli a7, a7, 3
+; CHECK-NEXT: addi a6, sp, 64
+; CHECK-NEXT: mv t0, a2
+; CHECK-NEXT: bltu a2, a4, .LBB21_4
+; CHECK-NEXT: # %bb.3:
+; CHECK-NEXT: mv t0, a4
+; CHECK-NEXT: .LBB21_4:
+; CHECK-NEXT: vl8re64.v v24, (a5)
+; CHECK-NEXT: add a5, a6, a7
+; CHECK-NEXT: vl8re64.v v0, (a0)
+; CHECK-NEXT: vsetvli zero, t0, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v8, (a6)
+; CHECK-NEXT: sub a0, a2, a4
+; CHECK-NEXT: sltu a2, a2, a0
+; CHECK-NEXT: addi a2, a2, -1
+; CHECK-NEXT: and a0, a2, a0
+; CHECK-NEXT: add a6, a6, a1
+; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v16, (a6)
+; CHECK-NEXT: mv a0, a3
+; CHECK-NEXT: bltu a3, a4, .LBB21_6
+; CHECK-NEXT: # %bb.5:
+; CHECK-NEXT: mv a0, a4
+; CHECK-NEXT: .LBB21_6:
+; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v0, (a5)
+; CHECK-NEXT: sub a2, a3, a4
+; CHECK-NEXT: add a5, a5, a1
+; CHECK-NEXT: sltu a3, a3, a2
+; CHECK-NEXT: addi a3, a3, -1
+; CHECK-NEXT: and a2, a3, a2
+; CHECK-NEXT: addi a3, sp, 104
+; CHECK-NEXT: add a1, a3, a1
+; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v24, (a5)
+; CHECK-NEXT: vle64.v v16, (a1)
+; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT: vle64.v v8, (a3)
+; CHECK-NEXT: addi sp, s0, -80
+; CHECK-NEXT: ld ra, 72(sp) # 8-byte Folded Reload
+; CHECK-NEXT: ld s0, 64(sp) # 8-byte Folded Reload
+; CHECK-NEXT: addi sp, sp, 80
+; CHECK-NEXT: ret
+ %v = call <vscale x 16 x i64> @llvm.experimental.vp.splice.nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 5, <vscale x 16 x i1> splat (i1 1), i32 %evla, i32 %evlb)
+ ret <vscale x 16 x i64> %v
+}
+
+define <vscale x 16 x i64> @test_vp_splice_nxv16i64_negative_offset(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 zeroext %evla, i32 zeroext %evlb) nounwind {
+; CHECK-LABEL: test_vp_splice_nxv16i64_negative_offset:
+; CHECK: # %bb.0:
+; CHECK-NEXT: csrr a5, vlenb
+; CHECK-NEXT: slli a6, a5, 1
+; CHECK-NEXT: addi a6, a6, -1
+; CHECK-NEXT: slli a1, a5, 3
+; CHECK-NEXT: mv a4, a2
+; CHECK-NEXT: bltu a2, a6, .LBB22_2
+; CHECK-NEXT: # %bb.1:
+; CHECK-NEXT: mv a4, a6
+; CHECK-NEXT: .LBB22_2:
+; CHECK-NEXT: addi sp, sp, -80
+; CHECK-NEXT: sd ra, 72(sp) # 8-byte Folded Spill
+; CHECK-NEXT: sd s0, 64(sp) # 8-byte Folded Spill
+; CHECK-NEXT: addi s0, sp, 80
+; CHECK-NEXT: csrr a6, vlenb
+; CHECK-NEXT: slli a6, a6, 5
+; CHECK-NEXT: sub sp, sp, a6
+; CHECK-NEXT: andi sp, sp, -64
+; CHECK-NEXT: add a6, a0, a1
+; CHECK-NEXT: slli a4, a4, 3
+; CHECK-NEXT: addi a7, sp, 64
+; CHECK-NEXT: mv t0, a2
+; CHECK-NEXT: bltu a2, a5, .LBB22_4
+; CHECK-NEXT: # %bb.3:
+; CHECK-NEXT: mv t0, a5
+; CHECK-NEXT: .LBB22_4:
+; CHECK-NEXT: vl8re64.v v24, (a6)
+; CHECK-NEXT: add a6, a7, a4
+; CHECK-NEXT: vl8re64.v v0, (a0)
+; CHECK-NEXT: vsetvli zero, t0, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v8, (a7)
+; CHECK-NEXT: sub a0, a2, a5
+; CHECK-NEXT: sltu a2, a2, a0
+; CHECK-NEXT: addi a2, a2, -1
+; CHECK-NEXT: and a0, a2, a0
+; CHECK-NEXT: add a7, a7, a1
+; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v16, (a7)
+; CHECK-NEXT: mv a0, a3
+; CHECK-NEXT: bltu a3, a5, .LBB22_6
+; CHECK-NEXT: # %bb.5:
+; CHECK-NEXT: mv a0, a5
+; CHECK-NEXT: .LBB22_6:
+; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v0, (a6)
+; CHECK-NEXT: sub a2, a3, a5
+; CHECK-NEXT: add a5, a6, a1
+; CHECK-NEXT: sltu a3, a3, a2
+; CHECK-NEXT: addi a3, a3, -1
+; CHECK-NEXT: and a2, a3, a2
+; CHECK-NEXT: li a3, 8
+; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
+; CHECK-NEXT: vse64.v v24, (a5)
+; CHECK-NEXT: bltu a4, a3, .LBB22_8
+; CHECK-NEXT: # %bb.7:
+; CHECK-NEXT: li a4, 8
+; CHECK-NEXT: .LBB22_8:
+; CHECK-NEXT: sub a2, a6, a4
+; CHECK-NEXT: add a1, a2, a1
+; CHECK-NEXT: vle64.v v16, (a1)
+; CHECK-NEXT: vsetvli zero, a0, e64, m8, ta, ma
+; CHECK-NEXT: vle64.v v8, (a2)
+; CHECK-NEXT: addi sp, s0, -80
+; CHECK-NEXT: ld ra, 72(sp) # 8-byte Folded Reload
+; CHECK-NEXT: ld s0, 64(sp) # 8-byte Folded Reload
+; CHECK-NEXT: addi sp, sp, 80
+; CHECK-NEXT: ret
+ %v = call <vscale x 16 x i64> @llvm.experimental.vp.splice.nxv16i64(<vscale x 16 x i64> %va, <vscale x 16 x i64> %vb, i32 -1, <vscale x 16 x i1> splat (i1 1), i32 %evla, i32 %evlb)
+ ret <vscale x 16 x i64> %v
+}
More information about the llvm-commits
mailing list