[llvm] [LLVM][CodeGen] Add lowering for scalable vector bfloat operations. (PR #109803)

Matt Arsenault via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 24 08:05:59 PDT 2024


================
@@ -28466,6 +28520,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
   return convertFromScalableVector(DAG, VT, ScalableRes);
 }
 
+// Lower bfloat16 operations by upcasting to float32, performing the operation
+// and then downcasting the result back to bfloat16.
+SDValue AArch64TargetLowering::LowerBFloatOp(SDValue Op,
+                                             SelectionDAG &DAG) const {
+  SDLoc DL(Op);
+  EVT VT = Op.getValueType();
+  assert(isTypeLegal(VT) && VT.isScalableVector() && "Unexpected type!");
+
+  // Split the vector and try again.
+  if (VT == MVT::nxv8bf16) {
+    SmallVector<SDValue, 4> LoOps, HiOps;
+    for (const SDValue &V : Op->op_values()) {
+      LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
+      HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
+    }
+
+    unsigned Opc = Op.getOpcode();
+    SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
+    SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
+    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
+  }
+
+  // Promote to float and try again.
+  EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
+
+  SmallVector<SDValue, 4> Ops;
+  for (const SDValue &V : Op->op_values())
+    Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
+
+  SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
+  return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
+                     DAG.getIntPtrConstant(0, DL, true));
----------------
arsenm wrote:

This is very generic and I would hope doesn't need repeating in target code 

https://github.com/llvm/llvm-project/pull/109803


More information about the llvm-commits mailing list