[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 26 04:51:14 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: JP Hafer (jph-13)
<details>
<summary>Changes</summary>
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability. See: #<!-- -->91924
This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
---
Full diff: https://github.com/llvm/llvm-project/pull/141480.diff
2 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+152)
- (added) llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll (+47)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2800145cc603..bb094d9772c47 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
ISD::FP_TO_UINT_SAT, ISD::FADD});
+ // Try to fmul -> scvtf for powers of 2
+ setTargetDAGCombine(ISD::FMUL);
+
// Try and combine setcc with csel
setTargetDAGCombine(ISD::SETCC);
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
return FixConv;
}
+/// Try to extract a log2 exponent from a uniform constant FP splat.
+/// Returns -1 if the value is not a power-of-two float.
+static int getUniformFPSplatLog2(const BuildVectorSDNode *BV, unsigned MaxExponent) {
+ SDValue FirstElt = BV->getOperand(0);
+ if (!isa<ConstantFPSDNode>(FirstElt))
+ return -1;
+
+ const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
+ const APFloat &FirstVal = FirstConst->getValueAPF();
+ const fltSemantics &Sem = FirstVal.getSemantics();
+
+ // Check all elements are the same
+ for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
+ SDValue Elt = BV->getOperand(i);
+ if (!isa<ConstantFPSDNode>(Elt))
+ return -1;
+ const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
+ if (!Val.bitwiseIsEqual(FirstVal))
+ return -1;
+ }
+
+ // Reject zero, NaN, or negative values
+ if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
+ return -1;
+
+ // Get raw bits
+ APInt Bits = FirstVal.bitcastToAPInt();
+
+ int ExponentBias = 0;
+ unsigned ExponentBits = 0;
+ unsigned MantissaBits = 0;
+
+ if (&Sem == &APFloat::IEEEsingle()) {
+ ExponentBias = 127;
+ ExponentBits = 8;
+ MantissaBits = 23;
+ } else if (&Sem == &APFloat::IEEEdouble()) {
+ ExponentBias = 1023;
+ ExponentBits = 11;
+ MantissaBits = 52;
+ } else {
+ // Unsupported type
+ return -1;
+ }
+
+ // Mask out mantissa and check it's zero (i.e., power of two)
+ APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
+ if ((Bits & MantissaMask) != 0)
+ return -1;
+
+ // Extract exponent
+ unsigned ExponentShift = MantissaBits;
+ APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(),
+ ExponentShift,
+ ExponentShift + ExponentBits);
+ int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
+ int Log2 = ExponentBias - Exponent;
+
+ if (static_cast<unsigned>(Log2) > MaxExponent)
+ return -1;
+
+ return Log2;
+}
+
+/// Fold a floating-point multiply by power of two into fixed-point to
+/// floating-point conversion.
+static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64Subtarget *Subtarget) {
+
+ if (!Subtarget->hasNEON())
+ return SDValue();
+
+ // N is the FMUL node.
+ if (N->getOpcode() != ISD::FMUL)
+ return SDValue();
+
+ // SINT_TO_FP or UINT_TO_FP
+ SDValue Op = N->getOperand(0);
+ unsigned Opc = Op->getOpcode();
+ if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
+ !Op.getOperand(0).getValueType().isSimple() ||
+ (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
+ return SDValue();
+
+ SDValue ConstVec = N->getOperand(1);
+ if (!isa<BuildVectorSDNode>(ConstVec))
+ return SDValue();
+
+ MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
+ int32_t IntBits = IntTy.getSizeInBits();
+ if (IntBits != 16 && IntBits != 32 && IntBits != 64)
+ return SDValue();
+
+ MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
+ int32_t FloatBits = FloatTy.getSizeInBits();
+ if (FloatBits != 32 && FloatBits != 64)
+ return SDValue();
+
+ if (IntBits > FloatBits)
+ return SDValue();
+
+ BitVector UndefElements;
+ BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
+ int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
+
+ // Handle cases where it's not a power of two, or is 2^0.
+ if (IntrinsicC == -1 || IntrinsicC == 0)
+ return SDValue();
+
+ // Check if IntrinsicC is within the valid range [1, FloatBits].
+ // The 's' value must be in [1, FloatBits].
+ if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
+ return SDValue();
+
+ MVT ResTy;
+ unsigned NumLanes = Op.getValueType().getVectorNumElements();
+ switch (NumLanes) {
+ default:
+ return SDValue();
+ case 2:
+ ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
+ break;
+ case 4:
+ ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
+ break;
+ }
+
+ if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue ConvInput = Op.getOperand(0);
+ bool IsSigned = Opc == ISD::SINT_TO_FP;
+
+ if (IntBits < FloatBits)
+ ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+ ResTy, ConvInput);
+
+ unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
+ : Intrinsic::aarch64_neon_vcvtfxu2fp;
+
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+ DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
+ DAG.getConstant(IntrinsicC, DL, MVT::i32));
+}
+
static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64TargetLowering &TLI) {
EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
return performFpToIntCombine(N, DAG, DCI, Subtarget);
+ case ISD::FMUL:
+ return performFMulCombine(N, DAG, DCI, Subtarget);
case ISD::OR:
return performORCombine(N, DCI, Subtarget, *this);
case ISD::AND:
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..befddb165fcce
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,47 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
+
+; Test case 1: Scalar fdiv by 16.0
+define float @tests(i32 %in) {
+; CHECK-LABEL: tests:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fdiv float %vcvt.i, 16.0
+ ret float %div.i
+}
+
+; Test case 2: Scalar fmul by (2^-4)
+define float @testsmul(i32 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
+ ret float %div.i
+}
+
+; Test case 3: Vector fdiv by 16.0
+define <2 x float> @testv(<2 x i32> %in) {
+; CHECK-LABEL: testv:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf.2s v0, v0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+ ret <2 x float> %div.i
+}
+
+; Test case 4: Vector fmul by 2^-4
+define <2 x float> @testvmul(<2 x i32> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul:
+; CHECK: // %bb.0:
+; CHECK-NEXT: scvtf.2s v0, v0, #4
+; CHECK-NEXT: ret
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
+ ret <2 x float> %div.i
+}
\ No newline at end of file
``````````
</details>
https://github.com/llvm/llvm-project/pull/141480
More information about the llvm-commits
mailing list