[Mlir-commits] [mlir] [MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. (PR #126103)
Jakub Kuderski
llvmlistbot at llvm.org
Fri Feb 7 13:13:54 PST 2025
================
@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
}
+void mlir::populateMathF32ExpansionPatterns(
+ RewritePatternSet &patterns,
+ const std::function<bool(StringRef)> &predicate) {
+ MLIRContext *context = patterns.getContext();
+ if (predicate("acos")) {
+ patterns.add<ReuseF32Expansion<math::AcosOp>>(context);
+ }
+ if (predicate("acosh")) {
+ patterns.add<ReuseF32Expansion<math::AcoshOp>>(context);
+ }
+ if (predicate("asin")) {
+ patterns.add<ReuseF32Expansion<math::AsinOp>>(context);
+ }
+ if (predicate("asinh")) {
+ patterns.add<ReuseF32Expansion<math::AsinhOp>>(context);
+ }
+ if (predicate("atan")) {
+ patterns.add<ReuseF32Expansion<math::AtanOp>>(context);
+ }
+ if (predicate("atan2")) {
+ patterns.add<ReuseF32Expansion<math::Atan2Op>>(context);
+ }
+ if (predicate("atanh")) {
+ patterns.add<ReuseF32Expansion<math::AtanhOp>>(context);
+ }
+ if (predicate("cbrt")) {
+ patterns.add<ReuseF32Expansion<math::CbrtOp>>(context);
+ }
+ if (predicate("cos")) {
+ patterns.add<ReuseF32Expansion<math::CosOp>>(context);
+ }
+ if (predicate("cosh")) {
+ patterns.add<ReuseF32Expansion<math::CoshOp>>(context);
+ }
+ if (predicate("erf")) {
+ patterns.add<ReuseF32Expansion<math::ErfOp>>(context);
+ }
+ if (predicate("exp")) {
+ patterns.add<ReuseF32Expansion<math::ExpOp>>(context);
+ }
+ if (predicate("exp2")) {
+ patterns.add<ReuseF32Expansion<math::Exp2Op>>(context);
+ }
+ if (predicate("expm1")) {
+ patterns.add<ReuseF32Expansion<math::ExpM1Op>>(context);
+ }
+ if (predicate("log")) {
+ patterns.add<ReuseF32Expansion<math::LogOp>>(context);
+ }
+ if (predicate("log10")) {
+ patterns.add<ReuseF32Expansion<math::Log10Op>>(context);
+ }
+ if (predicate("log2")) {
+ patterns.add<ReuseF32Expansion<math::Log2Op>>(context);
+ }
+ if (predicate("log1p")) {
+ patterns.add<ReuseF32Expansion<math::Log1pOp>>(context);
+ }
+ if (predicate("powf")) {
+ patterns.add<ReuseF32Expansion<math::PowFOp>>(context);
+ }
+ if (predicate("rsqrt")) {
+ patterns.add<ReuseF32Expansion<math::RsqrtOp>>(context);
+ }
+ if (predicate("sin")) {
+ patterns.add<ReuseF32Expansion<math::SinOp>>(context);
+ }
+ if (predicate("sinh")) {
+ patterns.add<ReuseF32Expansion<math::SinhOp>>(context);
+ }
+ if (predicate("sqrt")) {
+ patterns.add<ReuseF32Expansion<math::SqrtOp>>(context);
+ }
+ if (predicate("tan")) {
+ patterns.add<ReuseF32Expansion<math::TanOp>>(context);
+ }
+ if (predicate("tanh")) {
+ patterns.add<ReuseF32Expansion<math::TanhOp>>(context);
+ }
+}
+
+void mlir::populateMathPolynomialApproximationPatterns(
+ RewritePatternSet &patterns,
+ const std::function<bool(StringRef)> &predicate) {
+ MLIRContext *context = patterns.getContext();
+ if (predicate("acos")) {
+ patterns.add<AcosPolynomialApproximation>(context);
+ }
+ if (predicate("asin")) {
+ patterns.add<AsinPolynomialApproximation>(context);
+ }
+ if (predicate("atan")) {
+ patterns.add<AtanApproximation>(context);
+ }
+ if (predicate("atan2")) {
+ patterns.add<Atan2Approximation>(context);
+ }
+ if (predicate("cbrt")) {
+ patterns.add<CbrtApproximation>(context);
+ }
+ if (predicate("cos")) {
+ patterns.add<SinAndCosApproximation<false, math::CosOp>>(context);
+ }
+ if (predicate("erf")) {
+ patterns.add<ErfPolynomialApproximation>(context);
+ }
+ if (predicate("exp")) {
+ patterns.add<ExpApproximation>(context);
+ }
+ if (predicate("expm1")) {
+ patterns.add<ExpM1Approximation>(context);
+ }
+ if (predicate("log")) {
+ patterns.add<LogApproximation>(context);
+ }
+ if (predicate("log2")) {
+ patterns.add<Log2Approximation>(context);
+ }
+ if (predicate("log1p")) {
+ patterns.add<Log1pApproximation>(context);
+ }
+ if (predicate("rsqrt")) {
+ patterns.add<RsqrtApproximation>(context);
+ }
+ if (predicate("sin")) {
+ patterns.add<SinAndCosApproximation<true, math::SinOp>>(context);
+ }
+ if (predicate("tanh")) {
+ patterns.add<TanhApproximation>(context);
+ }
+}
+
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
- // Patterns for leveraging existing f32 lowerings on other data types.
- patterns
- .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
- ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
- ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
- ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
- ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
- ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
- patterns.getContext());
-
- patterns
- .add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, AsinPolynomialApproximation,
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+ mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) {
+ return name == "atan" || name == "atan2" || name == "tanh" ||
+ name == "log" || name == "log2" || name == "log1p" ||
+ name == "erf" || name == "exp" || name == "expm1" ||
+ name == "cbrt" || name == "sin" || name == "cos";
----------------
kuhar wrote:
Use `llvm::is_contained({"atan", "atan2", ...}, name}`
https://github.com/llvm/llvm-project/pull/126103
More information about the Mlir-commits
mailing list