[llvm] 6d6314b - [DAGCombiner] Extend `combineFMulOrFDivWithIntPow2` to work for non-splat float vecs

Noah Goldstein via llvm-commits llvm-commits at lists.llvm.org
Wed Sep 20 11:28:46 PDT 2023


Author: Noah Goldstein
Date: 2023-09-20T13:28:24-05:00
New Revision: 6d6314ba644902d3cca7d5e6bd4c0021f82ab55b

URL: https://github.com/llvm/llvm-project/commit/6d6314ba644902d3cca7d5e6bd4c0021f82ab55b
DIFF: https://github.com/llvm/llvm-project/commit/6d6314ba644902d3cca7d5e6bd4c0021f82ab55b.diff

LOG: [DAGCombiner] Extend `combineFMulOrFDivWithIntPow2` to work for non-splat float vecs

Do so by extending `matchUnaryPredicate` to also work for
`ConstantFPSDNode` types then encapsulate the constant checks in a
lambda and pass it to `matchUnaryPredicate`.

Reviewed By: RKSimon

Differential Revision: https://reviews.llvm.org/D154868

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAGNodes.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 4d4c2673382b16c..59c6feec8bcbfed 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3128,9 +3128,25 @@ namespace ISD {
   /// Attempt to match a unary predicate against a scalar/splat constant or
   /// every element of a constant BUILD_VECTOR.
   /// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
-  bool matchUnaryPredicate(SDValue Op,
-                           std::function<bool(ConstantSDNode *)> Match,
-                           bool AllowUndefs = false);
+  template <typename ConstNodeType>
+  bool matchUnaryPredicateImpl(SDValue Op,
+                               std::function<bool(ConstNodeType *)> Match,
+                               bool AllowUndefs = false);
+
+  /// Hook for matching ConstantSDNode predicate
+  inline bool matchUnaryPredicate(SDValue Op,
+                                  std::function<bool(ConstantSDNode *)> Match,
+                                  bool AllowUndefs = false) {
+    return matchUnaryPredicateImpl<ConstantSDNode>(Op, Match, AllowUndefs);
+  }
+
+  /// Hook for matching ConstantFPSDNode predicate
+  inline bool
+  matchUnaryFpPredicate(SDValue Op,
+                        std::function<bool(ConstantFPSDNode *)> Match,
+                        bool AllowUndefs = false) {
+    return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, Match, AllowUndefs);
+  }
 
   /// Attempt to match a binary predicate against a pair of scalar/splat
   /// constants or every element of a pair of constant BUILD_VECTORs.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b69badad625d935..693523e737acf66 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -16352,7 +16352,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
   EVT VT = N->getValueType(0);
   SDValue ConstOp, Pow2Op;
 
-  int Mantissa = -1;
+  std::optional<int> Mantissa;
   auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
     if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
       return false;
@@ -16366,36 +16366,43 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
 
     Pow2Op = Pow2Op.getOperand(0);
 
-    // TODO(1): We may be able to include undefs.
-    // TODO(2): We could also handle non-splat vector types.
-    ConstantFPSDNode *CFP =
-        isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false);
-    if (CFP == nullptr)
-      return false;
-    const APFloat &APF = CFP->getValueAPF();
-
-    // Make sure we have normal/ieee constant.
-    if (!APF.isNormal() || !APF.isIEEE())
-      return false;
-
     // `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
     // TODO: We could use knownbits to make this bound more precise.
     int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
 
-    // Make sure the floats exponent is within the bounds that this transform
-    // produces bitwise equals value.
-    int CurExp = ilogb(APF);
-    // FMul by pow2 will only increase exponent.
-    int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
-    // FDiv by pow2 will only decrease exponent.
-    int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
-    if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
-        MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
-      return false;
+    auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
+      if (CFP == nullptr)
+        return false;
+
+      const APFloat &APF = CFP->getValueAPF();
+
+      // Make sure we have normal/ieee constant.
+      if (!APF.isNormal() || !APF.isIEEE())
+        return false;
+
+      // Make sure the floats exponent is within the bounds that this transform
+      // produces bitwise equals value.
+      int CurExp = ilogb(APF);
+      // FMul by pow2 will only increase exponent.
+      int MinExp =
+          N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
+      // FDiv by pow2 will only decrease exponent.
+      int MaxExp =
+          N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
+      if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
+          MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
+        return false;
+
+      // Finally make sure we actually know the mantissa for the float type.
+      int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
+      if (!Mantissa)
+        Mantissa = ThisMantissa;
+
+      return *Mantissa == ThisMantissa && ThisMantissa > 0;
+    };
 
-    // Finally make sure we actually know the mantissa for the float type.
-    Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
-    return Mantissa > 0;
+    // TODO: We may be able to include undefs.
+    return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
   };
 
   if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
@@ -16420,7 +16427,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
 
   // Perform actual transform.
   SDValue MantissaShiftCnt =
-      DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT));
+      DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
   // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
   // `(X << C1) + (C << C1)`, but that isn't always the case because of the
   // cast. We could implement that by handle here to handle the casts.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7fcd1f4f898911a..c15f056551b93fd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -344,12 +344,13 @@ bool ISD::isFreezeUndef(const SDNode *N) {
   return N->getOpcode() == ISD::FREEZE && N->getOperand(0).isUndef();
 }
 
