[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)

JP Hafer via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 10 10:57:11 PDT 2025


https://github.com/jph-13 updated https://github.com/llvm/llvm-project/pull/141480

>From ef6dbc60db790df522ceb6201dd92701d85c79f2 Mon Sep 17 00:00:00 2001
From: JP Hafer <jhafer at mathworks.com>
Date: Tue, 10 Jun 2025 13:53:00 -0400
Subject: [PATCH] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) ->
 scvtf(x, 2)

This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability.
See: #91924

This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
---
 .../Target/AArch64/AArch64ISelDAGToDAG.cpp    | 165 ++++++++++++++++++
 .../lib/Target/AArch64/AArch64InstrFormats.td | 122 +++++++++++++
 .../AArch64/scvtf-div-mul-combine.ll          |  71 ++++++++
 3 files changed, 358 insertions(+)
 create mode 100644 llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll

diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 11cb91fbe02d4..c9401eba6cd4f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
   bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
                                      unsigned Width);
 
+  template <unsigned RegWidth>
+  bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
+    return SelectCVTFixedPosRecipOperandVec(N, FixedPos, RegWidth);
+  }
+
+  bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
+                                        unsigned Width);
+
   bool SelectCMP_SWAP(SDNode *N);
 
   bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,156 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
   return true;
 }
 
+static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
+                                                         SDValue N,
+                                                         SDValue &FixedPos,
+                                                         unsigned RegWidth,
+                                                         bool isReciprocal) {
+
+  // Fast Path
+  if (N.getOpcode() == ISD::BUILD_VECTOR) {
+    // Match build_vector <float C, float C, ...>
+    unsigned NumElts = N.getNumOperands();
+    ConstantFPSDNode *First = dyn_cast<ConstantFPSDNode>(N.getOperand(0));
+    if (!First)
+      return false;
+
+    APFloat FVal = First->getValueAPF();
+    for (unsigned i = 1; i < NumElts; ++i) {
+      ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N.getOperand(i));
+      if (!CFP || !CFP->isExactlyValue(FVal))
+        return false;
+    }
+
+    if (N.getValueType().getVectorElementType() == MVT::f16) {
+      bool ignored;
+      FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
+                   &ignored);
+    }
+
+    if (isReciprocal) {
+      if (!FVal.getExactInverse(&FVal))
+        return false;
+    }
+
+    bool IsExact;
+    APSInt IntVal(RegWidth + 1, true);
+    FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
+
+    if (!IsExact || !IntVal.isPowerOf2())
+      return false;
+
+    unsigned FBits = IntVal.logBase2();
+    if (FBits == 0 || FBits > RegWidth)
+      return false;
+
+    FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
+    return true;
+  }
+
+  // N must be ISD::BITCAST and convert a vector integer type to a vector float
+  // type.
+  if (N.getOpcode() != ISD::BITCAST || !N.getValueType().isVector() ||
+      !N.getValueType().isFloatingPoint()) {
+    return false;
+  }
+  SDValue VectorIntNode = N.getOperand(
+      0); // This is the v2i32 node (t16 in your DAG), likely AArch64ISD::DUP
+
+  // The source of the bitcast must be a splat-forming operation from a
+  // constant.
+  SDValue ScalarSourceNode;
+  bool isSplatConfirmed = false;
+
+  if (VectorIntNode.getOpcode() == AArch64ISD::DUP) {
+    // AArch64ISD::DUP inherently means a splat of its scalar operand.
+    ScalarSourceNode = VectorIntNode.getOperand(0);
+    isSplatConfirmed = true;
+  } else if (VectorIntNode.getOpcode() == ISD::SPLAT_VECTOR) {
+    ScalarSourceNode = VectorIntNode.getOperand(0);
+    isSplatConfirmed = true;
+  } else if (VectorIntNode.getOpcode() == ISD::BUILD_VECTOR) {
+    // For ISD::BUILD_VECTOR, we must explicitly check if it's a constant splat.
+    BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(VectorIntNode.getNode());
+    APInt SplatValue;
+    APInt SplatUndef;
+    unsigned SplatBitSize;
+    bool HasAnyUndefs;
+    if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+                             HasAnyUndefs)) {
+      ScalarSourceNode = VectorIntNode.getOperand(0);
+      isSplatConfirmed = true;
+      return false; // BUILD_VECTOR was not a splat
+    }
+  } else {
+    // The node below the bitcast is not a recognized splat-forming node.
+    return false;
+  }
+
+  if (!isSplatConfirmed)
+    return false;
+
+  // ScalarSourceNode must be a constant (ISD::Constant or ISD::ConstantFP).
+  APFloat FVal(0.0);
+  if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
+    FVal = CFP->getValueAPF();
+  } else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
+    // If it's an integer constant, interpret its bits as a floating-point
+    // value. The target float element type is from
+    // N.getValueType().getVectorElementType()
+    EVT FloatEltVT = N.getValueType().getVectorElementType();
+
+    if (FloatEltVT == MVT::f32) {
+      FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
+    } else if (FloatEltVT == MVT::f64) {
+      FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
+    } else if (FloatEltVT == MVT::f16) {
+      FVal = APFloat(APFloat::IEEEhalf(), CI->getAPIntValue());
+    } else {
+      return false;
+    }
+  } else {
+    return false;
+  }
+
+  // 4. Perform fixed-point reciprocal check and power-of-2 validation on FVal.
+  // Normalize f16 to f32 if needed for consistent APFloat operations (if
+  // VecFloatVT was v2f16).
+  if (N.getValueType().getVectorElementType() == MVT::f16) {
+    bool ignored;
+    FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
+  }
+
+  // Handle reciprocal case if applicable for this fixed-point conversion.
+  if (isReciprocal) {
+    if (!FVal.getExactInverse(&FVal))
+      return false;
+  }
+
+  bool IsExact;
+  // RegWidth is the width of the floating point element type (e.g., 32 for f32,
+  // 64 for f64).
+  APSInt IntVal(RegWidth + 1,
+                true); // Use RegWidth + 1 for sufficient bits for conversion
+  FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
+
+  if (!IsExact || !IntVal.isPowerOf2())
+    return false;
+
+  unsigned FBits = IntVal.logBase2();
+  // FBits must be non-zero and within the expected range for the instruction's
+  // scale field. The scale field is 6 bits, so FBits must be <= 63.
+  if (FBits == 0 ||
+      FBits > RegWidth) // FBits should fit within the float's precision
+    return false;
+
+  // 5. Set FixedPos to the extracted FBits as an i32 constant SDValue.
+  // This is the i32 immediate that the SCVTF instruction's 'scale' operand
+  // expects.
+  FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
+  return true;
+}
+
 bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
                                                    unsigned RegWidth) {
   return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4123,13 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
                                             true);
 }
 
+bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(SDValue N,
+                                                           SDValue &FixedPos,
+                                                           unsigned RegWidth) {
+  return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
+                                                      RegWidth, true);
+}
+
 // Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
 // of the string and obtains the integer values from them and combines these
 // into a single value to be used in the MRS/MSR instruction.
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 9078748c14834..d8fa916beda6a 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -799,6 +799,22 @@ class fixedpoint_recip_i64<ValueType FloatVT>
   let DecoderMethod = "DecodeFixedPointScaleImm64";
 }
 
+class fixedpoint_recip_vec_i32<ValueType VecFloatVT>
+    : Operand<VecFloatVT>,
+      ComplexPattern<VecFloatVT, 1,
+                     "SelectCVTFixedPosRecipOperandVec<32>", [build_vector]> {
+  let EncoderMethod = "getFixedPointScaleOpValue";
+  let DecoderMethod = "DecodeFixedPointScaleImm32";
+}
+
+class fixedpoint_recip_vec_i64<ValueType VecFloatVT>
+    : Operand<VecFloatVT>,
+      ComplexPattern<VecFloatVT, 1,
+                     "SelectCVTFixedPosRecipOperandVec<64>", [build_vector]> {
+  let EncoderMethod = "getFixedPointScaleOpValue";
+  let DecoderMethod = "DecodeFixedPointScaleImm32";
+}
+
 def fixedpoint_recip_f16_i32 : fixedpoint_recip_i32<f16>;
 def fixedpoint_recip_f32_i32 : fixedpoint_recip_i32<f32>;
 def fixedpoint_recip_f64_i32 : fixedpoint_recip_i32<f64>;
@@ -807,6 +823,16 @@ def fixedpoint_recip_f16_i64 : fixedpoint_recip_i64<f16>;
 def fixedpoint_recip_f32_i64 : fixedpoint_recip_i64<f32>;
 def fixedpoint_recip_f64_i64 : fixedpoint_recip_i64<f64>;
 
