[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)

JP Hafer via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 18 08:11:07 PDT 2025


https://github.com/jph-13 updated https://github.com/llvm/llvm-project/pull/141480

>From 2096ae8b01cb7a7296678b70263e79ee9d2c0c8c Mon Sep 17 00:00:00 2001
From: JP Hafer <jhafer at mathworks.com>
Date: Wed, 18 Jun 2025 11:10:46 -0400
Subject: [PATCH] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) ->
 scvtf(x, 2)

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability.
See: #91924

This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
---
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    | 137 ++++++++++++++++++
 llvm/lib/Target/AArch64/AArch64InstrInfo.td   |  52 +++++++
 .../CodeGen/AArch64/scvtf-div-mul-combine.ll  |  96 ++++++++++++
 3 files changed, 285 insertions(+)
 create mode 100644 llvm/test/CodeGen/AArch64/scvtf-div-mul-combine.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 009d69b2b9433..84fe3d8a34421 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
                                      unsigned Width);
 
+  template <unsigned FloatWidth>
+  bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
+    return SelectCVTFixedPosRecipOperandVec(N, FixedPos, FloatWidth);
+  }
+
+  bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
+                                        unsigned Width);
+
   bool SelectCMP_SWAP(SDNode *N);
 
   bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,129 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
   return true;
 }
 
+static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
+                                                         SDValue N,
+                                                         SDValue &FixedPos,
+                                                         unsigned FloatWidth,
+                                                         bool isReciprocal) {
+
+  // N must be a bitcast/nvcast of a vector float type.
+  if (!((N.getOpcode() == ISD::BITCAST ||
+         N.getOpcode() == AArch64ISD::NVCAST) &&
+        N.getValueType().isVector() && N.getValueType().isFloatingPoint())) {
+    return false;
+  }
+
+  if (N.getNumOperands() == 0)
+    return false;
+  SDValue ImmediateNode = N.getOperand(0);
+
+  bool isSplatConfirmed = false;
+
+  if (ImmediateNode.getOpcode() == AArch64ISD::DUP ||
+      ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR) {
+    // These opcodes inherently mean a splat.
+    isSplatConfirmed = true;
+  } else if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
+    // For BUILD_VECTOR, we must explicitly check if it's a constant splat.
+    BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
+    APInt SplatValue;
+    APInt SplatUndef;
+    unsigned SplatBitSize;
+    bool HasAnyUndefs;
+    if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+                             HasAnyUndefs)) {
+      isSplatConfirmed = true;
+    } else {
+      return false;
+    }
+  } else if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
+    // This implies that the DAG structure was (DUP (MOVIshift C)) or
+    // (BUILD_VECTOR (MOVIshift C)).
+    isSplatConfirmed = true;
+  } else {
+    return false;
+  }
+
+  // If we reached here, isSplatConfirmed should be true and ScalarSourceNode
+  // should be set. But just in case ...
+  if (!isSplatConfirmed)
+    return false;
+
+  // --- Extract the actual constant value ---
+  auto ScalarSourceNode = ImmediateNode.getOperand(0);
+  APFloat FVal(0.0);
+  if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
+    // Scalar source is a floating-point constant.
+    FVal = CFP->getValueAPF();
+  } else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
+    // Scalar source is an integer constant; interpret its bits as
+    // floating-point.
+    EVT FloatEltVT = N.getValueType().getVectorElementType();
+
+    if (FloatEltVT == MVT::f32) {
+      FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
+    } else if (FloatEltVT == MVT::f64) {
+      FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
+    } else if (FloatEltVT == MVT::f16) {
+      auto *ShiftAmountConst =
+          dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(1));
+
+      if (!ShiftAmountConst) {
+        return false;
+      }
+      APInt ImmediateVal = CI->getAPIntValue();
+      unsigned ShiftAmount = ShiftAmountConst->getAPIntValue().getZExtValue();
+      APInt EffectiveBits = ImmediateVal.trunc(16).shl(ShiftAmount);
+      FVal = APFloat(APFloat::IEEEhalf(), EffectiveBits);
+    } else {
+      // Unsupported floating-point element type.
+      return false;
+    }
+  } else {
+    // ScalarSourceNode is not a recognized constant type.
+    return false;
+  }
+
+  // --- Perform fixed-point reciprocal check and power-of-2 validation on FVal
+  // --- Normalize f16 to f32 if needed for consistent APFloat operations.
+  if (N.getValueType().getVectorElementType() == MVT::f16) {
+    bool ignored;
+    FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
+  }
+
+  // Handle reciprocal case.
+  if (isReciprocal) {
+    if (!FVal.getExactInverse(&FVal))
+      // Not an exact reciprocal, or reciprocal not a power of 2.
+      return false;
+  }
+
+  bool IsExact;
+  unsigned TargetIntBits =
+      N.getValueType().getVectorElementType().getSizeInBits();
+  APSInt IntVal(
+      TargetIntBits + 1,
+      true); // Use TargetIntBits + 1 for sufficient bits for conversion
+
+  FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
+
+  if (!IsExact || !IntVal.isPowerOf2())
+    return false;
+
+  unsigned FBits = IntVal.logBase2();
+  // FBits must be non-zero (implies actual scaling) and within the range
+  // supported by the instruction (typically 1 to 64 for AArch64 FCVTZS/FCVTZU).
+  // FloatWidth should ideally be the width of the *integer elements* in the
+  // vector (16, 32, 64).
+  if (FBits == 0 || FBits > FloatWidth)
+    return false;
+
+  // Set FixedPos to the extracted FBits as an i32 constant SDValue.
+  FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
+  return true;
+}
+
 bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
                                                    unsigned RegWidth) {
   return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4096,12 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
                                             true);
 }
 
+bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(
+    SDValue N, SDValue &FixedPos, unsigned FloatWidth) {
+  return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
+                                                      FloatWidth, true);
+}
+
 // Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
 // of the string and obtains the integer values from them and combines these
 // into a single value to be used in the MRS/MSR instruction.
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index f90f12b5ac3c7..5613128d0e9fd 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -8473,6 +8473,58 @@ def : Pat<(v8f16 (sint_to_fp (v8i16 (AArch64vashr_exact v8i16:$Vn, i32:$shift)))
           (SCVTFv8i16_shift $Vn, vecshiftR16:$shift)>;
 }
 
+// Select fmul(sitofp(x), C) where C is a constant reciprocal of a power of two.
+// For both scalar and vector inputs, if we have sitofp(X) * C (where C is
+// 1/2^N), this can be optimized to scvtf(X, 2^N).
+class fixedpoint_recip_vec_i16<ValueType FloatVT>
+    : ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<16>", []>;
+class fixedpoint_recip_vec_i32<ValueType FloatVT>
+    : ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<32>", []>;
+class fixedpoint_recip_vec_i64<ValueType FloatVT>
+    : ComplexPattern<FloatVT, 1, "SelectCVTFixedPosRecipOperandVec<64>", []>;
+def fixedpoint_recip_vec_xform : SDNodeXForm<timm, [{
+  // Suppress the unused variable warning by explicitly using N.
+  // The actual value needed for the pattern is already in V.
+  (void)N;
+  return V;
+}]>;
+
+def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
+def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
+def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;
+
+def fixedpoint_recip_v4f16_v4i16 : fixedpoint_recip_vec_i16<v4f16>;
+def fixedpoint_recip_v8f16_v8i16 : fixedpoint_recip_vec_i16<v8f16>;
+
+let Predicates = [HasNEON] in {
+  def : Pat<(v2f32(fmul(sint_to_fp(v2i32 V64:$Rn)),
+                fixedpoint_recip_v2f32_v2i32:$scale)),
+            (v2f32(SCVTFv2i32_shift(v2i32 V64:$Rn),
+                (fixedpoint_recip_vec_xform fixedpoint_recip_v2f32_v2i32:$scale)))>;
+
+  def : Pat<(v4f32(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
+                fixedpoint_recip_v4f32_v4i32:$scale)),
+            (v4f32(SCVTFv4i32_shift(v4i32 FPR128:$Rn),
+                (fixedpoint_recip_vec_xform fixedpoint_recip_v4f32_v4i32:$scale)))>;
+
+  def : Pat<(v2f64(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
+                fixedpoint_recip_v2f64_v2i64:$scale)),
+            (v2f64(SCVTFv2i64_shift(v2i64 FPR128:$Rn),
+                (fixedpoint_recip_vec_xform fixedpoint_recip_v2f64_v2i64:$scale)))>;
+}
+
+let Predicates = [HasNEON, HasFullFP16] in {
+  def : Pat<(v4f16(fmul(sint_to_fp(v4i16 V64:$Rn)),
+                fixedpoint_recip_v4f16_v4i16:$scale)),
+            (v4f16(SCVTFv4i16_shift(v4i16 V64:$Rn),
+                (fixedpoint_recip_vec_xform fixedpoint_recip_v4f16_v4i16:$scale)))>;
+
+  def : Pat<(v8f16(fmul(sint_to_fp(v8i16 FPR128:$Rn)),
+                fixedpoint_recip_v8f16_v8i16:$scale)),
+            (v8f16(SCVTFv8i16_shift(v8i16 FPR128:$Rn),
+                (fixedpoint_recip_vec_xform fixedpoint_recip_v8f16_v8i16:$scale)))>;
+}
+
 // X << 1 ==> X + X
 class SHLToADDPat<ValueType ty, RegisterClass regtype>
   : Pat<(ty (AArch64vshl (ty regtype:$Rn), (i32 1))),
