[llvm] [LLVM][CodeGen] Add lowering for scalable vector bfloat operations. (PR #109803)
Matt Arsenault via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 24 08:05:59 PDT 2024
================
@@ -28466,6 +28520,40 @@ SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
return convertFromScalableVector(DAG, VT, ScalableRes);
}
+// Lower bfloat16 operations by upcasting to float32, performing the operation
+// and then downcasting the result back to bfloat16.
+SDValue AArch64TargetLowering::LowerBFloatOp(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ EVT VT = Op.getValueType();
+ assert(isTypeLegal(VT) && VT.isScalableVector() && "Unexpected type!");
+
+ // Split the vector and try again.
+ if (VT == MVT::nxv8bf16) {
+ SmallVector<SDValue, 4> LoOps, HiOps;
+ for (const SDValue &V : Op->op_values()) {
+ LoOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 0));
+ HiOps.push_back(DAG.getExtractSubvector(DL, MVT::nxv4bf16, V, 4));
+ }
+
+ unsigned Opc = Op.getOpcode();
+ SDValue SplitOpLo = DAG.getNode(Opc, DL, MVT::nxv4bf16, LoOps);
+ SDValue SplitOpHi = DAG.getNode(Opc, DL, MVT::nxv4bf16, HiOps);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
+ }
+
+ // Promote to float and try again.
+ EVT PromoteVT = VT.changeVectorElementType(MVT::f32);
+
+ SmallVector<SDValue, 4> Ops;
+ for (const SDValue &V : Op->op_values())
+ Ops.push_back(DAG.getNode(ISD::FP_EXTEND, DL, PromoteVT, V));
+
+ SDValue PromotedOp = DAG.getNode(Op.getOpcode(), DL, PromoteVT, Ops);
+ return DAG.getNode(ISD::FP_ROUND, DL, VT, PromotedOp,
+ DAG.getIntPtrConstant(0, DL, true));
----------------
arsenm wrote:
This is very generic and I would hope doesn't need repeating in target code
https://github.com/llvm/llvm-project/pull/109803
More information about the llvm-commits
mailing list