[llvm] 6d6314b - [DAGCombiner] Extend `combineFMulOrFDivWithIntPow2` to work for non-splat float vecs
Noah Goldstein via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 20 11:28:46 PDT 2023
Author: Noah Goldstein
Date: 2023-09-20T13:28:24-05:00
New Revision: 6d6314ba644902d3cca7d5e6bd4c0021f82ab55b
URL: https://github.com/llvm/llvm-project/commit/6d6314ba644902d3cca7d5e6bd4c0021f82ab55b
DIFF: https://github.com/llvm/llvm-project/commit/6d6314ba644902d3cca7d5e6bd4c0021f82ab55b.diff
LOG: [DAGCombiner] Extend `combineFMulOrFDivWithIntPow2` to work for non-splat float vecs
Do so by extending `matchUnaryPredicate` to also work for
`ConstantFPSDNode` types then encapsulate the constant checks in a
lambda and pass it to `matchUnaryPredicate`.
Reviewed By: RKSimon
Differential Revision: https://reviews.llvm.org/D154868
Added:
Modified:
llvm/include/llvm/CodeGen/SelectionDAGNodes.h
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 4d4c2673382b16c..59c6feec8bcbfed 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3128,9 +3128,25 @@ namespace ISD {
/// Attempt to match a unary predicate against a scalar/splat constant or
/// every element of a constant BUILD_VECTOR.
/// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
- bool matchUnaryPredicate(SDValue Op,
- std::function<bool(ConstantSDNode *)> Match,
- bool AllowUndefs = false);
+ template <typename ConstNodeType>
+ bool matchUnaryPredicateImpl(SDValue Op,
+ std::function<bool(ConstNodeType *)> Match,
+ bool AllowUndefs = false);
+
+ /// Hook for matching ConstantSDNode predicate
+ inline bool matchUnaryPredicate(SDValue Op,
+ std::function<bool(ConstantSDNode *)> Match,
+ bool AllowUndefs = false) {
+ return matchUnaryPredicateImpl<ConstantSDNode>(Op, Match, AllowUndefs);
+ }
+
+ /// Hook for matching ConstantFPSDNode predicate
+ inline bool
+ matchUnaryFpPredicate(SDValue Op,
+ std::function<bool(ConstantFPSDNode *)> Match,
+ bool AllowUndefs = false) {
+ return matchUnaryPredicateImpl<ConstantFPSDNode>(Op, Match, AllowUndefs);
+ }
/// Attempt to match a binary predicate against a pair of scalar/splat
/// constants or every element of a pair of constant BUILD_VECTORs.
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index b69badad625d935..693523e737acf66 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -16352,7 +16352,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
EVT VT = N->getValueType(0);
SDValue ConstOp, Pow2Op;
- int Mantissa = -1;
+ std::optional<int> Mantissa;
auto GetConstAndPow2Ops = [&](unsigned ConstOpIdx) {
if (ConstOpIdx == 1 && N->getOpcode() == ISD::FDIV)
return false;
@@ -16366,36 +16366,43 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
Pow2Op = Pow2Op.getOperand(0);
- // TODO(1): We may be able to include undefs.
- // TODO(2): We could also handle non-splat vector types.
- ConstantFPSDNode *CFP =
- isConstOrConstSplatFP(ConstOp, /*AllowUndefs*/ false);
- if (CFP == nullptr)
- return false;
- const APFloat &APF = CFP->getValueAPF();
-
- // Make sure we have normal/ieee constant.
- if (!APF.isNormal() || !APF.isIEEE())
- return false;
-
// `Log2(Pow2Op) < Pow2Op.getScalarSizeInBits()`.
// TODO: We could use knownbits to make this bound more precise.
int MaxExpChange = Pow2Op.getValueType().getScalarSizeInBits();
- // Make sure the floats exponent is within the bounds that this transform
- // produces bitwise equals value.
- int CurExp = ilogb(APF);
- // FMul by pow2 will only increase exponent.
- int MinExp = N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
- // FDiv by pow2 will only decrease exponent.
- int MaxExp = N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
- if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
- MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
- return false;
+ auto IsFPConstValid = [N, MaxExpChange, &Mantissa](ConstantFPSDNode *CFP) {
+ if (CFP == nullptr)
+ return false;
+
+ const APFloat &APF = CFP->getValueAPF();
+
+ // Make sure we have normal/ieee constant.
+ if (!APF.isNormal() || !APF.isIEEE())
+ return false;
+
+ // Make sure the floats exponent is within the bounds that this transform
+ // produces bitwise equals value.
+ int CurExp = ilogb(APF);
+ // FMul by pow2 will only increase exponent.
+ int MinExp =
+ N->getOpcode() == ISD::FMUL ? CurExp : (CurExp - MaxExpChange);
+ // FDiv by pow2 will only decrease exponent.
+ int MaxExp =
+ N->getOpcode() == ISD::FDIV ? CurExp : (CurExp + MaxExpChange);
+ if (MinExp <= APFloat::semanticsMinExponent(APF.getSemantics()) ||
+ MaxExp >= APFloat::semanticsMaxExponent(APF.getSemantics()))
+ return false;
+
+ // Finally make sure we actually know the mantissa for the float type.
+ int ThisMantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
+ if (!Mantissa)
+ Mantissa = ThisMantissa;
+
+ return *Mantissa == ThisMantissa && ThisMantissa > 0;
+ };
- // Finally make sure we actually know the mantissa for the float type.
- Mantissa = APFloat::semanticsPrecision(APF.getSemantics()) - 1;
- return Mantissa > 0;
+ // TODO: We may be able to include undefs.
+ return ISD::matchUnaryFpPredicate(ConstOp, IsFPConstValid);
};
if (!GetConstAndPow2Ops(0) && !GetConstAndPow2Ops(1))
@@ -16420,7 +16427,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
// Perform actual transform.
SDValue MantissaShiftCnt =
- DAG.getConstant(Mantissa, DL, getShiftAmountTy(NewIntVT));
+ DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
// TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
// `(X << C1) + (C << C1)`, but that isn't always the case because of the
// cast. We could implement that by handle here to handle the casts.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7fcd1f4f898911a..c15f056551b93fd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -344,12 +344,13 @@ bool ISD::isFreezeUndef(const SDNode *N) {
return N->getOpcode() == ISD::FREEZE && N->getOperand(0).isUndef();
}
-bool ISD::matchUnaryPredicate(SDValue Op,
- std::function<bool(ConstantSDNode *)> Match,
- bool AllowUndefs) {
+template <typename ConstNodeType>
+bool ISD::matchUnaryPredicateImpl(SDValue Op,
+ std::function<bool(ConstNodeType *)> Match,
+ bool AllowUndefs) {
// FIXME: Add support for scalar UNDEF cases?
- if (auto *Cst = dyn_cast<ConstantSDNode>(Op))
- return Match(Cst);
+ if (auto *C = dyn_cast<ConstNodeType>(Op))
+ return Match(C);
// FIXME: Add support for vector UNDEF cases?
if (ISD::BUILD_VECTOR != Op.getOpcode() &&
@@ -364,12 +365,17 @@ bool ISD::matchUnaryPredicate(SDValue Op,
continue;
}
- auto *Cst = dyn_cast<ConstantSDNode>(Op.getOperand(i));
+ auto *Cst = dyn_cast<ConstNodeType>(Op.getOperand(i));
if (!Cst || Cst->getValueType(0) != SVT || !Match(Cst))
return false;
}
return true;
}
+// Build used template types.
+template bool ISD::matchUnaryPredicateImpl<ConstantSDNode>(
+ SDValue, std::function<bool(ConstantSDNode *)>, bool);
+template bool ISD::matchUnaryPredicateImpl<ConstantFPSDNode>(
+ SDValue, std::function<bool(ConstantFPSDNode *)>, bool);
bool ISD::matchBinaryPredicate(
SDValue LHS, SDValue RHS,
diff --git a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
index 275afe11f1e0c7e..8d98ec7eaac2a4c 100644
--- a/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
+++ b/llvm/test/CodeGen/X86/fold-int-pow2-with-fmul-or-fdiv.ll
@@ -1104,58 +1104,15 @@ define <4 x float> @fmul_pow_shl_cnt_vec_preserve_fma(<4 x i32> %cnt, <4 x float
define <2 x double> @fmul_pow_shl_cnt_vec_non_splat_todo(<2 x i64> %cnt) nounwind {
; CHECK-SSE-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
; CHECK-SSE: # %bb.0:
-; CHECK-SSE-NEXT: movdqa {{.*#+}} xmm1 = [2,2]
-; CHECK-SSE-NEXT: movdqa %xmm1, %xmm2
-; CHECK-SSE-NEXT: psllq %xmm0, %xmm2
-; CHECK-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[2,3,2,3]
-; CHECK-SSE-NEXT: psllq %xmm0, %xmm1
-; CHECK-SSE-NEXT: movsd {{.*#+}} xmm1 = xmm2[0],xmm1[1]
-; CHECK-SSE-NEXT: movapd {{.*#+}} xmm0 = [4294967295,4294967295]
-; CHECK-SSE-NEXT: andpd %xmm1, %xmm0
-; CHECK-SSE-NEXT: orpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
-; CHECK-SSE-NEXT: psrlq $32, %xmm1
-; CHECK-SSE-NEXT: por {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; CHECK-SSE-NEXT: subpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; CHECK-SSE-NEXT: addpd %xmm0, %xmm1
-; CHECK-SSE-NEXT: mulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
-; CHECK-SSE-NEXT: movapd %xmm1, %xmm0
+; CHECK-SSE-NEXT: psllq $52, %xmm0
+; CHECK-SSE-NEXT: paddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; CHECK-SSE-NEXT: retq
;
-; CHECK-AVX2-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
-; CHECK-AVX2: # %bb.0:
-; CHECK-AVX2-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2]
-; CHECK-AVX2-NEXT: vpsllvq %xmm0, %xmm1, %xmm0
-; CHECK-AVX2-NEXT: vpxor %xmm1, %xmm1, %xmm1
-; CHECK-AVX2-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
-; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; CHECK-AVX2-NEXT: vpsrlq $32, %xmm0, %xmm0
-; CHECK-AVX2-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-AVX2-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-AVX2-NEXT: vaddpd %xmm0, %xmm1, %xmm0
-; CHECK-AVX2-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-AVX2-NEXT: retq
-;
-; CHECK-NO-FASTFMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
-; CHECK-NO-FASTFMA: # %bb.0:
-; CHECK-NO-FASTFMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2]
-; CHECK-NO-FASTFMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0
-; CHECK-NO-FASTFMA-NEXT: vpxor %xmm1, %xmm1, %xmm1
-; CHECK-NO-FASTFMA-NEXT: vpblendd {{.*#+}} xmm1 = xmm0[0],xmm1[1],xmm0[2],xmm1[3]
-; CHECK-NO-FASTFMA-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %xmm1
-; CHECK-NO-FASTFMA-NEXT: vpsrlq $32, %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT: vpor {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT: vsubpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT: vaddpd %xmm0, %xmm1, %xmm0
-; CHECK-NO-FASTFMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-NO-FASTFMA-NEXT: retq
-;
-; CHECK-FMA-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
-; CHECK-FMA: # %bb.0:
-; CHECK-FMA-NEXT: vpbroadcastq {{.*#+}} xmm1 = [2,2]
-; CHECK-FMA-NEXT: vpsllvq %xmm0, %xmm1, %xmm0
-; CHECK-FMA-NEXT: vcvtuqq2pd %xmm0, %xmm0
-; CHECK-FMA-NEXT: vmulpd {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
-; CHECK-FMA-NEXT: retq
+; CHECK-AVX-LABEL: fmul_pow_shl_cnt_vec_non_splat_todo:
+; CHECK-AVX: # %bb.0:
+; CHECK-AVX-NEXT: vpsllq $52, %xmm0, %xmm0
+; CHECK-AVX-NEXT: vpaddq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; CHECK-AVX-NEXT: retq
%shl = shl nsw nuw <2 x i64> <i64 2, i64 2>, %cnt
%conv = uitofp <2 x i64> %shl to <2 x double>
%mul = fmul <2 x double> <double 15.000000e+00, double 14.000000e+00>, %conv
More information about the llvm-commits
mailing list