[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)
JP Hafer via llvm-commits
llvm-commits at lists.llvm.org
Tue Jun 24 07:01:00 PDT 2025
================
@@ -3952,6 +3960,129 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
return true;
}
+static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
+ SDValue N,
+ SDValue &FixedPos,
+ unsigned FloatWidth,
+ bool isReciprocal) {
+
+ // N must be a bitcast/nvcast of a vector float type.
+ if (!((N.getOpcode() == ISD::BITCAST ||
+ N.getOpcode() == AArch64ISD::NVCAST) &&
+ N.getValueType().isVector() && N.getValueType().isFloatingPoint())) {
+ return false;
+ }
+
+ if (N.getNumOperands() == 0)
+ return false;
+ SDValue ImmediateNode = N.getOperand(0);
+
+ bool isSplatConfirmed = false;
+
+ if (ImmediateNode.getOpcode() == AArch64ISD::DUP ||
+ ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR) {
+ // These opcodes inherently mean a splat.
+ isSplatConfirmed = true;
+ } else if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
+ // For BUILD_VECTOR, we must explicitly check if it's a constant splat.
+ BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
+ APInt SplatValue;
+ APInt SplatUndef;
+ unsigned SplatBitSize;
+ bool HasAnyUndefs;
+ if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+ HasAnyUndefs)) {
+ isSplatConfirmed = true;
+ } else {
+ return false;
+ }
+ } else if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
+ // This implies that the DAG structure was (DUP (MOVIshift C)) or
+ // (BUILD_VECTOR (MOVIshift C)).
+ isSplatConfirmed = true;
+ } else {
+ return false;
+ }
+
+ // If we reached here, isSplatConfirmed should be true and ScalarSourceNode
+ // should be set. But just in case ...
+ if (!isSplatConfirmed)
+ return false;
+
+ // --- Extract the actual constant value ---
+ auto ScalarSourceNode = ImmediateNode.getOperand(0);
+ APFloat FVal(0.0);
+ if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
+ // Scalar source is a floating-point constant.
+ FVal = CFP->getValueAPF();
+ } else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
+ // Scalar source is an integer constant; interpret its bits as
+ // floating-point.
+ EVT FloatEltVT = N.getValueType().getVectorElementType();
+
+ if (FloatEltVT == MVT::f32) {
+ FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
+ } else if (FloatEltVT == MVT::f64) {
+ FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
+ } else if (FloatEltVT == MVT::f16) {
+ auto *ShiftAmountConst =
+ dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(1));
+
+ if (!ShiftAmountConst) {
+ return false;
+ }
+ APInt ImmediateVal = CI->getAPIntValue();
+ unsigned ShiftAmount = ShiftAmountConst->getAPIntValue().getZExtValue();
+ APInt EffectiveBits = ImmediateVal.trunc(16).shl(ShiftAmount);
+ FVal = APFloat(APFloat::IEEEhalf(), EffectiveBits);
+ } else {
+ // Unsupported floating-point element type.
+ return false;
+ }
+ } else {
+ // ScalarSourceNode is not a recognized constant type.
+ return false;
+ }
+
+ // --- Perform fixed-point reciprocal check and power-of-2 validation on FVal
+ // --- Normalize f16 to f32 if needed for consistent APFloat operations.
+ if (N.getValueType().getVectorElementType() == MVT::f16) {
+ bool ignored;
----------------
jph-13 wrote:
This code was a relic from an AI conversation. I don't think the code was needed. Originally I was doing more intermediate math. I removed the logic.
https://github.com/llvm/llvm-project/pull/141480
More information about the llvm-commits
mailing list