[llvm] AMDGPU: Custom lower fptrunc vectors for f32 -> f16 (PR #141883)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 6 06:23:16 PDT 2025
================
@@ -6900,14 +6903,35 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
DAG.getTargetConstant(0, DL, MVT::i32));
}
+SDValue SITargetLowering::splitFP_ROUNDVectorOp(SDValue Op,
+ SelectionDAG &DAG) const {
+ EVT DstVT = Op.getValueType();
+ unsigned NumElts = DstVT.getVectorNumElements();
+ assert(NumElts > 2 && isPowerOf2_32(NumElts));
+
+ auto [Lo, Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
+
+ SDLoc DL(Op);
+ unsigned Opc = Op.getOpcode();
+ SDValue Flags = Op.getOperand(1);
+ EVT HalfDstVT =
+ EVT::getVectorVT(*DAG.getContext(), DstVT.getScalarType(), NumElts / 2);
+ SDValue OpLo = DAG.getNode(Opc, DL, HalfDstVT, Lo, Flags);
+ SDValue OpHi = DAG.getNode(Opc, DL, HalfDstVT, Hi, Flags);
+
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, OpLo, OpHi);
+}
+
SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();
EVT DstVT = Op.getValueType();
- if (DstVT == MVT::v2f16) {
+ if (DstVT.isVector() && DstVT.getScalarType() == MVT::f16) {
----------------
arsenm wrote:
In a follow up can you look into extending this for v2bf16? I'm guessing in the ultimate expansion sequence, this will give a benefit even if the underlying v2 opcode isn't legal
https://github.com/llvm/llvm-project/pull/141883
More information about the llvm-commits
mailing list