[Mlir-commits] [mlir] ced23aa - [MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. (#126103)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 10 09:52:27 PST 2025
Author: Benoit Jacob
Date: 2025-02-10T09:52:24-08:00
New Revision: ced23aa5403240f26cba4d99b59bf5d31d6035ac
URL: https://github.com/llvm/llvm-project/commit/ced23aa5403240f26cba4d99b59bf5d31d6035ac
DIFF: https://github.com/llvm/llvm-project/commit/ced23aa5403240f26cba4d99b59bf5d31d6035ac.diff
LOG: [MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. (#126103)
The existing `mlir::populateMathPolynomialApproximationPatterns` is
coarse-grained and inflexible:
- It populates 2 distinct classes of patterns: (1) polynomial
approximations, (2) expansions of operands to f32.
- It does not offer knobs to select which math functions to apply the
rewrites to.
This PR adds finer-grained populate-patterns functions, which take a
predicate lambda allowing the caller to control which math functions to
apply rewrites to.
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Math/Transforms/Passes.h
mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f0f17c6adcb088e..ea7a556297a76ac 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -48,6 +48,25 @@ struct MathPolynomialApproximationOptions {
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
+// Adds patterns to convert to f32 around math functions for which `predicate`
+// returns true.
+void populateMathF32ExpansionPatterns(
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+
+// Adds patterns to enable polynomial approximations for math functions for
+// which `predicate` returns true.
+void populateMathPolynomialApproximationPatterns(
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+
+// Legacy. Calls both populateMathF32ExpansionPatterns and
+// populateMathPolynomialApproximationPatterns with predicates enabling a
+// certain set of math function rewrites, that probably can't be changed for
+// compatibility reasons. Notice that unlike
+// populateMathPolynomialApproximationPatterns(patterns, predicate), this
+// overload also calls populateMathF32ExpansionPatterns.
+// Prefer calling these functions directly:
+// * populateMathF32ExpansionPatterns(patterns, predicate)
+// * populateMathPolynomialApproximationPatterns(patterns, predicate)
void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 24c892f68b50316..777427de9465c5d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1667,28 +1667,125 @@ void mlir::populatePolynomialApproximateErfPattern(
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
}
+template <typename OpType>
+static void
+populateMathF32ExpansionPattern(RewritePatternSet &patterns,
+ llvm::function_ref<bool(StringRef)> predicate) {
+ if (predicate(OpType::getOperationName())) {
+ patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+ }
+}
+
+void mlir::populateMathF32ExpansionPatterns(
+ RewritePatternSet &patterns,
+ llvm::function_ref<bool(StringRef)> predicate) {
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+}
+
+template <typename OpType, typename PatternType>
+static void populateMathPolynomialApproximationPattern(
+ RewritePatternSet &patterns,
+ llvm::function_ref<bool(StringRef)> predicate) {
+ if (predicate(OpType::getOperationName())) {
+ patterns.add<PatternType>(patterns.getContext());
+ }
+}
+
+void mlir::populateMathPolynomialApproximationPatterns(
+ RewritePatternSet &patterns,
+ llvm::function_ref<bool(StringRef)> predicate) {
+ populateMathPolynomialApproximationPattern<AcosOp,
+ AcosPolynomialApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<AsinOp,
+ AsinPolynomialApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<
+ CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+ populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
+ patterns, predicate);
+ populateMathPolynomialApproximationPattern<
+ SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+ populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
+ patterns, predicate);
+}
+
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
- // Patterns for leveraging existing f32 lowerings on other data types.
- patterns
- .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
- patterns.getContext());
-
- patterns
- .add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, AsinPolynomialApproximation,
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+ mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) -> bool {
+ return llvm::is_contained(
+ {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
+ math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
+ math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
+ math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
+ math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
+ math::SinOp::getOperationName(), math::CosOp::getOperationName()},
+ name);
+ });
+
+ populateMathPolynomialApproximationPatterns(
+ patterns, [](StringRef name) -> bool {
+ return llvm::is_contained(
+ {math::AtanOp::getOperationName(),
+ math::Atan2Op::getOperationName(),
+ math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
+ math::Log2Op::getOperationName(),
+ math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
+ math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
+ math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
+ math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
+ math::CosOp::getOperationName()},
+ name);
+ });
+
if (options.enableAvx2) {
- patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
- patterns.getContext());
+ auto predicateRsqrt = [](StringRef name) {
+ return name == math::RsqrtOp::getOperationName();
+ };
+ mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
+ mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
}
}
More information about the Mlir-commits
mailing list