+def fixedpoint_recip_v2f16_v2i32 : fixedpoint_recip_vec_i32<v2f16>;
+def fixedpoint_recip_v4f16_v4i32 : fixedpoint_recip_vec_i32<v4f16>;
+def fixedpoint_recip_v8f16_v8i32 : fixedpoint_recip_vec_i32<v8f16>;
+def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
+def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
+
+def fixedpoint_recip_v2f16_v2i64 : fixedpoint_recip_vec_i64<v2f16>;
+def fixedpoint_recip_v2f32_v2i64 : fixedpoint_recip_vec_i64<v2f32>;
+def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;
+
 def vecshiftR8 : Operand<i32>, ImmLeaf<i32, [{
   return (((uint32_t)Imm) > 0) && (((uint32_t)Imm) < 9);
 }]> {
@@ -5407,6 +5433,102 @@ class BaseIntegerToFPUnscaled<bits<2> rmode, bits<3> opcode,
   let Inst{4-0}   = Rd;
 }
 
+multiclass IntegerToFPVector<
+    bits<2> rmode, bits<3> opcode, string asm, RegisterClass srcRegClass,
+    RegisterClass dstRegClass, Operand imm_op, bits<1> q, bits<2> size,
+    bits<2> srcElemTypeBits, list<Predicate> preds> {
+
+  def _V : BaseIntegerToFP<rmode, opcode, srcRegClass, dstRegClass, imm_op,
+                           asm, []> {
+    let Inst{30} = q;
+    let Inst{23 -22} = size;
+    let Inst{18 -16} = 0b001;
+    let Inst{11 -10} = srcElemTypeBits;
+    let Predicates = preds;
+  }
+}
+
+// SCVTF (Signed Convert To Floating-Point) from Vector 32-bit Integer (vNi32)
+// defm SCVTFv2f16_v2i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
+//                                     FPR64, FPR64,
+//                                     fixedpoint_recip_v2f16_v2i32,
+//                                     0, 0b00, 0b10, [HasFullFP16]>;
+
+// defm SCVTFv4f16_v4i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
+//                                     FPR128, FPR128,
+//                                     fixedpoint_recip_v4f16_v4i32,
+//                                     1, 0b00, 0b10, [HasFullFP16]>;
+
+// defm SCVTFv8f16_v8i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
+//                                     FPR128, FPR128,
+//                                     fixedpoint_recip_v8f16_v8i32,
+//                                     1, 0b00, 0b10, [HasFullFP16]>;
+
+defm SCVTFv2f32_v2i32
+    : IntegerToFPVector<0b00, 0b010, "scvtf", FPR64, FPR64,
+                        fixedpoint_recip_v2f32_v2i32, 0, 0b01, 0b10, []>;
+
+defm SCVTFv4f32_v4i32
+    : IntegerToFPVector<0b00, 0b010, "scvtf", FPR128, FPR128,
+                        fixedpoint_recip_v4f32_v4i32, 1, 0b01, 0b10, []>;
+
+// SCVTF (Signed Convert To Floating-Point) from Vector 64-bit Integer (vNi64)
+// defm SCVTFv2f16_v2i64 : IntegerToFPVector<0b00, 0b010, "scvtf",
+//                                     FPR128, FPR128,
+//                                     fixedpoint_recip_v2f16_v2i64,
+//                                     1, 0b00, 0b11, [HasFullFP16]>;
+
+// defm SCVTFv2f32_v2i64 : IntegerToFPVector<0b00, 0b010, "scvtf",
+//                                     FPR128, FPR128,
+//                                     fixedpoint_recip_v2f32_v2i64,
+//                                     1, 0b01, 0b11, []>;
+
+defm SCVTFv2f64_v2i64
+    : IntegerToFPVector<0b00, 0b010, "scvtf", FPR128, FPR128,
+                        fixedpoint_recip_v2f64_v2i64, 1, 0b10, 0b11, []>;
+
+// def : Pat<
+//   (fmul (sint_to_fp (v2i32 V64:$Rn)),
+//         fixedpoint_recip_v2f32_v2i32:$scale),
+//   (SCVTFv2f16_v2i32_V V64:$Rn, fixedpoint_recip_v2f32_v2i32:$scale)
+// >;
+
+// def : Pat<
+//   (fmul (sint_to_fp (v4i32 FPR128:$Rn)),
+//         fixedpoint_recip_v4f16_v4i32:$scale),
+//   (SCVTFv4f16_v4i32_V FPR128:$Rn, fixedpoint_recip_v4f16_v4i32:$scale)
+// >;
+
+// def : Pat<
+//   (fmul (sint_to_fp (v8i32 FPR128:$Rn)),
+//         fixedpoint_recip_v8f16_v8i32:$scale),
+//   (SCVTFv8f16_v8i32_V FPR128:$Rn, fixedpoint_recip_v8f16_v8i32:$scale)
+// >;
+
+def : Pat<(fmul(sint_to_fp(v2i32 V64:$Rn)),
+              fixedpoint_recip_v2f32_v2i32:$scale),
+          (SCVTFv2f32_v2i32_V V64:$Rn, fixedpoint_recip_v2f32_v2i32:$scale)>;
+
+def : Pat<(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
+              fixedpoint_recip_v4f32_v4i32:$scale),
+          (SCVTFv4f32_v4i32_V FPR128:$Rn, fixedpoint_recip_v4f32_v4i32:$scale)>;
+
+// def : Pat<
+//   (fmul (sint_to_fp (v2i64 FPR128:$Rn)),
+//         fixedpoint_recip_v2f16_v2i64:$scale),
+//   (SCVTFv2f16_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f16_v2i64:$scale)
+// >;
+
+// def : Pat<
+//   (fmul (sint_to_fp (v2i64 FPR128:$Rn)),
+//         fixedpoint_recip_v2f32_v2i64:$scale),
+//   (SCVTFv2f32_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f32_v2i64:$scale)
+// >;
+
+def : Pat<(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
+              fixedpoint_recip_v2f64_v2i64:$scale),
+          (SCVTFv2f64_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f64_v2i64:$scale)>;
+
 multiclass IntegerToFP<bits<2> rmode, bits<3> opcode, string asm, SDPatternOperator node> {
   // Unscaled
   def UWHri: BaseIntegerToFPUnscaled<rmode, opcode, GPR32, FPR16, f16, asm, node> {
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..b31831559fb51
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,71 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -mattr=+fullfp16 -o - %s | FileCheck %s
+
+; This test file verifies the optimization for fmul(sitofp(x), C)
+; where C is a constant reciprocal of a power of two,
+; converting it to scvtf(X, 2^N).
+
+; --- Scalar Tests ---
+
+; Scalar f32 (from i32)
+define float @test_f32_div(i32 %in) {
+; CHECK-LABEL: test_f32_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf s0, w0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp i32 %in to float
+  %div.i = fdiv float %vcvt.i, 16.0
+  ret float %div.i
+}
+
+; Scalar f64 (from i64)
+define double @test_f64_div(i64 %in) {
+; CHECK-LABEL: test_f64_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf d0, x0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp i64 %in to double
+  %div.i = fdiv double %vcvt.i, 16.0
+  ret double %div.i
+}
+
+; --- Multi-Element Vector F32 Tests ---
+
+; Vector v2f32 (from v2i32)
+define <2 x float> @testv_v2f32_div(<2 x i32> %in) {
+; CHECK-LABEL: testv_v2f32_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf d0, d0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+  %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+  ret <2 x float> %div.i
+}
+
+; Vector v4f32 (from v4i32)
+define <4 x float> @testv_v4f32_div(<4 x i32> %in) {
+; CHECK-LABEL: testv_v4f32_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf q0, q0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <4 x i32> %in to <4 x float>
+  %div.i = fdiv <4 x float> %vcvt.i, <float 16.0, float 16.0, float 16.0, float 16.0>
+  ret <4 x float> %div.i
+}
+
+; --- Multi-Element Vector F64 Tests ---
+
+; Vector v2f64 (from v2i64)
+define <2 x double> @testv_v2f64_div(<2 x i64> %in) {
+; CHECK-LABEL: testv_v2f64_div:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:  scvtf q0, q0, #4
+; CHECK-NEXT:  ret
+entry:
+  %vcvt.i = sitofp <2 x i64> %in to <2 x double>
+  %div.i = fdiv <2 x double> %vcvt.i, <double 16.0, double 16.0>
+  ret <2 x double> %div.i
+}
\ No newline at end of file



More information about the llvm-commits mailing list