[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)

JP Hafer via llvm-commits llvm-commits at lists.llvm.org
Fri May 30 11:21:18 PDT 2025


https://github.com/jph-13 updated https://github.com/llvm/llvm-project/pull/141480

>From ec47ab2da1fff94c38fb520fe7af2e503ca537b7 Mon Sep 17 00:00:00 2001
From: JP Hafer <jhafer at mathworks.com>
Date: Fri, 30 May 2025 14:19:15 -0400
Subject: [PATCH] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) ->
 scvtf(x, 2)

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability.
See: #91924

This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
---
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    |  56 +++++++
 .../Target/AArch64/AArch64ISelLowering.cpp    | 152 ++++++++++++++++++
 .../AArch64/scvtf-div-mul-combine.ll          |  93 +++++++++++
 3 files changed, 301 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 96fa85179d023..fc8df60f5bbeb 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -3907,6 +3907,62 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
                                                unsigned RegWidth,
                                                bool isReciprocal) {
   APFloat FVal(0.0);
+
+  if (N.getOpcode() == ISD::BUILD_VECTOR) {
+    EVT VT = N.getValueType();
+    EVT EltVT = VT.getVectorElementType();
+
+    unsigned NumElts = N.getNumOperands();
+    SDValue FirstOp = N.getOperand(0);
+
+    ConstantFPSDNode *FirstCN = dyn_cast<ConstantFPSDNode>(FirstOp);
+    if (!FirstCN)
+      return false;
+
+    APFloat FirstVal = FirstCN->getValueAPF();
+    if (EltVT == MVT::f16) {
+      bool ignored;
+      FirstVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
+    }
+
+    // Handle reciprocal case if needed
+    if (isReciprocal) {
+      if (!FirstVal.getExactInverse(&FirstVal))
+        return false;
+    }
+
+    bool IsExact;
+    APSInt IntVal(65, true);
+    FirstVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
+
+    if (!IsExact || !IntVal.isPowerOf2())
+      return false;
+
+    unsigned FBits = IntVal.logBase2();
+    if (FBits == 0 || FBits > RegWidth)
+      return false;
+
+    APInt FirstBits = FirstVal.bitcastToAPInt();
+
+    for (unsigned i = 1; i < NumElts; ++i) {
+      ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N.getOperand(i));
+      if (!CN)
+        return false;
+
+      APFloat ElemVal = CN->getValueAPF();
+      if (EltVT == MVT::f16) {
+        bool ignored;
+        ElemVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
+      }
+
+      if (ElemVal.bitcastToAPInt() != FirstBits)
+        return false;
+    }
+
+    FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
+    return true;
+  }
+
   if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
     FVal = CN->getValueAPF();
   else if (LoadSDNode *LN = dyn_cast<LoadSDNode>(N)) {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2800145cc603..be2dd8572adf0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1148,6 +1148,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
                        ISD::FP_TO_UINT_SAT, ISD::FADD});
 
+  // Try to fmul -> scvtf for powers of 2
+  setTargetDAGCombine(ISD::FMUL);
+
   // Try and combine setcc with csel
   setTargetDAGCombine(ISD::SETCC);
 
@@ -19250,6 +19253,153 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
   return FixConv;
 }
 
