[Mlir-commits] [mlir] [MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. (PR #126103)

Jakub Kuderski llvmlistbot at llvm.org
Fri Feb 7 13:13:54 PST 2025


================
@@ -1667,28 +1667,158 @@ void mlir::populatePolynomialApproximateErfPattern(
   patterns.add<ErfPolynomialApproximation>(patterns.getContext());
 }
 
+void mlir::populateMathF32ExpansionPatterns(
+    RewritePatternSet &patterns,
+    const std::function<bool(StringRef)> &predicate) {
+  MLIRContext *context = patterns.getContext();
+  if (predicate("acos")) {
+    patterns.add<ReuseF32Expansion<math::AcosOp>>(context);
+  }
+  if (predicate("acosh")) {
+    patterns.add<ReuseF32Expansion<math::AcoshOp>>(context);
+  }
+  if (predicate("asin")) {
+    patterns.add<ReuseF32Expansion<math::AsinOp>>(context);
+  }
+  if (predicate("asinh")) {
+    patterns.add<ReuseF32Expansion<math::AsinhOp>>(context);
+  }
+  if (predicate("atan")) {
+    patterns.add<ReuseF32Expansion<math::AtanOp>>(context);
+  }
+  if (predicate("atan2")) {
+    patterns.add<ReuseF32Expansion<math::Atan2Op>>(context);
+  }
+  if (predicate("atanh")) {
+    patterns.add<ReuseF32Expansion<math::AtanhOp>>(context);
+  }
+  if (predicate("cbrt")) {
+    patterns.add<ReuseF32Expansion<math::CbrtOp>>(context);
+  }
+  if (predicate("cos")) {
+    patterns.add<ReuseF32Expansion<math::CosOp>>(context);
+  }
+  if (predicate("cosh")) {
+    patterns.add<ReuseF32Expansion<math::CoshOp>>(context);
+  }
+  if (predicate("erf")) {
+    patterns.add<ReuseF32Expansion<math::ErfOp>>(context);
+  }
+  if (predicate("exp")) {
+    patterns.add<ReuseF32Expansion<math::ExpOp>>(context);
+  }
+  if (predicate("exp2")) {
+    patterns.add<ReuseF32Expansion<math::Exp2Op>>(context);
+  }
+  if (predicate("expm1")) {
+    patterns.add<ReuseF32Expansion<math::ExpM1Op>>(context);
+  }
+  if (predicate("log")) {
+    patterns.add<ReuseF32Expansion<math::LogOp>>(context);
+  }
+  if (predicate("log10")) {
+    patterns.add<ReuseF32Expansion<math::Log10Op>>(context);
+  }
+  if (predicate("log2")) {
+    patterns.add<ReuseF32Expansion<math::Log2Op>>(context);
+  }
+  if (predicate("log1p")) {
+    patterns.add<ReuseF32Expansion<math::Log1pOp>>(context);
+  }
+  if (predicate("powf")) {
+    patterns.add<ReuseF32Expansion<math::PowFOp>>(context);
+  }
+  if (predicate("rsqrt")) {
+    patterns.add<ReuseF32Expansion<math::RsqrtOp>>(context);
+  }
+  if (predicate("sin")) {
+    patterns.add<ReuseF32Expansion<math::SinOp>>(context);
+  }
+  if (predicate("sinh")) {
+    patterns.add<ReuseF32Expansion<math::SinhOp>>(context);
+  }
+  if (predicate("sqrt")) {
+    patterns.add<ReuseF32Expansion<math::SqrtOp>>(context);
+  }
+  if (predicate("tan")) {
+    patterns.add<ReuseF32Expansion<math::TanOp>>(context);
+  }
+  if (predicate("tanh")) {
+    patterns.add<ReuseF32Expansion<math::TanhOp>>(context);
+  }
+}
+
+void mlir::populateMathPolynomialApproximationPatterns(
+    RewritePatternSet &patterns,
+    const std::function<bool(StringRef)> &predicate) {
+  MLIRContext *context = patterns.getContext();
+  if (predicate("acos")) {
+    patterns.add<AcosPolynomialApproximation>(context);
+  }
+  if (predicate("asin")) {
+    patterns.add<AsinPolynomialApproximation>(context);
+  }
+  if (predicate("atan")) {
+    patterns.add<AtanApproximation>(context);
+  }
+  if (predicate("atan2")) {
+    patterns.add<Atan2Approximation>(context);
+  }
+  if (predicate("cbrt")) {
+    patterns.add<CbrtApproximation>(context);
+  }
+  if (predicate("cos")) {
+    patterns.add<SinAndCosApproximation<false, math::CosOp>>(context);
+  }
+  if (predicate("erf")) {
+    patterns.add<ErfPolynomialApproximation>(context);
+  }
+  if (predicate("exp")) {
+    patterns.add<ExpApproximation>(context);
+  }
+  if (predicate("expm1")) {
+    patterns.add<ExpM1Approximation>(context);
+  }
+  if (predicate("log")) {
+    patterns.add<LogApproximation>(context);
+  }
+  if (predicate("log2")) {
+    patterns.add<Log2Approximation>(context);
+  }
+  if (predicate("log1p")) {
+    patterns.add<Log1pApproximation>(context);
+  }
+  if (predicate("rsqrt")) {
+    patterns.add<RsqrtApproximation>(context);
+  }
+  if (predicate("sin")) {
+    patterns.add<SinAndCosApproximation<true, math::SinOp>>(context);
+  }
+  if (predicate("tanh")) {
+    patterns.add<TanhApproximation>(context);
+  }
+}
+
 void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
-  // Patterns for leveraging existing f32 lowerings on other data types.
-  patterns
-      .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
-           ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
-           ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
-           ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
-           ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
-           ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
-          patterns.getContext());
-
-  patterns
-      .add<AtanApproximation, Atan2Approximation, TanhApproximation,
-           LogApproximation, Log2Approximation, Log1pApproximation,
-           ErfPolynomialApproximation, AsinPolynomialApproximation,
-           AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-           CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-           SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+  mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) {
+    return name == "atan" || name == "atan2" || name == "tanh" ||
+           name == "log" || name == "log2" || name == "log1p" ||
+           name == "erf" || name == "exp" || name == "expm1" ||
+           name == "cbrt" || name == "sin" || name == "cos";
----------------
kuhar wrote:

Use `llvm::is_contained({"atan", "atan2", ...}, name}`

https://github.com/llvm/llvm-project/pull/126103


More information about the Mlir-commits mailing list