[llvm] [AArch64] Fix #94909: Optimize vector fmul(sitofp(x), 0.5) -> scvtf(x, 2) (PR #141480)

JP Hafer via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 25 07:58:40 PDT 2025


================
@@ -3952,6 +3960,119 @@ static bool checkCVTFixedPointOperandWithFBits(SelectionDAG *CurDAG, SDValue N,
   return true;
 }
 
+static bool checkCVTFixedPointOperandWithFBitsForVectors(SelectionDAG *CurDAG,
+                                                         SDValue N,
+                                                         SDValue &FixedPos,
+                                                         unsigned FloatWidth,
+                                                         bool IsReciprocal) {
+
+  SDValue ImmediateNode;
+  // N must be a bitcast or nvcast
+  if (N.getOpcode() == ISD::BITCAST || N.getOpcode() == AArch64ISD::NVCAST) {
+    ImmediateNode = N.getOperand(0);
+  } else {
+    return false;
+  }
+
+  EVT NodeVT = N.getValueType();
+  EVT InputVT = ImmediateNode.getValueType();
+  // The input must me a vector of int but N must be a floating point
+  if (!InputVT.isVector() || !InputVT.isInteger() || !NodeVT.isFloatingPoint())
+    return false;
+
+  bool IsSplatConfirmed = false;
+
+  if (ImmediateNode.getOpcode() == AArch64ISD::DUP ||
+      ImmediateNode.getOpcode() == ISD::SPLAT_VECTOR) {
+    // These opcodes inherently mean a splat.
+    IsSplatConfirmed = true;
+  } else if (ImmediateNode.getOpcode() == ISD::BUILD_VECTOR) {
+    // For BUILD_VECTOR, we must explicitly check if it's a constant splat.
+    BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(ImmediateNode.getNode());
+    APInt SplatValue;
+    APInt SplatUndef;
+    unsigned SplatBitSize;
+    bool HasAnyUndefs;
+    if (BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+                             HasAnyUndefs)) {
+      IsSplatConfirmed = true;
+    }
+  } else if (ImmediateNode.getOpcode() == AArch64ISD::MOVIshift) {
+    // This implies that the DAG structure was (DUP (MOVIshift C)) or
+    // (BUILD_VECTOR (MOVIshift C)).
+    IsSplatConfirmed = true;
+  }
+
+  if (!IsSplatConfirmed)
+    return false;
+
+  // --- Extract the actual constant value ---
+  auto ScalarSourceNode = ImmediateNode.getOperand(0);
+  APFloat FVal(0.0);
+  if (auto *CFP = dyn_cast<ConstantFPSDNode>(ScalarSourceNode)) {
+    // Scalar source is a floating-point constant.
+    FVal = CFP->getValueAPF();
+  } else if (auto *CI = dyn_cast<ConstantSDNode>(ScalarSourceNode)) {
+    // Scalar source is an integer constant; interpret its bits as
+    // floating-point.
+    EVT FloatEltVT = N.getValueType().getVectorElementType();
+
+    if (FloatEltVT == MVT::f32) {
+      FVal = APFloat(APFloat::IEEEsingle(), CI->getAPIntValue());
+    } else if (FloatEltVT == MVT::f64) {
+      FVal = APFloat(APFloat::IEEEdouble(), CI->getAPIntValue());
+    } else if (FloatEltVT == MVT::f16) {
+      auto *ShiftAmountConst =
+          dyn_cast<ConstantSDNode>(ImmediateNode.getOperand(1));
----------------
jph-13 wrote:

Oh wow, you are completely correct! My impl is missing much more than it catches. I will expand the tests.

https://github.com/llvm/llvm-project/pull/141480


More information about the llvm-commits mailing list