+/// Try to extract a log2 exponent from a uniform constant FP splat.
+/// Returns -1 if the value is not a power-of-two float.
+static int getUniformFPSplatLog2(const BuildVectorSDNode *BV,
+                                 unsigned MaxExponent) {
+  SDValue FirstElt = BV->getOperand(0);
+  if (!isa<ConstantFPSDNode>(FirstElt))
+    return -1;
+
+  const ConstantFPSDNode *FirstConst = cast<ConstantFPSDNode>(FirstElt);
+  const APFloat &FirstVal = FirstConst->getValueAPF();
+  const fltSemantics &Sem = FirstVal.getSemantics();
+
+  // Check all elements are the same
+  for (unsigned i = 1, e = BV->getNumOperands(); i != e; ++i) {
+    SDValue Elt = BV->getOperand(i);
+    if (!isa<ConstantFPSDNode>(Elt))
+      return -1;
+    const APFloat &Val = cast<ConstantFPSDNode>(Elt)->getValueAPF();
+    if (!Val.bitwiseIsEqual(FirstVal))
+      return -1;
+  }
+
+  // Reject zero, NaN, or negative values
+  if (FirstVal.isZero() || FirstVal.isNaN() || FirstVal.isNegative())
+    return -1;
+
+  // Get raw bits
+  APInt Bits = FirstVal.bitcastToAPInt();
+
+  int ExponentBias = 0;
+  unsigned ExponentBits = 0;
+  unsigned MantissaBits = 0;
+
+  if (&Sem == &APFloat::IEEEsingle()) {
+    ExponentBias = 127;
+    ExponentBits = 8;
+    MantissaBits = 23;
+  } else if (&Sem == &APFloat::IEEEdouble()) {
+    ExponentBias = 1023;
+    ExponentBits = 11;
+    MantissaBits = 52;
+  } else {
+    // Unsupported type
+    return -1;
+  }
+
+  // Mask out mantissa and check it's zero (i.e., power of two)
+  APInt MantissaMask = APInt::getLowBitsSet(Bits.getBitWidth(), MantissaBits);
+  if ((Bits & MantissaMask) != 0)
+    return -1;
+
+  // Extract exponent
+  unsigned ExponentShift = MantissaBits;
+  APInt ExponentMask = APInt::getBitsSet(Bits.getBitWidth(), ExponentShift,
+                                         ExponentShift + ExponentBits);
+  int Exponent = (Bits & ExponentMask).lshr(ExponentShift).getZExtValue();
+  int Log2 = ExponentBias - Exponent;
+
+  if (static_cast<unsigned>(Log2) > MaxExponent)
+    return -1;
+
+  return Log2;
+}
+
+/// Fold a floating-point multiply by power of two into fixed-point to
+/// floating-point conversion.
+static SDValue performFMulCombine(SDNode *N, SelectionDAG &DAG,
+                                  TargetLowering::DAGCombinerInfo &DCI,
+                                  const AArch64Subtarget *Subtarget) {
+
+  if (!Subtarget->hasNEON())
+    return SDValue();
+
+  // N is the FMUL node.
+  if (N->getOpcode() != ISD::FMUL)
+    return SDValue();
+
+  // SINT_TO_FP or UINT_TO_FP
+  SDValue Op = N->getOperand(0);
+  unsigned Opc = Op->getOpcode();
+  if (!Op.getValueType().isVector() || !Op.getValueType().isSimple() ||
+      !Op.getOperand(0).getValueType().isSimple() ||
+      (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP))
+    return SDValue();
+
+  SDValue ConstVec = N->getOperand(1);
+  if (!isa<BuildVectorSDNode>(ConstVec))
+    return SDValue();
+
+  MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType();
+  int32_t IntBits = IntTy.getSizeInBits();
+  if (IntBits != 16 && IntBits != 32 && IntBits != 64)
+    return SDValue();
+
+  MVT FloatTy = N->getSimpleValueType(0).getVectorElementType();
+  int32_t FloatBits = FloatTy.getSizeInBits();
+  if (FloatBits != 32 && FloatBits != 64)
+    return SDValue();
+
+  if (IntBits > FloatBits)
+    return SDValue();
+
+  BitVector UndefElements;
+  BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
+  int32_t IntrinsicC = getUniformFPSplatLog2(BV, FloatBits + 1);
+
+  // Handle cases where it's not a power of two, or is 2^0.
+  if (IntrinsicC == -1 || IntrinsicC == 0)
+    return SDValue();
+
+  // Check if IntrinsicC is within the valid range [1, FloatBits].
+  // The 's' value must be in [1, FloatBits].
+  if (IntrinsicC <= 0 || IntrinsicC > FloatBits)
+    return SDValue();
+
+  MVT ResTy;
+  unsigned NumLanes = Op.getValueType().getVectorNumElements();
+  switch (NumLanes) {
+  default:
+    return SDValue();
+  case 2:
+    ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64;
+    break;
+  case 4:
+    ResTy = FloatBits == 32 ? MVT::v4i32 : MVT::v4i64;
+    break;
+  }
+
+  if (ResTy == MVT::v4i64 && DCI.isBeforeLegalizeOps())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue ConvInput = Op.getOperand(0);
+  bool IsSigned = Opc == ISD::SINT_TO_FP;
+
+  if (IntBits < FloatBits)
+    ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
+                            ResTy, ConvInput);
+
+  unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp
+                                      : Intrinsic::aarch64_neon_vcvtfxu2fp;
+
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(),
+                     DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput,
+                     DAG.getConstant(IntrinsicC, DL, MVT::i32));
+}
+
 static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
                                const AArch64TargetLowering &TLI) {
   EVT VT = N->getValueType(0);
@@ -26693,6 +26843,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::FP_TO_SINT_SAT:
   case ISD::FP_TO_UINT_SAT:
     return performFpToIntCombine(N, DAG, DCI, Subtarget);
+  case ISD::FMUL:
+    return performFMulCombine(N, DAG, DCI, Subtarget);
   case ISD::OR:
     return performORCombine(N, DCI, Subtarget, *this);
   case ISD::AND:
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..27f1158f3a0b3
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,93 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+fullfp16 -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s
+
+; Scalar fdiv by 16.0 (f32)
+define float @tests_f32_div(i32 %in) {
+; CHECK-LABEL: tests_f32_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf   s0, w0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fdiv float %vcvt.i, 16.0
+  ret float %div.i
+}
+
+; Scalar fmul by (2^-4) (f32)
+define float @testsmul_f32_mul(i32 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul_f32_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf   s0, w0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fmul float %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4
+  ret float %div.i
+}
+
+; Vector fdiv by 16.0 (v2f32)
+define <2 x float> @testv_v2f32_div(<2 x i32> %in) {
+; CHECK-LABEL: testv_v2f32_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf.2s        v0, v0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+  ret <2 x float> %div.i
+}
+
+; Vector fmul by 2^-4 (v2f32)
+define <2 x float> @testvmul_v2f32_mul(<2 x i32> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul_v2f32_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf.2s        v0, v0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fmul <2 x float> %vcvt.i, splat (float 6.250000e-02) ; 0.0625 is 2^-4
+  ret <2 x float> %div.i
+}
+
+; Scalar fdiv by 16.0 (f64)
+define double @tests_f64_div(i64 %in) {
+; CHECK-LABEL: tests_f64_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf   d0, x0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp i64 %in to double
+  %div.i = fdiv double %vcvt.i, 1.600000e+01 ; 16.0 in double-precision
+  ret double %div.i
+}
+
+; Scalar fmul by (2^-4) (f64)
+define double @testsmul_f64_mul(i64 %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testsmul_f64_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf   d0, x0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp i64 %in to double
+  %div.i = fmul double %vcvt.i, 6.250000e-02 ; 0.0625 is 2^-4 in double-precision
+  ret double %div.i
+}
+
+; Vector fdiv by 16.0 (v2f64)
+define <2 x double> @testv_v2f64_div(<2 x i64> %in) {
+; CHECK-LABEL: testv_v2f64_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    scvtf.2d        v0, v0, #4
+; CHECK-NEXT:    ret
+entry:
+  %vcvt.i = sitofp <2 x i64> %in to <2 x double>
+  %div.i = fdiv <2 x double> %vcvt.i, <double 1.600000e+01, double 1.600000e+01>
+  ret <2 x double> %div.i
+}
+
+; Vector fmul by 2^-4 (v2f64)
+define <2 x double> @testvmul_v2f64_mul(<2 x i64> %in) local_unnamed_addr #0 {
+; CHECK-LABEL: testvmul_v2f64_mul:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    scvtf.2d        v0, v0, #4
+; CHECK-NEXT:    ret
+  %vcvt.i = sitofp <2 x i64> %in to <2 x double>
+  %div.i = fmul <2 x double> %vcvt.i, splat (double 6.250000e-02) ; 0.0625 is 2^-4 in double-precision
+  ret <2 x double> %div.i
+}
\ No newline at end of file



More information about the llvm-commits mailing list