[Mlir-commits] [mlir] [mlir][math] add benefit arg to populate math approximations/expansions (PR #130782)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 11 08:00:51 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-math
Author: Emilio Cota (cota)
<details>
<summary>Changes</summary>
This is a follow-up to #<!-- -->127291, which added the benefit arg to lowerings to intrinsics and libm.
In this change we add the benefit arg to the math approximation and expansion lowerings, which allows users to establish a preferred order among all three math lowerings, namely approximations, intrinsics and libm.
Note that we're only updating the new API added in #<!-- -->126103. The legacy one (`mlir::populateMathPolynomialApproximationPatterns`) is left unmodified to encourage users to move out of it.
---
Full diff: https://github.com/llvm/llvm-project/pull/130782.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+5-2)
- (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+54-51)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 9adc1c6940a15..c0fe5d3be448a 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
@@ -52,12 +53,14 @@ void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.
void populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Adds patterns to enable polynomial approximations for math functions for
// which `predicate` returns true.
void populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit = 1);
// Legacy. Calls both populateMathF32ExpansionPatterns and
// populateMathPolynomialApproximationPatterns with predicates enabling a
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 167eebd786dba..a26e380232a91 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1776,90 +1776,93 @@ void mlir::populatePolynomialApproximateErfcPattern(
template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext());
+ patterns.add<ReuseF32Expansion<OpType>>(patterns.getContext(), benefit);
}
}
void mlir::populateMathF32ExpansionPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
- populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate);
- populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate);
- populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate);
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
+ populateMathF32ExpansionPattern<math::AcosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AcoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AsinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Atan2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::AtanhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CbrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::LogOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log10Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log1pOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::Log2Op>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::PowFOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::RsqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SinhOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::SqrtOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanOp>(patterns, predicate, benefit);
+ populateMathF32ExpansionPattern<math::TanhOp>(patterns, predicate, benefit);
}
template <typename OpType, typename PatternType>
static void populateMathPolynomialApproximationPattern(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
if (predicate(OpType::getOperationName())) {
- patterns.add<PatternType>(patterns.getContext());
+ patterns.add<PatternType>(patterns.getContext(), benefit);
}
}
void mlir::populateMathPolynomialApproximationPatterns(
- RewritePatternSet &patterns,
- llvm::function_ref<bool(StringRef)> predicate) {
+ RewritePatternSet &patterns, llvm::function_ref<bool(StringRef)> predicate,
+ PatternBenefit benefit) {
populateMathPolynomialApproximationPattern<AcosOp,
AcosPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AsinOp,
AsinPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<AtanOp, AtanApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Atan2Op, Atan2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<CbrtOp, CbrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
+ CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<LogOp, LogApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log2Op, Log2Approximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<Log1pOp, Log1pApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<RsqrtOp, RsqrtApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
populateMathPolynomialApproximationPattern<
- SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate);
+ SinOp, SinAndCosApproximation<true, math::SinOp>>(patterns, predicate,
+ benefit);
populateMathPolynomialApproximationPattern<TanhOp, TanhApproximation>(
- patterns, predicate);
+ patterns, predicate, benefit);
}
void mlir::populateMathPolynomialApproximationPatterns(
``````````
</details>
https://github.com/llvm/llvm-project/pull/130782
More information about the Mlir-commits
mailing list