[llvm] [AArch64] Add vector expansion support for ISD::FCBRT when using ArmPL (PR #183750)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 2 02:39:36 PST 2026
https://github.com/david-arm updated https://github.com/llvm/llvm-project/pull/183750
>From 69d424c655c0430f92b45f6dcb8af7381247c072 Mon Sep 17 00:00:00 2001
From: David Sherwood <david.sherwood at arm.com>
Date: Fri, 27 Feb 2026 15:12:04 +0000
Subject: [PATCH] [AArch64] Add vector expansion support for ISD::FCBRT when
using ArmPL
This patch teaches the backend how to lower the FCBRT DAG node to the
vector math library function when using ArmPL. This is similar to what
we already do for llvm.pow/FPOW, however the only way to expose this
is via a DAG combine that converts
FPOW(<2 x double> %x, <2 x double> <double 1.0/3.0, double 1.0/3.0>)
into
FCBRT(<2 x double> %x)
when the appropriate fast math flags are present on the node. I've
updated the DAG combine to handle vector types and only perform
the transformation if there exists a vector library variant of
cbrt.
---
.../include/llvm/CodeGen/RuntimeLibcallUtil.h | 4 ++
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 12 +++-
.../SelectionDAG/LegalizeVectorOps.cpp | 10 +++
llvm/lib/CodeGen/TargetLoweringBase.cpp | 23 +++++++
llvm/lib/IR/RuntimeLibcalls.cpp | 62 ++++++++++++++-----
.../Target/AArch64/AArch64ISelLowering.cpp | 2 +
llvm/test/CodeGen/AArch64/veclib-llvm.pow.ll | 27 +++++++-
llvm/test/CodeGen/ARM/pow.ll | 16 ++++-
8 files changed, 133 insertions(+), 23 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/RuntimeLibcallUtil.h b/llvm/include/llvm/CodeGen/RuntimeLibcallUtil.h
index cc71c3206410a..898d19d1fc292 100644
--- a/llvm/include/llvm/CodeGen/RuntimeLibcallUtil.h
+++ b/llvm/include/llvm/CodeGen/RuntimeLibcallUtil.h
@@ -100,6 +100,10 @@ LLVM_ABI Libcall getPOWI(EVT RetVT);
/// UNKNOWN_LIBCALL if there is none.
LLVM_ABI Libcall getPOW(EVT RetVT);
+/// getCBRT - Return the CBRT_* value for the given types, or
+/// UNKNOWN_LIBCALL if there is none.
+LLVM_ABI Libcall getCBRT(EVT RetVT);
+
/// getLDEXP - Return the LDEXP_* value for the given types, or
/// UNKNOWN_LIBCALL if there is none.
LLVM_ABI Libcall getLDEXP(EVT RetVT);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 41e77e044d8a9..2f8d4990de861 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -19343,8 +19343,11 @@ SDValue DAGCombiner::visitFPOW(SDNode *N) {
// TODO: Since we're approximating, we don't need an exact 1/3 exponent.
// Some range near 1/3 should be fine.
EVT VT = N->getValueType(0);
- if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) ||
- (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) {
+ EVT ScalarVT = VT.getScalarType();
+ if ((ScalarVT == MVT::f32 &&
+ ExponentC->getValueAPF().isExactlyValue(1.0f / 3.0f)) ||
+ (ScalarVT == MVT::f64 &&
+ ExponentC->getValueAPF().isExactlyValue(1.0 / 3.0))) {
// pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0.
// pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf.
// pow(-val, 1/3) = nan; cbrt(-val) = -num.
@@ -19358,7 +19361,10 @@ SDValue DAGCombiner::visitFPOW(SDNode *N) {
// Do not create a cbrt() libcall if the target does not have it, and do not
// turn a pow that has lowering support into a cbrt() libcall.
- if (!DAG.getLibInfo().has(LibFunc_cbrt) ||
+ RTLIB::Libcall LC = RTLIB::getCBRT(VT);
+ bool HasLibCall =
+ DAG.getTargetLoweringInfo().getLibcallImpl(LC) != RTLIB::Unsupported;
+ if (!HasLibCall ||
(!DAG.getTargetLoweringInfo().isOperationExpand(ISD::FPOW, VT) &&
DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT)))
return SDValue();
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 0b1d5bfd078d8..6b02828712b63 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -423,6 +423,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::FLDEXP:
case ISD::FPOWI:
case ISD::FPOW:
+ case ISD::FCBRT:
case ISD::FLOG:
case ISD::FLOG2:
case ISD::FLOG10:
@@ -1327,6 +1328,15 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
// scalarizing.
break;
}
+ case ISD::FCBRT: {
+ RTLIB::Libcall LC = RTLIB::getCBRT(Node->getValueType(0));
+ if (tryExpandVecMathCall(Node, LC, Results))
+ return;
+
+ // TODO: Try to see if there's a narrower call available to use before
+ // scalarizing.
+ break;
+ }
case ISD::FMODF: {
EVT VT = Node->getValueType(0);
RTLIB::Libcall LC = RTLIB::getMODF(VT);
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index cc5a4219536ac..7ab6d82c5ccda 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -651,6 +651,29 @@ RTLIB::Libcall RTLIB::getREM(EVT VT) {
return getFPLibCall(VT, REM_F32, REM_F64, REM_F80, REM_F128, REM_PPCF128);
}
+RTLIB::Libcall RTLIB::getCBRT(EVT VT) {
+ // TODO: Tablegen should generate this function
+ if (VT.isVector()) {
+ if (!VT.isSimple())
+ return RTLIB::UNKNOWN_LIBCALL;
+ switch (VT.getSimpleVT().SimpleTy) {
+ case MVT::v4f32:
+ return RTLIB::CBRT_V4F32;
+ case MVT::v2f64:
+ return RTLIB::CBRT_V2F64;
+ case MVT::nxv4f32:
+ return RTLIB::CBRT_NXV4F32;
+ case MVT::nxv2f64:
+ return RTLIB::CBRT_NXV2F64;
+ default:
+ return RTLIB::UNKNOWN_LIBCALL;
+ }
+ }
+
+ return getFPLibCall(VT, CBRT_F32, CBRT_F64, CBRT_F80, CBRT_F128,
+ CBRT_PPCF128);
+}
+
RTLIB::Libcall RTLIB::getMODF(EVT RetVT) {
// TODO: Tablegen should generate this function
if (RetVT.isVector()) {
diff --git a/llvm/lib/IR/RuntimeLibcalls.cpp b/llvm/lib/IR/RuntimeLibcalls.cpp
index 15ecd53b5ed13..e9b5a72adb263 100644
--- a/llvm/lib/IR/RuntimeLibcalls.cpp
+++ b/llvm/lib/IR/RuntimeLibcalls.cpp
@@ -56,24 +56,37 @@ RuntimeLibcallsInfo::RuntimeLibcallsInfo(const Triple &TT,
setAvailable(Impl);
break;
case VectorLibrary::ArmPL:
- for (RTLIB::LibcallImpl Impl :
- {RTLIB::impl_armpl_svfmod_f32_x, RTLIB::impl_armpl_svfmod_f64_x,
- RTLIB::impl_armpl_vfmodq_f32, RTLIB::impl_armpl_vfmodq_f64,
- RTLIB::impl_armpl_vmodfq_f64, RTLIB::impl_armpl_vmodfq_f32,
- RTLIB::impl_armpl_svmodf_f64_x, RTLIB::impl_armpl_svmodf_f32_x,
- RTLIB::impl_armpl_vsincosq_f64, RTLIB::impl_armpl_vsincosq_f32,
- RTLIB::impl_armpl_svsincos_f64_x, RTLIB::impl_armpl_svsincos_f32_x,
- RTLIB::impl_armpl_vsincospiq_f32, RTLIB::impl_armpl_vsincospiq_f64,
- RTLIB::impl_armpl_svsincospi_f32_x,
- RTLIB::impl_armpl_svsincospi_f64_x, RTLIB::impl_armpl_svpow_f32_x,
- RTLIB::impl_armpl_svpow_f64_x, RTLIB::impl_armpl_vpowq_f32,
- RTLIB::impl_armpl_vpowq_f64})
+ for (RTLIB::LibcallImpl Impl : {RTLIB::impl_armpl_svfmod_f32_x,
+ RTLIB::impl_armpl_svfmod_f64_x,
+ RTLIB::impl_armpl_vfmodq_f32,
+ RTLIB::impl_armpl_vfmodq_f64,
+ RTLIB::impl_armpl_vmodfq_f64,
+ RTLIB::impl_armpl_vmodfq_f32,
+ RTLIB::impl_armpl_svmodf_f64_x,
+ RTLIB::impl_armpl_svmodf_f32_x,
+ RTLIB::impl_armpl_vsincosq_f64,
+ RTLIB::impl_armpl_vsincosq_f32,
+ RTLIB::impl_armpl_svsincos_f64_x,
+ RTLIB::impl_armpl_svsincos_f32_x,
+ RTLIB::impl_armpl_vsincospiq_f32,
+ RTLIB::impl_armpl_vsincospiq_f64,
+ RTLIB::impl_armpl_svsincospi_f32_x,
+ RTLIB::impl_armpl_svsincospi_f64_x,
+ RTLIB::impl_armpl_svpow_f32_x,
+ RTLIB::impl_armpl_svpow_f64_x,
+ RTLIB::impl_armpl_vpowq_f32,
+ RTLIB::impl_armpl_vpowq_f64,
+ RTLIB::impl_armpl_svcbrt_f32_x,
+ RTLIB::impl_armpl_svcbrt_f64_x,
+ RTLIB::impl_armpl_vcbrtq_f32,
+ RTLIB::impl_armpl_vcbrtq_f64})
setAvailable(Impl);
for (RTLIB::LibcallImpl Impl :
{RTLIB::impl_armpl_vfmodq_f32, RTLIB::impl_armpl_vfmodq_f64,
RTLIB::impl_armpl_vsincosq_f64, RTLIB::impl_armpl_vsincosq_f32,
- RTLIB::impl_armpl_vpowq_f32, RTLIB::impl_armpl_vpowq_f64})
+ RTLIB::impl_armpl_vpowq_f32, RTLIB::impl_armpl_vpowq_f64,
+ RTLIB::impl_armpl_vcbrtq_f32, RTLIB::impl_armpl_vcbrtq_f64})
setLibcallImplCallingConv(Impl, CallingConv::AArch64_VectorCall);
break;
default:
@@ -295,20 +308,33 @@ RuntimeLibcallsInfo::getFunctionTy(LLVMContext &Ctx, const Triple &TT,
case RTLIB::impl_armpl_vpowq_f32:
case RTLIB::impl_armpl_vpowq_f64:
case RTLIB::impl_armpl_svpow_f32_x:
- case RTLIB::impl_armpl_svpow_f64_x: {
+ case RTLIB::impl_armpl_svpow_f64_x:
+ case RTLIB::impl_armpl_vcbrtq_f32:
+ case RTLIB::impl_armpl_vcbrtq_f64:
+ case RTLIB::impl_armpl_svcbrt_f32_x:
+ case RTLIB::impl_armpl_svcbrt_f64_x: {
bool IsF32 = LibcallImpl == RTLIB::impl__ZGVnN4vv_fmodf ||
LibcallImpl == RTLIB::impl__ZGVsMxvv_fmodf ||
LibcallImpl == RTLIB::impl_armpl_svfmod_f32_x ||
LibcallImpl == RTLIB::impl_armpl_vfmodq_f32 ||
LibcallImpl == RTLIB::impl_armpl_vpowq_f32 ||
- LibcallImpl == RTLIB::impl_armpl_svpow_f32_x;
+ LibcallImpl == RTLIB::impl_armpl_svpow_f32_x ||
+ LibcallImpl == RTLIB::impl_armpl_vcbrtq_f32 ||
+ LibcallImpl == RTLIB::impl_armpl_svcbrt_f32_x;
bool IsScalable = LibcallImpl == RTLIB::impl__ZGVsMxvv_fmod ||
LibcallImpl == RTLIB::impl__ZGVsMxvv_fmodf ||
LibcallImpl == RTLIB::impl_armpl_svfmod_f32_x ||
LibcallImpl == RTLIB::impl_armpl_svfmod_f64_x ||
LibcallImpl == RTLIB::impl_armpl_svpow_f32_x ||
- LibcallImpl == RTLIB::impl_armpl_svpow_f64_x;
+ LibcallImpl == RTLIB::impl_armpl_svpow_f64_x ||
+ LibcallImpl == RTLIB::impl_armpl_svcbrt_f32_x ||
+ LibcallImpl == RTLIB::impl_armpl_svcbrt_f64_x;
+
+ bool HasOneArg = LibcallImpl == RTLIB::impl_armpl_vcbrtq_f32 ||
+ LibcallImpl == RTLIB::impl_armpl_vcbrtq_f64 ||
+ LibcallImpl == RTLIB::impl_armpl_svcbrt_f32_x ||
+ LibcallImpl == RTLIB::impl_armpl_svcbrt_f64_x;
AttrBuilder FuncAttrBuilder(Ctx);
@@ -322,7 +348,7 @@ RuntimeLibcallsInfo::getFunctionTy(LLVMContext &Ctx, const Triple &TT,
unsigned EC = IsF32 ? 4 : 2;
VectorType *VecTy = VectorType::get(ScalarTy, EC, IsScalable);
- SmallVector<Type *, 3> ArgTys = {VecTy, VecTy};
+ SmallVector<Type *, 3> ArgTys(HasOneArg ? 1 : 2, VecTy);
if (hasVectorMaskArgument(LibcallImpl))
ArgTys.push_back(VectorType::get(Type::getInt1Ty(Ctx), EC, IsScalable));
@@ -461,6 +487,8 @@ bool RuntimeLibcallsInfo::hasVectorMaskArgument(RTLIB::LibcallImpl Impl) {
case RTLIB::impl__ZGVsMxvv_fmodf:
case RTLIB::impl_armpl_svpow_f32_x:
case RTLIB::impl_armpl_svpow_f64_x:
+ case RTLIB::impl_armpl_svcbrt_f32_x:
+ case RTLIB::impl_armpl_svcbrt_f64_x:
return true;
default:
return false;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index eb6e9146e3839..c29bcd93c216e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1789,6 +1789,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FREM, VT, Expand);
setOperationAction(ISD::FPOW, VT, Expand);
setOperationAction(ISD::FPOWI, VT, Expand);
+ setOperationAction(ISD::FCBRT, VT, Expand);
setOperationAction(ISD::FCOS, VT, Expand);
setOperationAction(ISD::FSIN, VT, Expand);
setOperationAction(ISD::FSINCOS, VT, Expand);
@@ -2175,6 +2176,7 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
setOperationAction(ISD::FCOSH, VT, Expand);
setOperationAction(ISD::FTANH, VT, Expand);
setOperationAction(ISD::FPOW, VT, Expand);
+ setOperationAction(ISD::FCBRT, VT, Expand);
setOperationAction(ISD::FLOG, VT, Expand);
setOperationAction(ISD::FLOG2, VT, Expand);
setOperationAction(ISD::FLOG10, VT, Expand);
diff --git a/llvm/test/CodeGen/AArch64/veclib-llvm.pow.ll b/llvm/test/CodeGen/AArch64/veclib-llvm.pow.ll
index 1f25d41a86ebd..fdea2b5a5a1e3 100644
--- a/llvm/test/CodeGen/AArch64/veclib-llvm.pow.ll
+++ b/llvm/test/CodeGen/AArch64/veclib-llvm.pow.ll
@@ -54,7 +54,7 @@ define <4 x float> @test_pow_v4f32_025(<4 x float> %x) nounwind {
; ARMPL-NEXT: fsqrt v0.4s, v0.4s
; ARMPL-NEXT: fsqrt v0.4s, v0.4s
; ARMPL-NEXT: ret
- %result = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> splat (float 2.5e-01))
+ %result = call nsz ninf afn <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> splat (float 2.5e-01))
ret <4 x float> %result
}
@@ -67,6 +67,29 @@ define <vscale x 2 x double> @test_pow_nxv2f64_075(<vscale x 2 x double> %x) nou
; ARMPL-NEXT: fsqrt z1.d, p0/m, z0.d
; ARMPL-NEXT: fmul z0.d, z0.d, z1.d
; ARMPL-NEXT: ret
- %result = call fast <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %x, <vscale x 2 x double> splat (double 7.5e-01))
+ %result = call ninf afn <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %x, <vscale x 2 x double> splat (double 7.5e-01))
ret <vscale x 2 x double> %result
}
+
+define <4 x float> @test_pow_one_third_v4f32(<4 x float> %x) nounwind {
+; ARMPL-LABEL: test_pow_one_third_v4f32:
+; ARMPL: // %bb.0:
+; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT: bl armpl_vcbrtq_f32
+; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT: ret
+ %r = call nsz ninf nnan afn <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> splat (float 0x3FD5555560000000))
+ ret <4 x float> %r
+}
+
+define <vscale x 2 x double> @test_pow_one_third_nxv2f64(<vscale x 2 x double> %x) nounwind {
+; ARMPL-LABEL: test_pow_one_third_nxv2f64:
+; ARMPL: // %bb.0:
+; ARMPL-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
+; ARMPL-NEXT: ptrue p0.d
+; ARMPL-NEXT: bl armpl_svcbrt_f64_x
+; ARMPL-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
+; ARMPL-NEXT: ret
+ %r = call nsz ninf nnan afn <vscale x 2 x double> @llvm.pow.nxv2f64(<vscale x 2 x double> %x, <vscale x 2 x double> splat (double 0x3FD5555555555555))
+ ret <vscale x 2 x double> %r
+}
diff --git a/llvm/test/CodeGen/ARM/pow.ll b/llvm/test/CodeGen/ARM/pow.ll
index 8abc37de3d49a..afd3e4b44db75 100644
--- a/llvm/test/CodeGen/ARM/pow.ll
+++ b/llvm/test/CodeGen/ARM/pow.ll
@@ -1,4 +1,3 @@
-; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -mtriple=thumbv7m-linux-gnueabi | FileCheck %s --check-prefixes=ANY,SOFTFLOAT
; RUN: llc < %s -mtriple=thumbv8-linux-gnueabihf -mattr=neon | FileCheck %s --check-prefixes=ANY,HARDFLOAT
@@ -106,3 +105,18 @@ define <2 x double> @pow_v2f64_one_fourth_not_enough_fmf(<2 x double> %x) nounwi
ret <2 x double> %r
}
+define <4 x float> @test_pow_one_third_v4f32(<4 x float> %x) nounwind {
+; SOFTFLOAT-LABEL: test_pow_one_third_v4f32:
+; SOFTFLOAT: bl powf
+; SOFTFLOAT: bl powf
+; SOFTFLOAT: bl powf
+; SOFTFLOAT: bl powf
+;
+; HARDFLOAT-LABEL: test_pow_one_third_v4f32:
+; HARDFLOAT: bl cbrtf
+; HARDFLOAT: bl cbrtf
+; HARDFLOAT: bl cbrtf
+; HARDFLOAT: bl cbrtf
+ %r = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> splat (float 0x3FD5555560000000))
+ ret <4 x float> %r
+}
More information about the llvm-commits
mailing list