-bool ISD::matchUnaryPredicate(SDValue Op,
-                              std::function<bool(ConstantSDNode *)> Match,
-                              bool AllowUndefs) {
+template <typename ConstNodeType>
+bool ISD::matchUnaryPredicateImpl(SDValue Op,
+                                  std::function<bool(ConstNodeType *)> Match,
+                                  bool AllowUndefs) {
   // FIXME: Add support for scalar UNDEF cases?
-  if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
-    return Match(Cst);
+  if (auto *C = dyn_cast<ConstNodeType>(Op))
+    return Match(C);
 
   // FIXME: Add support for vector UNDEF cases?
   if (ISD::BUILD_VECTOR != Op.getOpcode() &&
@@ -364,12 +365,17 @@ bool ISD::matchUnaryPredicate(SDValue Op,
       continue;
     }
 
-    auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
+    auto *Cst = dyn_cast<ConstNodeType>(Op.getOperand(i));
     if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
       return false;
   }
   return true;
 }
+// Build used template types.
+template bool ISD::matchUnaryPredicateImpl<ConstantSDNode>(
+    SDValue, std::function<bool(ConstantSDNode *)>, bool);
+template bool ISD::matchUnaryPredicateImpl<ConstantFPSDNode>(
+    SDValue, std::function<bool(ConstantFPSDNode *)>, bool);
 
 bool ISD::matchBinaryPredicate(
     SDValue LHS, SDValue RHS,

diff  --git a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
index 275afe11f1e0c7e..8d98ec7eaac2a4c 100644
--- a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
+++ b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
@@ -1104,58 +1104,15 @@ define <4 x float> @fmul_pow_shl_cnt_vec_preserve_fma(<4 x i32> %cnt, <4 x float
 define <2 x double> @fmul_pow_shl_cnt_vec_non_splat_todo(<2 x i64> %cnt) nounwind {
 ; CHECK-SSE-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
 ; CHECK-SSE:       # %bb.0:
-; CHECK-SSE-NEXT:    movdqa {{.*#+}} xmm1 = [2,2]
-; CHECK-SSE-NEXT:    movdqa %xmm1, %xmm2
-; CHECK-SSE-NEXT:    psllq %xmm0, %xmm2
-; CHECK-SSE-NEXT:    pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
-; CHECK-SSE-NEXT:    psllq %xmm0, %xmm1
-; CHECK-SSE-NEXT:    movsd {{.*#+}} xmm1 = xmm2[0],xmm1[1]
-; CHECK-SSE-NEXT:    movapd {{.*#+}} xmm0 = [4294967295,4294967295]
-; CHECK-SSE-NEXT:    andpd %xmm1, %xmm0
-; CHECK-SSE-NEXT:    orpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; CHECK-SSE-NEXT:    psrlq $32, %xmm1
-; CHECK-SSE-NEXT:    por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; CHECK-SSE-NEXT:    subpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; CHECK-SSE-NEXT:    addpd %xmm0, %xmm1
-; CHECK-SSE-NEXT:    mulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; CHECK-SSE-NEXT:    movapd %xmm1, %xmm0
+; CHECK-SSE-NEXT:    psllq $52, %xmm0
+; CHECK-SSE-NEXT:    paddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
 ; CHECK-SSE-NEXT:    retq
 ;
-; CHECK-AVX2-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
-; CHECK-AVX2:       # %bb.0:
-; CHECK-AVX2-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2,2]
-; CHECK-AVX2-NEXT:    vpsllvq %xmm0, %xmm1, %xmm0
-; CHECK-AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; CHECK-AVX2-NEXT:    vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
-; CHECK-AVX2-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; CHECK-AVX2-NEXT:    vpsrlq $32, %xmm0, %xmm0
-; CHECK-AVX2-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-AVX2-NEXT:    vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-AVX2-NEXT:    vaddpd %xmm0, %xmm1, %xmm0
-; CHECK-AVX2-NEXT:    vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-AVX2-NEXT:    retq
-;
-; CHECK-NO-FASTFMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
-; CHECK-NO-FASTFMA:       # %bb.0:
-; CHECK-NO-FASTFMA-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2,2]
-; CHECK-NO-FASTFMA-NEXT:    vpsllvq %xmm0, %xmm1, %xmm0
-; CHECK-NO-FASTFMA-NEXT:    vpxor %xmm1, %xmm1, %xmm1
-; CHECK-NO-FASTFMA-NEXT:    vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
-; CHECK-NO-FASTFMA-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; CHECK-NO-FASTFMA-NEXT:    vpsrlq $32, %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT:    vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT:    vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT:    vaddpd %xmm0, %xmm1, %xmm0
-; CHECK-NO-FASTFMA-NEXT:    vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT:    retq
-;
-; CHECK-FMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
-; CHECK-FMA:       # %bb.0:
-; CHECK-FMA-NEXT:    vpbroadcastq {{.*#+}} xmm1 = [2,2]
-; CHECK-FMA-NEXT:    vpsllvq %xmm0, %xmm1, %xmm0
-; CHECK-FMA-NEXT:    vcvtuqq2pd %xmm0, %xmm0
-; CHECK-FMA-NEXT:    vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-FMA-NEXT:    retq
+; CHECK-AVX-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
+; CHECK-AVX:       # %bb.0:
+; CHECK-AVX-NEXT:    vpsllq $52, %xmm0, %xmm0
+; CHECK-AVX-NEXT:    vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-AVX-NEXT:    retq
   %shl = shl nsw nuw <2 x i64> <i64 2, i64 2>, %cnt
   %conv = uitofp <2 x i64> %shl to <2 x double>
   %mul = fmul <2 x double> <double 15.000000e+00, double 14.000000e+00>, %conv


        


More information about the llvm-commits mailing list