[Mlir-commits] [mlir] [mlir][math] add benefit arg to populate math approximations/expansions (PR #130782)
Emilio Cota
llvmlistbot at llvm.org
Tue Mar 11 07:59:49 PDT 2025
https://github.com/cota created https://github.com/llvm/llvm-project/pull/130782
This is a follow-up to #127291, which added the benefit arg to lowerings to intrinsics and libm.
In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm.
Note that we're only updating the new API added in #126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it.
>From a09c4297da745fba484fb6f9a19169b114068223 Mon Sep 17 00:00:00 2001
From: Emilio Cota <ecg at google.com>
Date: Tue, 11 Mar 2025 10:27:25 -0400
Subject: [PATCH] [mlir][math] add benefit arg to populate math
approximations/expansions
This is a follow-up to #127291, which added the benefit arg to
lowerings to intrinsics and libm.
In this change we add the benefit arg to the math approximation
and expansion lowerings, which allows users to establish a preferred
order among all three math lowerings, namely approximations,
intrinsics and libm.
Note that we're only updating the new API added in #126103.
The legacy one (`mlir::populateMathPolynomialApproximationPatterns`)
is left unmodified to encourage users to move out of it.
---
.../mlir/Dialect/Math/Transforms/Passes.h | 7 +-
.../Transforms/PolynomialApproximation.cpp | 105 +++++++++---------
2 files changed, 59 insertions(+), 53 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9adc1c6940a15..c0fe5d3be448a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.
void populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Adds patterns to enable polynomial approximations for math functions for
// which `predicate` returns true.
void populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Legacy. Calls both populateMathF32ExpansionPatterns and
// populateMathPolynomialApproximationPatterns with predicates enabling a
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 167eebd786dba..a26e380232a91 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern(
template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+ patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
}
}
void mlir::populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
- populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
}
template <typename OpType, typename PatternType>
static void populateMathPolynomialApproximationPattern(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<PatternType>(patterns.getContext());
+ patterns.add<PatternType>(patterns.getContext(), benefit);
}
}
void mlir::populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
populateMathPolynomialApproximationPattern<AcosOp,
AcosPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AsinOp,
AsinPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+ CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+ SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
}
void mlir::populateMathPolynomialApproximationPatterns(
More information about the Mlir-commits
mailing list