diff --git a/llvm/test/CodeGen/AArch64/scvtf-div-mul-combine.ll b/llvm/test/CodeGen/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..c7e306aa69766
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,96 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -mattr=+fullfp16 -o - %s | FileCheck %s
+
+; This test file verifies that fdiv(sitofp(x), C)
+; where C is a constant power of two,
+; is optimized to scvtf(X, shift_amount).
+; This typically involves an implicit fdiv -> fmul_reciprocal transformation.
+
+; --- Scalar Tests ---
+
+; Scalar f32 (from i32)
+define float @test_f32_div_const(i32 %in) {
+; CHECK-LABEL: test_f32_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf s0, w0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fdiv float %vcvt.i, 16.0
+  ret float %div.i
+}
+
+; Scalar f64 (from i64)
+define double @test_f64_div_const(i64 %in) {
+; CHECK-LABEL: test_f64_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf d0, x0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp i64 %in to double
+  %div.i = fdiv double %vcvt.i, 16.0
+  ret double %div.i
+}
+
+; --- Vector Tests ---
+
+; Vector v2f32 (from v2i32)
+define <2 x float> @testv_v2f32_div_const(<2 x i32> %in) {
+; CHECK-LABEL: testv_v2f32_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf.2s v0, v0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+  ret <2 x float> %div.i
+}
+
+; Vector v4f32 (from v4i32)
+define <4 x float> @testv_v4f32_div_const(<4 x i32> %in) {
+; CHECK-LABEL: testv_v4f32_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf.4s v0, v0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <4 x i32> %in to <4 x float>
+  %div.i = fdiv <4 x float> %vcvt.i, <float 16.0, float 16.0, float 16.0, float 16.0>
+  ret <4 x float> %div.i
+}
+
+; Vector v2f64 (from v2i64)
+define <2 x double> @testv_v2f64_div_const(<2 x i64> %in) {
+; CHECK-LABEL: testv_v2f64_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf.2d v0, v0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <2 x i64> %in to <2 x double>
+  %div.i = fdiv <2 x double> %vcvt.i, <double 16.0, double 16.0>
+  ret <2 x double> %div.i
+}
+
+; --- f16 Tests (assuming fullfp16 is enabled) ---
+
+; Vector v4f16 (from v4i16)
+define <4 x half> @testv_v4f16_div_const(<4 x i16> %in) {
+; CHECK-LABEL: testv_v4f16_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf.4h v0, v0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <4 x i16> %in to <4 x half>
+  %div.i = fdiv <4 x half> %vcvt.i, <half 16.0, half 16.0, half 16.0, half 16.0> ; 16.0 in half-precision
+  ret <4 x half> %div.i
+}
+
+; Vector v8f16 (from v8i16)
+define <8 x half> @testv_v8f16_div_const(<8 x i16> %in) {
+; CHECK-LABEL: testv_v8f16_div_const:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf.8h v0, v0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <8 x i16> %in to <8 x half>
+  %div.i = fdiv <8 x half> %vcvt.i, <half 16.0, half 16.0, half 16.0, half 16.0, half 16.0, half 16.0, half 16.0, half 16.0> ; 16.0 in half-precision
+  ret <8 x half> %div.i
+}
\ No newline at end of file



More information about the llvm-commits mailing list