[Mlir-commits] [mlir] [MLIR][Math] Add fine-grained populate-patterns functions for math function rewrites. (PR #126103)

Benoit Jacob llvmlistbot at llvm.org
Mon Feb 10 03:57:30 PST 2025


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/126103

>From bd15e6bed1f6631ff3a182853c182403b341a4f2 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 5 Feb 2025 14:51:52 -0600
Subject: [PATCH] polynomial-approx

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 .../mlir/Dialect/Math/Transforms/Passes.h     |  19 +++
 .../Transforms/PolynomialApproximation.cpp    | 135 +++++++++++++++---
 2 files changed, 135 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f0f17c6adcb088e..ea7a556297a76ac 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -48,6 +48,25 @@ struct MathPolynomialApproximationOptions {
 void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
 void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
 
+// Adds patterns to convert to f32 around math functions for which `predicate`
+// returns true.
+void populateMathF32ExpansionPatterns(
+    RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+
+// Adds patterns to enable polynomial approximations for math functions for
+// which `predicate` returns true.
+void populateMathPolynomialApproximationPatterns(
+    RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+
+// Legacy. Calls both populateMathF32ExpansionPatterns and
+// populateMathPolynomialApproximationPatterns with predicates enabling a
+// certain set of math function rewrites, that probably can't be changed for
+// compatibility reasons. Notice that unlike
+// populateMathPolynomialApproximationPatterns(patterns, predicate), this
+// overload also calls populateMathF32ExpansionPatterns.
+// Prefer calling these functions directly:
+// * populateMathF32ExpansionPatterns(patterns, predicate)
+// * populateMathPolynomialApproximationPatterns(patterns, predicate)
 void populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options = {});
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 24c892f68b50316..777427de9465c5d 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1667,28 +1667,125 @@ void mlir::populatePolynomialApproximateErfPattern(
   patterns.add<ErfPolynomialApproximation>(patterns.getContext());
 }
 
+template <typename OpType>
+static void
+populateMathF32ExpansionPattern(RewritePatternSet &patterns,
+                                llvm::function_ref<bool(StringRef)> predicate) {
+  if (predicate(OpType::getOperationName())) {
+    patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+  }
+}
+
+void mlir::populateMathF32ExpansionPatterns(
+    RewritePatternSet &patterns,
+    llvm::function_ref<bool(StringRef)> predicate) {
+  populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
+  populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+}
+
+template <typename OpType, typename PatternType>
+static void populateMathPolynomialApproximationPattern(
+    RewritePatternSet &patterns,
+    llvm::function_ref<bool(StringRef)> predicate) {
+  if (predicate(OpType::getOperationName())) {
+    patterns.add<PatternType>(patterns.getContext());
+  }
+}
+
+void mlir::populateMathPolynomialApproximationPatterns(
+    RewritePatternSet &patterns,
+    llvm::function_ref<bool(StringRef)> predicate) {
+  populateMathPolynomialApproximationPattern<AcosOp,
+                                             AcosPolynomialApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<AsinOp,
+                                             AsinPolynomialApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<
+      CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+  populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
+      patterns, predicate);
+  populateMathPolynomialApproximationPattern<
+      SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+  populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
+      patterns, predicate);
+}
+
 void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
-  // Patterns for leveraging existing f32 lowerings on other data types.
-  patterns
-      .add<ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
-           ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
-           ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
-           ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
-           ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
-           ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
-          patterns.getContext());
-
-  patterns
-      .add<AtanApproximation, Atan2Approximation, TanhApproximation,
-           LogApproximation, Log2Approximation, Log1pApproximation,
-           ErfPolynomialApproximation, AsinPolynomialApproximation,
-           AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-           CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-           SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+  mlir::populateMathF32ExpansionPatterns(patterns, [](StringRef name) -> bool {
+    return llvm::is_contained(
+        {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
+         math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
+         math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
+         math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
+         math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
+         math::SinOp::getOperationName(), math::CosOp::getOperationName()},
+        name);
+  });
+
+  populateMathPolynomialApproximationPatterns(
+      patterns, [](StringRef name) -> bool {
+        return llvm::is_contained(
+            {math::AtanOp::getOperationName(),
+             math::Atan2Op::getOperationName(),
+             math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
+             math::Log2Op::getOperationName(),
+             math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
+             math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
+             math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
+             math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
+             math::CosOp::getOperationName()},
+            name);
+      });
+
   if (options.enableAvx2) {
-    patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
-        patterns.getContext());
+    auto predicateRsqrt = [](StringRef name) {
+      return name == math::RsqrtOp::getOperationName();
+    };
+    mlir::populateMathF32ExpansionPatterns(patterns, predicateRsqrt);
+    mlir::populateMathPolynomialApproximationPatterns(patterns, predicateRsqrt);
   }
 }



More information about the Mlir-commits mailing list