[Mlir-commits] [mlir] [MLIR][Math] Add optional benefit arg to populate math lowering patterns (PR #127291)
William Moses
llvmlistbot at llvm.org
Fri Feb 14 18:16:10 PST 2025
https://github.com/wsmoses updated https://github.com/llvm/llvm-project/pull/127291
>From 282d90acb6c0a666fa115eb64ad1d26bb7565551 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh at wsmoses.com>
Date: Fri, 14 Feb 2025 19:22:20 -0600
Subject: [PATCH 1/5] [MLIR][Math] Add optional benefit arg to populate math
lowering patterns
---
.../mlir/Conversion/MathToLLVM/MathToLLVM.h | 1 +
.../mlir/Conversion/MathToLibm/MathToLibm.h | 2 +-
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 5 +-
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 76 +++++++++----------
4 files changed, 43 insertions(+), 41 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index 93cd780bba438..b7883fe9a55ff 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -23,6 +23,7 @@ class Pass;
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
+ PatternBenefit benefit = 1,
bool approximateLog1p = true);
void registerConvertMathToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index ab9a1cef20cab..6db661a7b5748 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -19,7 +19,7 @@ class OperationPass;
/// Populate the given list with patterns that convert from Math to Libm calls.
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1);
/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 98680773e00d2..196fad2d8367b 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,9 +304,10 @@ struct ConvertMathToLLVMPass
void mlir::populateMathToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+ PatternBenefit benefit,
bool approximateLog1p) {
if (approximateLog1p)
- patterns.add<Log1pOpLowering>(converter);
+ patterns.add<Log1pOpLowering>(converter, benefit);
// clang-format off
patterns.add<
AbsFOpLowering,
@@ -337,7 +338,7 @@ void mlir::populateMathToLLVMConversionPatterns(
FTruncOpLowering,
TanOpLowering,
TanhOpLowering
- >(converter);
+ >(converter, benefit);
// clang-format on
}
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51..97ec5cf178f5e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -50,9 +50,9 @@ template <typename Op>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
- ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
+ ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, StringRef floatFunc,
StringRef doubleFunc)
- : OpRewritePattern<Op>(context), floatFunc(floatFunc),
+ : OpRewritePattern<Op>(context, benegit, ), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -62,10 +62,10 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
};
template <typename OpTy>
-void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
+void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit, MLIRContext *ctx,
StringRef floatFunc, StringRef doubleFunc) {
- patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
- patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+ patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
+ patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
}
} // namespace
@@ -159,42 +159,42 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1) {
MLIRContext *ctx = patterns.getContext();
- populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
- populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
- populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
- populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
- populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
- populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
- populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
- populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
- populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
- populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
- populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
- populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
- populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
- populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
- populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
- populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
- populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
- populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
- populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
- populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
- populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
- populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
- populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+ populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
+ populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
+ populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf", "acosh");
+ populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
+ populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf", "asinh");
+ populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f", "atan2");
+ populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
+ populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf", "atanh");
+ populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
+ populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
+ populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
+ populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
+ populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
+ populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
+ populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
+ populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f", "expm1");
+ populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf", "floor");
+ populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
+ populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
+ populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
+ populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f", "log10");
+ populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf", "log1p");
+ populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
"roundeven");
- populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
- populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
- populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
- populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
- populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
- populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
- populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
- populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
+ populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf", "round");
+ populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
+ populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
+ populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
+ populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf", "rsqrt");
+ populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
+ populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
+ populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf", "trunc");
}
namespace {
>From 29026273810312c055513ded61ff1f65cf73cbbb Mon Sep 17 00:00:00 2001
From: William Moses <gh at wsmoses.com>
Date: Fri, 14 Feb 2025 19:54:54 -0600
Subject: [PATCH 2/5] Update mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Co-authored-by: Ivan R. Ivanov <ivanov.i.aa at m.titech.ac.jp>
---
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 97ec5cf178f5e..86967fbe914fa 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -52,7 +52,7 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, StringRef floatFunc,
StringRef doubleFunc)
- : OpRewritePattern<Op>(context, benegit, ), floatFunc(floatFunc),
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
>From 79e569add838dbef3b9a9f7db51fd301ac5e5830 Mon Sep 17 00:00:00 2001
From: William Moses <gh at wsmoses.com>
Date: Fri, 14 Feb 2025 20:04:45 -0600
Subject: [PATCH 3/5] Update MathToLLVM.h
---
mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index b7883fe9a55ff..d28ac728c04ed 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -10,6 +10,7 @@
#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
#include <memory>
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
>From c5cef7d47da307bc90125351ceb093cc2ef3492c Mon Sep 17 00:00:00 2001
From: William Moses <gh at wsmoses.com>
Date: Fri, 14 Feb 2025 20:14:27 -0600
Subject: [PATCH 4/5] Update MathToLibm.cpp
---
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 86967fbe914fa..e3a8afd52e10e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -159,7 +159,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) {
MLIRContext *ctx = patterns.getContext();
populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
>From 0aa729aa03fc7e010f287b0159f065ad0ec9f16c Mon Sep 17 00:00:00 2001
From: William Moses <gh at wsmoses.com>
Date: Fri, 14 Feb 2025 20:16:00 -0600
Subject: [PATCH 5/5] fixup arg order
---
mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h | 4 ++--
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp | 3 +--
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index d28ac728c04ed..d1b9dd58bf82d 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -24,8 +24,8 @@ class Pass;
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
- PatternBenefit benefit = 1,
- bool approximateLog1p = true);
+ bool approximateLog1p = true,
+ PatternBenefit benefit = 1);
void registerConvertMathToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 196fad2d8367b..85ec288268aeb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,8 +304,7 @@ struct ConvertMathToLLVMPass
void mlir::populateMathToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- PatternBenefit benefit,
- bool approximateLog1p) {
+ bool approximateLog1p, PatternBenefit benefit) {
if (approximateLog1p)
patterns.add<Log1pOpLowering>(converter, benefit);
// clang-format off
More information about the Mlir-commits
mailing list