[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)
JP Hafer via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 27 04:53:47 PDT 2025
================
@@ -3952,6 +3960,132 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
return true;
}
+static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
+ SDValue N,
+ SDValue &FixedPos,
+ unsigned FloatWidth,
+ bool IsReciprocal) {
+
+ SDValue ImmediateNode;
+ // N must be a bitcast, nvcast, or fmov
+ if (N.getOpcode() == ISD::BITCAST || N.getOpcode() == AArch64ISD::NVCAST ||
+ N.getOpcode() == AArch64ISD::FMOV) {
+ ImmediateNode = N.getOperand(0);
+ } else {
+ return false;
+ }
+
+ EVT NodeVT = N.getValueType();
+ // In theory the immediate node value type would be a vector. However,
+ // this is not the case when using 2.0. Thus check N's value type for
+ // vector and floating point instead.
+ if (!NodeVT.isVector() || !NodeVT.isFloatingPoint())
+ return false;
+
+ if (!(ImmediateNode.getOpcode() == AArch64ISD::DUP ||
+ ImmediateNode.getOpcode() == AArch64ISD::MOVIshift ||
+ ImmediateNode.getOpcode() == ISD::Constant ||
+ ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR ||
+ ImmediateNode.getOpcode() == ISD::BUILD_VECTOR)) {
+ return false; // Not a possible splat
+ }
+
+ if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
+ // For BUILD_VECTOR, we must explicitly check if it's a constant splat.
+ BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
+ APInt SplatValue;
+ APInt SplatUndef;
+ unsigned SplatBitSize;
+ bool HasAnyUndefs;
+ if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+ HasAnyUndefs)) {
+ return false;
+ }
+ }
+
+ APInt Imm;
+ bool IsIntConstant = false;
+ if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
+ Imm = APInt(NodeVT.getScalarSizeInBits(),
+ ImmediateNode.getConstantOperandVal(0)
+ << ImmediateNode.getConstantOperandVal(1));
+ IsIntConstant = true;
+ } else if (ImmediateNode.getOpcode() == ISD::Constant) {
+ auto *C = dyn_cast<ConstantSDNode>(ImmediateNode);
+ if (!C)
+ return false;
+ uint8_t EncodedU8 = static_cast<uint8_t>(C->getZExtValue());
+ uint64_t DecodedBits = AArch64_AM::decodeAdvSIMDModImmType11(EncodedU8);
+
+ unsigned BitWidth = N.getValueType().getVectorElementType().getSizeInBits();
+ uint64_t Mask = (BitWidth == 64) ? ~0ULL : ((1ULL << BitWidth) - 1);
+ uint64_t MaskedBits = DecodedBits & Mask;
+
+ Imm = APInt(BitWidth, MaskedBits);
+ IsIntConstant = true;
+ } else if (auto *CI = dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(0))) {
+ Imm = CI->getAPIntValue();
+ IsIntConstant = true;
+ }
+
+ APFloat FVal(0.0);
+ // --- Extract the actual constant value ---
+ if (IsIntConstant) {
+ // Scalar source is an integer constant; interpret its bits as
+ // floating-point.
+ EVT FloatEltVT = N.getValueType().getVectorElementType();
+
+ if (FloatEltVT == MVT::f32) {
+ FVal = APFloat(APFloat::IEEEsingle(), Imm);
+ } else if (FloatEltVT == MVT::f64) {
+ FVal = APFloat(APFloat::IEEEdouble(), Imm);
+ } else if (FloatEltVT == MVT::f16) {
+ FVal = APFloat(APFloat::IEEEhalf(), Imm);
+ } else {
+ // Unsupported floating-point element type.
+ return false;
+ }
+ } else if (auto *CFP =
+ dyn_cast<ConstantFPSDNode>(ImmediateNode.getOperand(0))) {
----------------
jph-13 wrote:
Not in the test cases I have. Another cautionary reiic from an AI conversation. I will drop it.
https://github.com/llvm/llvm-project/pull/141480
More information about the llvm-commits
mailing list