[llvm] AMDGPU: Custom lower fptrunc vectors for f32 -> f16 (PR #141883)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 6 06:23:16 PDT 2025


================
@@ -6900,14 +6903,35 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
                            DAG.getTargetConstant(0, DL, MVT::i32));
 }
 
+SDValue SITargetLowering::splitFP_ROUNDVectorOp(SDValue Op,
+                                                SelectionDAG &DAG) const {
+  EVT DstVT = Op.getValueType();
+  unsigned NumElts = DstVT.getVectorNumElements();
+  assert(NumElts > 2 && isPowerOf2_32(NumElts));
+
+  auto [Lo, Hi] = DAG.SplitVectorOperand(Op.getNode(), 0);
+
+  SDLoc DL(Op);
+  unsigned Opc = Op.getOpcode();
+  SDValue Flags = Op.getOperand(1);
+  EVT HalfDstVT =
+      EVT::getVectorVT(*DAG.getContext(), DstVT.getScalarType(), NumElts / 2);
+  SDValue OpLo = DAG.getNode(Opc, DL, HalfDstVT, Lo, Flags);
+  SDValue OpHi = DAG.getNode(Opc, DL, HalfDstVT, Hi, Flags);
+
+  return DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, OpLo, OpHi);
+}
+
 SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
   SDValue Src = Op.getOperand(0);
   EVT SrcVT = Src.getValueType();
   EVT DstVT = Op.getValueType();
 
-  if (DstVT == MVT::v2f16) {
+  if (DstVT.isVector() && DstVT.getScalarType() == MVT::f16) {
----------------
arsenm wrote:

In a follow up can you look into extending this for v2bf16? I'm guessing in the ultimate expansion sequence, this will give a benefit even if the underlying v2 opcode isn't legal 

https://github.com/llvm/llvm-project/pull/141883


More information about the llvm-commits mailing list