[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)
JP Hafer via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 16 07:24:26 PDT 2025
https://github.com/jph-13 updated https://github.com/llvm/llvm-project/pull/141480
>From ef6dbc60db790df522ceb6201dd92701d85c79f2 Mon Sep 17 00:00:00 2001
From: JP Hafer <jhafer at mathworks.com>
Date: Tue, 10 Jun 2025 13:53:00 -0400
Subject: [PATCH] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) ->
scvtf(x, 2)
This commit reintroduces the optimization in InstCombine that was previously removed due to limited applicability.
See: #91924
This update targets `fmul(sitofp(x), C)` where `C` is a constant reciprocal of a power of two. For both scalar and vector inputs, if we have `sitofp(X) * C` (where `C` is `1/2^N`), this can be optimized to `scvtf(X, 2^N)`. This eliminates the floating-point multiply by directly converting the integer to a scaled floating-point value.
---
.../Target/AArch64/AArch64ISelDAGToDAG.cpp | 165 ++++++++++++++++++
.../lib/Target/AArch64/AArch64InstrFormats.td | 122 +++++++++++++
.../AArch64/scvtf-div-mul-combine.ll | 71 ++++++++
3 files changed, 358 insertions(+)
create mode 100644 llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index 11cb91fbe02d4..c9401eba6cd4f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -487,6 +487,14 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
bool SelectCVTFixedPosRecipOperand(SDValue N, SDValue &FixedPos,
unsigned Width);
+ template <unsigned RegWidth>
+ bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos) {
+ return SelectCVTFixedPosRecipOperandVec(N, FixedPos, RegWidth);
+ }
+
+ bool SelectCVTFixedPosRecipOperandVec(SDValue N, SDValue &FixedPos,
+ unsigned Width);
+
bool SelectCMP_SWAP(SDNode *N);
bool SelectSVEAddSubImm(SDValue N, MVT VT, SDValue &Imm, SDValue &Shift);
@@ -3952,6 +3960,156 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
return true;
}
+static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
+ SDValue N,
+ SDValue &FixedPos,
+ unsigned RegWidth,
+ bool isReciprocal) {
+
+ // Fast Path
+ if (N.getOpcode() == ISD::BUILD_VECTOR) {
+ // Match build_vector <float C, float C, ...>
+ unsigned NumElts = N.getNumOperands();
+ ConstantFPSDNode *First = dyn_cast<ConstantFPSDNode>(N.getOperand(0));
+ if (!First)
+ return false;
+
+ APFloat FVal = First->getValueAPF();
+ for (unsigned i = 1; i < NumElts; ++i) {
+ ConstantFPSDNode *CFP = dyn_cast<ConstantFPSDNode>(N.getOperand(i));
+ if (!CFP || !CFP->isExactlyValue(FVal))
+ return false;
+ }
+
+ if (N.getValueType().getVectorElementType() == MVT::f16) {
+ bool ignored;
+ FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
+ &ignored);
+ }
+
+ if (isReciprocal) {
+ if (!FVal.getExactInverse(&FVal))
+ return false;
+ }
+
+ bool IsExact;
+ APSInt IntVal(RegWidth + 1, true);
+ FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
+
+ if (!IsExact || !IntVal.isPowerOf2())
+ return false;
+
+ unsigned FBits = IntVal.logBase2();
+ if (FBits == 0 || FBits > RegWidth)
+ return false;
+
+ FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
+ return true;
+ }
+
+ // N must be ISD::BITCAST and convert a vector integer type to a vector float
+ // type.
+ if (N.getOpcode() != ISD::BITCAST || !N.getValueType().isVector() ||
+ !N.getValueType().isFloatingPoint()) {
+ return false;
+ }
+ SDValue VectorIntNode = N.getOperand(
+ 0); // This is the v2i32 node (t16 in your DAG), likely AArch64ISD::DUP
+
+ // The source of the bitcast must be a splat-forming operation from a
+ // constant.
+ SDValue ScalarSourceNode;
+ bool isSplatConfirmed = false;
+
+ if (VectorIntNode.getOpcode() == AArch64ISD::DUP) {
+ // AArch64ISD::DUP inherently means a splat of its scalar operand.
+ ScalarSourceNode = VectorIntNode.getOperand(0);
+ isSplatConfirmed = true;
+ } else if (VectorIntNode.getOpcode() == ISD::SPLAT_VECTOR) {
+ ScalarSourceNode = VectorIntNode.getOperand(0);
+ isSplatConfirmed = true;
+ } else if (VectorIntNode.getOpcode() == ISD::BUILD_VECTOR) {
+ // For ISD::BUILD_VECTOR, we must explicitly check if it's a constant splat.
+ BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(VectorIntNode.getNode());
+ APInt SplatValue;
+ APInt SplatUndef;
+ unsigned SplatBitSize;
+ bool HasAnyUndefs;
+ if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+ HasAnyUndefs)) {
+ ScalarSourceNode = VectorIntNode.getOperand(0);
+ isSplatConfirmed = true;
+ return false; // BUILD_VECTOR was not a splat
+ }
+ } else {
+ // The node below the bitcast is not a recognized splat-forming node.
+ return false;
+ }
+
+ if (!isSplatConfirmed)
+ return false;
+
+ // ScalarSourceNode must be a constant (ISD::Constant or ISD::ConstantFP).
+ APFloat FVal(0.0);
+ if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
+ FVal = CFP->getValueAPF();
+ } else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
+ // If it's an integer constant, interpret its bits as a floating-point
+ // value. The target float element type is from
+ // N.getValueType().getVectorElementType()
+ EVT FloatEltVT = N.getValueType().getVectorElementType();
+
+ if (FloatEltVT == MVT::f32) {
+ FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
+ } else if (FloatEltVT == MVT::f64) {
+ FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
+ } else if (FloatEltVT == MVT::f16) {
+ FVal = APFloat(APFloat::IEEEhalf(), CI->getAPIntValue());
+ } else {
+ return false;
+ }
+ } else {
+ return false;
+ }
+
+ // 4. Perform fixed-point reciprocal check and power-of-2 validation on FVal.
+ // Normalize f16 to f32 if needed for consistent APFloat operations (if
+ // VecFloatVT was v2f16).
+ if (N.getValueType().getVectorElementType() == MVT::f16) {
+ bool ignored;
+ FVal.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
+ }
+
+ // Handle reciprocal case if applicable for this fixed-point conversion.
+ if (isReciprocal) {
+ if (!FVal.getExactInverse(&FVal))
+ return false;
+ }
+
+ bool IsExact;
+ // RegWidth is the width of the floating point element type (e.g., 32 for f32,
+ // 64 for f64).
+ APSInt IntVal(RegWidth + 1,
+ true); // Use RegWidth + 1 for sufficient bits for conversion
+ FVal.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact);
+
+ if (!IsExact || !IntVal.isPowerOf2())
+ return false;
+
+ unsigned FBits = IntVal.logBase2();
+ // FBits must be non-zero and within the expected range for the instruction's
+ // scale field. The scale field is 6 bits, so FBits must be <= 63.
+ if (FBits == 0 ||
+ FBits > RegWidth) // FBits should fit within the float's precision
+ return false;
+
+ // 5. Set FixedPos to the extracted FBits as an i32 constant SDValue.
+ // This is the i32 immediate that the SCVTF instruction's 'scale' operand
+ // expects.
+ FixedPos = CurDAG->getTargetConstant(FBits, SDLoc(N), MVT::i32);
+ return true;
+}
+
bool AArch64DAGToDAGISel::SelectCVTFixedPosOperand(SDValue N, SDValue &FixedPos,
unsigned RegWidth) {
return checkCVTFixedPointOperandWithFBits(CurDAG, N, FixedPos, RegWidth,
@@ -3965,6 +4123,13 @@ bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperand(SDValue N,
true);
}
+bool AArch64DAGToDAGISel::SelectCVTFixedPosRecipOperandVec(SDValue N,
+ SDValue &FixedPos,
+ unsigned RegWidth) {
+ return checkCVTFixedPointOperandWithFBitsForVectors(CurDAG, N, FixedPos,
+ RegWidth, true);
+}
+
// Inspects a register string of the form o0:op1:CRn:CRm:op2 gets the fields
// of the string and obtains the integer values from them and combines these
// into a single value to be used in the MRS/MSR instruction.
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 9078748c14834..d8fa916beda6a 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -799,6 +799,22 @@ class fixedpoint_recip_i64<ValueType FloatVT>
let DecoderMethod = "DecodeFixedPointScaleImm64";
}
+class fixedpoint_recip_vec_i32<ValueType VecFloatVT>
+ : Operand<VecFloatVT>,
+ ComplexPattern<VecFloatVT, 1,
+ "SelectCVTFixedPosRecipOperandVec<32>", [build_vector]> {
+ let EncoderMethod = "getFixedPointScaleOpValue";
+ let DecoderMethod = "DecodeFixedPointScaleImm32";
+}
+
+class fixedpoint_recip_vec_i64<ValueType VecFloatVT>
+ : Operand<VecFloatVT>,
+ ComplexPattern<VecFloatVT, 1,
+ "SelectCVTFixedPosRecipOperandVec<64>", [build_vector]> {
+ let EncoderMethod = "getFixedPointScaleOpValue";
+ let DecoderMethod = "DecodeFixedPointScaleImm32";
+}
+
def fixedpoint_recip_f16_i32 : fixedpoint_recip_i32<f16>;
def fixedpoint_recip_f32_i32 : fixedpoint_recip_i32<f32>;
def fixedpoint_recip_f64_i32 : fixedpoint_recip_i32<f64>;
@@ -807,6 +823,16 @@ def fixedpoint_recip_f16_i64 : fixedpoint_recip_i64<f16>;
def fixedpoint_recip_f32_i64 : fixedpoint_recip_i64<f32>;
def fixedpoint_recip_f64_i64 : fixedpoint_recip_i64<f64>;
+def fixedpoint_recip_v2f16_v2i32 : fixedpoint_recip_vec_i32<v2f16>;
+def fixedpoint_recip_v4f16_v4i32 : fixedpoint_recip_vec_i32<v4f16>;
+def fixedpoint_recip_v8f16_v8i32 : fixedpoint_recip_vec_i32<v8f16>;
+def fixedpoint_recip_v2f32_v2i32 : fixedpoint_recip_vec_i32<v2f32>;
+def fixedpoint_recip_v4f32_v4i32 : fixedpoint_recip_vec_i32<v4f32>;
+
+def fixedpoint_recip_v2f16_v2i64 : fixedpoint_recip_vec_i64<v2f16>;
+def fixedpoint_recip_v2f32_v2i64 : fixedpoint_recip_vec_i64<v2f32>;
+def fixedpoint_recip_v2f64_v2i64 : fixedpoint_recip_vec_i64<v2f64>;
+
def vecshiftR8 : Operand<i32>, ImmLeaf<i32, [{
return (((uint32_t)Imm) > 0) && (((uint32_t)Imm) < 9);
}]> {
@@ -5407,6 +5433,102 @@ class BaseIntegerToFPUnscaled<bits<2> rmode, bits<3> opcode,
let Inst{4-0} = Rd;
}
+multiclass IntegerToFPVector<
+ bits<2> rmode, bits<3> opcode, string asm, RegisterClass srcRegClass,
+ RegisterClass dstRegClass, Operand imm_op, bits<1> q, bits<2> size,
+ bits<2> srcElemTypeBits, list<Predicate> preds> {
+
+ def _V : BaseIntegerToFP<rmode, opcode, srcRegClass, dstRegClass, imm_op,
+ asm, []> {
+ let Inst{30} = q;
+ let Inst{23 -22} = size;
+ let Inst{18 -16} = 0b001;
+ let Inst{11 -10} = srcElemTypeBits;
+ let Predicates = preds;
+ }
+}
+
+// SCVTF (Signed Convert To Floating-Point) from Vector 32-bit Integer (vNi32)
+// defm SCVTFv2f16_v2i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
+// FPR64, FPR64,
+// fixedpoint_recip_v2f16_v2i32,
+// 0, 0b00, 0b10, [HasFullFP16]>;
+
+// defm SCVTFv4f16_v4i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
+// FPR128, FPR128,
+// fixedpoint_recip_v4f16_v4i32,
+// 1, 0b00, 0b10, [HasFullFP16]>;
+
+// defm SCVTFv8f16_v8i32 : IntegerToFPVector<0b00, 0b010, "scvtf",
+// FPR128, FPR128,
+// fixedpoint_recip_v8f16_v8i32,
+// 1, 0b00, 0b10, [HasFullFP16]>;
+
+defm SCVTFv2f32_v2i32
+ : IntegerToFPVector<0b00, 0b010, "scvtf", FPR64, FPR64,
+ fixedpoint_recip_v2f32_v2i32, 0, 0b01, 0b10, []>;
+
+defm SCVTFv4f32_v4i32
+ : IntegerToFPVector<0b00, 0b010, "scvtf", FPR128, FPR128,
+ fixedpoint_recip_v4f32_v4i32, 1, 0b01, 0b10, []>;
+
+// SCVTF (Signed Convert To Floating-Point) from Vector 64-bit Integer (vNi64)
+// defm SCVTFv2f16_v2i64 : IntegerToFPVector<0b00, 0b010, "scvtf",
+// FPR128, FPR128,
+// fixedpoint_recip_v2f16_v2i64,
+// 1, 0b00, 0b11, [HasFullFP16]>;
+
+// defm SCVTFv2f32_v2i64 : IntegerToFPVector<0b00, 0b010, "scvtf",
+// FPR128, FPR128,
+// fixedpoint_recip_v2f32_v2i64,
+// 1, 0b01, 0b11, []>;
+
+defm SCVTFv2f64_v2i64
+ : IntegerToFPVector<0b00, 0b010, "scvtf", FPR128, FPR128,
+ fixedpoint_recip_v2f64_v2i64, 1, 0b10, 0b11, []>;
+
+// def : Pat<
+// (fmul (sint_to_fp (v2i32 V64:$Rn)),
+// fixedpoint_recip_v2f32_v2i32:$scale),
+// (SCVTFv2f16_v2i32_V V64:$Rn, fixedpoint_recip_v2f32_v2i32:$scale)
+// >;
+
+// def : Pat<
+// (fmul (sint_to_fp (v4i32 FPR128:$Rn)),
+// fixedpoint_recip_v4f16_v4i32:$scale),
+// (SCVTFv4f16_v4i32_V FPR128:$Rn, fixedpoint_recip_v4f16_v4i32:$scale)
+// >;
+
+// def : Pat<
+// (fmul (sint_to_fp (v8i32 FPR128:$Rn)),
+// fixedpoint_recip_v8f16_v8i32:$scale),
+// (SCVTFv8f16_v8i32_V FPR128:$Rn, fixedpoint_recip_v8f16_v8i32:$scale)
+// >;
+
+def : Pat<(fmul(sint_to_fp(v2i32 V64:$Rn)),
+ fixedpoint_recip_v2f32_v2i32:$scale),
+ (SCVTFv2f32_v2i32_V V64:$Rn, fixedpoint_recip_v2f32_v2i32:$scale)>;
+
+def : Pat<(fmul(sint_to_fp(v4i32 FPR128:$Rn)),
+ fixedpoint_recip_v4f32_v4i32:$scale),
+ (SCVTFv4f32_v4i32_V FPR128:$Rn, fixedpoint_recip_v4f32_v4i32:$scale)>;
+
+// def : Pat<
+// (fmul (sint_to_fp (v2i64 FPR128:$Rn)),
+// fixedpoint_recip_v2f16_v2i64:$scale),
+// (SCVTFv2f16_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f16_v2i64:$scale)
+// >;
+
+// def : Pat<
+// (fmul (sint_to_fp (v2i64 FPR128:$Rn)),
+// fixedpoint_recip_v2f32_v2i64:$scale),
+// (SCVTFv2f32_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f32_v2i64:$scale)
+// >;
+
+def : Pat<(fmul(sint_to_fp(v2i64 FPR128:$Rn)),
+ fixedpoint_recip_v2f64_v2i64:$scale),
+ (SCVTFv2f64_v2i64_V FPR128:$Rn, fixedpoint_recip_v2f64_v2i64:$scale)>;
+
multiclass IntegerToFP<bits<2> rmode, bits<3> opcode, string asm, SDPatternOperator node> {
// Unscaled
def UWHri: BaseIntegerToFPUnscaled<rmode, opcode, GPR32, FPR16, f16, asm, node> {
diff --git a/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
new file mode 100644
index 0000000000000..b31831559fb51
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/AArch64/scvtf-div-mul-combine.ll
@@ -0,0 +1,71 @@
+; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -mattr=+fullfp16 -o - %s | FileCheck %s
+
+; This test file verifies the optimization for fmul(sitofp(x), C)
+; where C is a constant reciprocal of a power of two,
+; converting it to scvtf(X, 2^N).
+
+; --- Scalar Tests ---
+
+; Scalar f32 (from i32)
+define float @test_f32_div(i32 %in) {
+; CHECK-LABEL: test_f32_div:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf s0, w0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp i32 %in to float
+ %div.i = fdiv float %vcvt.i, 16.0
+ ret float %div.i
+}
+
+; Scalar f64 (from i64)
+define double @test_f64_div(i64 %in) {
+; CHECK-LABEL: test_f64_div:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf d0, x0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp i64 %in to double
+ %div.i = fdiv double %vcvt.i, 16.0
+ ret double %div.i
+}
+
+; --- Multi-Element Vector F32 Tests ---
+
+; Vector v2f32 (from v2i32)
+define <2 x float> @testv_v2f32_div(<2 x i32> %in) {
+; CHECK-LABEL: testv_v2f32_div:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf d0, d0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp <2 x i32> %in to <2 x float>
+ %div.i = fdiv <2 x float> %vcvt.i, <float 16.0, float 16.0>
+ ret <2 x float> %div.i
+}
+
+; Vector v4f32 (from v4i32)
+define <4 x float> @testv_v4f32_div(<4 x i32> %in) {
+; CHECK-LABEL: testv_v4f32_div:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf q0, q0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp <4 x i32> %in to <4 x float>
+ %div.i = fdiv <4 x float> %vcvt.i, <float 16.0, float 16.0, float 16.0, float 16.0>
+ ret <4 x float> %div.i
+}
+
+; --- Multi-Element Vector F64 Tests ---
+
+; Vector v2f64 (from v2i64)
+define <2 x double> @testv_v2f64_div(<2 x i64> %in) {
+; CHECK-LABEL: testv_v2f64_div:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: scvtf q0, q0, #4
+; CHECK-NEXT: ret
+entry:
+ %vcvt.i = sitofp <2 x i64> %in to <2 x double>
+ %div.i = fdiv <2 x double> %vcvt.i, <double 16.0, double 16.0>
+ ret <2 x double> %div.i
+}
\ No newline at end of file
More information about the llvm-commits
mailing list