[Mlir-commits] [mlir] 5c93eb5 - [MLIR][Math] Add optional benefit arg to populate math lowering patterns (#127291)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Feb 14 20:38:14 PST 2025
Author: William Moses
Date: 2025-02-14T22:38:11-06:00
New Revision: 5c93eb56dc9bc0c0210483cdd5d31e6b6580454f
URL: https://github.com/llvm/llvm-project/commit/5c93eb56dc9bc0c0210483cdd5d31e6b6580454f
DIFF: https://github.com/llvm/llvm-project/commit/5c93eb56dc9bc0c0210483cdd5d31e6b6580454f.diff
LOG: [MLIR][Math] Add optional benefit arg to populate math lowering patterns (#127291)
Co-authored-by: Ivan R. Ivanov <ivanov.i.aa at m.titech.ac.jp>
Added:
Modified:
mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index 93cd780bba438..0c1203e1e3c0e 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
#define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
+#include "mlir/IR/PatternMatch.h"
#include <memory>
namespace mlir {
@@ -23,7 +24,8 @@ class Pass;
void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
RewritePatternSet &patterns,
- bool approximateLog1p = true);
+ bool approximateLog1p = true,
+ PatternBenefit benefit = 1);
void registerConvertMathToLLVMInterface(DialectRegistry ®istry);
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index ab9a1cef20cab..8ace53a0fd582 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -19,7 +19,8 @@ class OperationPass;
/// Populate the given list with patterns that convert from Math to Libm calls.
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 98680773e00d2..85ec288268aeb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,9 +304,9 @@ struct ConvertMathToLLVMPass
void mlir::populateMathToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
- bool approximateLog1p) {
+ bool approximateLog1p, PatternBenefit benefit) {
if (approximateLog1p)
- patterns.add<Log1pOpLowering>(converter);
+ patterns.add<Log1pOpLowering>(converter, benefit);
// clang-format off
patterns.add<
AbsFOpLowering,
@@ -337,7 +337,7 @@ void mlir::populateMathToLLVMConversionPatterns(
FTruncOpLowering,
TanOpLowering,
TanhOpLowering
- >(converter);
+ >(converter, benefit);
// clang-format on
}
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51..12a6d9c3452df 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -50,10 +50,10 @@ template <typename Op>
struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
- ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
- StringRef doubleFunc)
- : OpRewritePattern<Op>(context), floatFunc(floatFunc),
- doubleFunc(doubleFunc){};
+ ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
+ StringRef floatFunc, StringRef doubleFunc)
+ : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ doubleFunc(doubleFunc) {};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
};
template <typename OpTy>
-void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
- StringRef floatFunc, StringRef doubleFunc) {
- patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
- patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
+ MLIRContext *ctx, StringRef floatFunc,
+ StringRef doubleFunc) {
+ patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
+ patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
}
} // namespace
@@ -159,42 +160,54 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit) {
MLIRContext *ctx = patterns.getContext();
- populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
- populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
- populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
- populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
- populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
- populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
- populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
- populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
- populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
- populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
- populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
- populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
- populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
- populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
- populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
- populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
- populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
- populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
- populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
- populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
- populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
- populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
- populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
- populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+ populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
+ populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
+ populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf",
+ "acosh");
+ populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
+ populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf",
+ "asinh");
+ populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f",
+ "atan2");
+ populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
+ populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf",
+ "atanh");
+ populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
+ populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
+ populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
+ populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
+ populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
+ populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
+ populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
+ populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",
+ "expm1");
+ populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf",
+ "floor");
+ populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
+ populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
+ populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
+ populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f",
+ "log10");
+ populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf",
+ "log1p");
+ populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
+ populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
"roundeven");
- populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
- populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
- populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
- populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
- populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
- populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
- populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
- populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
+ populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf",
+ "round");
+ populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
+ populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
+ populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
+ populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf",
+ "rsqrt");
+ populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
+ populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
+ populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf",
+ "trunc");
}
namespace {
More information about the Mlir-commits
mailing list