[Mlir-commits] [mlir] 8a9d489 - [mlir] Clean-up math -> libm/llvm conversion.
Alexander Belyaev
llvmlistbot at llvm.org
Tue Feb 21 10:22:30 PST 2023
Author: Alexander Belyaev
Date: 2023-02-21T19:21:54+01:00
New Revision: 8a9d4895df780231a14a1afc44e18b1f6b7eab93
URL: https://github.com/llvm/llvm-project/commit/8a9d4895df780231a14a1afc44e18b1f6b7eab93
DIFF: https://github.com/llvm/llvm-project/commit/8a9d4895df780231a14a1afc44e18b1f6b7eab93.diff
LOG: [mlir] Clean-up math -> libm/llvm conversion.
At the moment, there is an optional log1pBenefit
populateMathToLibmConversionPatterns which is used to increase the priority of
the log1p->libm pattern compared to log1p->llvm pattern that approximates
log1p with precision issues. Instead, we can have a flag for the MathToLLVM
pass to enable or disable the imprecise approximation.
Differential Revision: https://reviews.llvm.org/D144450
Added:
Modified:
mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index d0fc2e390ed79..b2e5db330a64e 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -21,7 +21,8 @@ class Pass;
#include "mlir/Conversion/Passes.h.inc"
void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns);
+ RewritePatternSet &patterns,
+ bool approximateLog1p = true);
} // namespace mlir
#endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index cd79e8e491d98..ab9a1cef20cab 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -8,8 +8,7 @@
#ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
#define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
-#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
+#include "mlir/IR/PatternMatch.h"
namespace mlir {
template <typename T>
@@ -20,9 +19,7 @@ class OperationPass;
/// Populate the given list with patterns that convert from Math to Libm calls.
/// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit,
- std::optional<PatternBenefit> log1pBenefit = std::nullopt);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
/// Create a pass to convert Math operations to libm calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 02299f4170687..33502abc0e2a4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -561,10 +561,11 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
let summary = "Convert Math dialect to LLVM dialect";
- let description = [{
- This pass converts supported Math ops to LLVM dialect intrinsics.
- }];
let dependentDialects = ["LLVM::LLVMDialect"];
+ let options = [
+ Option<"approximateLog1p", "approximate-log1p", "bool", "true",
+ "Enable approximation of Log1p.">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 888c51238d063..c331f4f2163bd 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -291,7 +291,7 @@ struct ConvertMathToLLVMPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
- populateMathToLLVMConversionPatterns(converter, patterns);
+ populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -301,7 +301,10 @@ struct ConvertMathToLLVMPass
} // namespace
void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
- RewritePatternSet &patterns) {
+ RewritePatternSet &patterns,
+ bool approximateLog1p) {
+ if (approximateLog1p)
+ patterns.add<Log1pOpLowering>(converter);
// clang-format off
patterns.add<
AbsFOpLowering,
@@ -319,7 +322,6 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
FloorOpLowering,
FmaOpLowering,
Log10OpLowering,
- Log1pOpLowering,
Log2OpLowering,
LogOpLowering,
PowFOpLowering,
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 93b58e2ab7d1d..35ac2b3c2bdfe 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -14,11 +14,10 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
-#include <optional>
+#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOLIBM
@@ -52,8 +51,8 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
public:
using OpRewritePattern<Op>::OpRewritePattern;
ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
- StringRef doubleFunc, PatternBenefit benefit)
- : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+ StringRef doubleFunc)
+ : OpRewritePattern<Op>(context), floatFunc(floatFunc),
doubleFunc(doubleFunc){};
LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -152,53 +151,37 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
return success();
}
-void mlir::populateMathToLibmConversionPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit,
- std::optional<PatternBenefit> log1pBenefit) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+ MLIRContext *ctx = patterns.getContext();
patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
- patterns.getContext(), benefit);
+ ctx);
patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
- PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(
- patterns.getContext(), benefit);
- patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
- "atan", benefit);
- patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
- "atan2f", "atan2", benefit);
- patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(patterns.getContext(), "cbrtf",
- "cbrt", benefit);
- patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
- "erf", benefit);
- patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
- "expm1f", "expm1", benefit);
- patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf",
- "tan", benefit);
- patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
- "tanh", benefit);
- patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(
- patterns.getContext(), "roundevenf", "roundeven", benefit);
- patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
- "roundf", "round", benefit);
- patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
- "cos", benefit);
- patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
- "sin", benefit);
- patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
- patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
- patterns.add<ScalarOpToLibmCall<math::FloorOp>>(patterns.getContext(),
- "floorf", "floor", benefit);
- patterns.add<ScalarOpToLibmCall<math::CeilOp>>(patterns.getContext(), "ceilf",
- "ceil", benefit);
- patterns.add<ScalarOpToLibmCall<math::TruncOp>>(patterns.getContext(),
- "truncf", "trunc", benefit);
+ PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(ctx);
+ patterns.add<ScalarOpToLibmCall<math::AtanOp>>(ctx, "atanf", "atan");
+ patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(ctx, "atan2f", "atan2");
+ patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(ctx, "cbrtf", "cbrt");
+ patterns.add<ScalarOpToLibmCall<math::ErfOp>>(ctx, "erff", "erf");
+ patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(ctx, "expm1f", "expm1");
+ patterns.add<ScalarOpToLibmCall<math::TanOp>>(ctx, "tanf", "tan");
+ patterns.add<ScalarOpToLibmCall<math::TanhOp>>(ctx, "tanhf", "tanh");
+ patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(ctx, "roundevenf",
+ "roundeven");
+ patterns.add<ScalarOpToLibmCall<math::RoundOp>>(ctx, "roundf", "round");
+ patterns.add<ScalarOpToLibmCall<math::CosOp>>(ctx, "cosf", "cos");
+ patterns.add<ScalarOpToLibmCall<math::SinOp>>(ctx, "sinf", "sin");
+ patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(ctx, "log1pf", "log1p");
+ patterns.add<ScalarOpToLibmCall<math::FloorOp>>(ctx, "floorf", "floor");
+ patterns.add<ScalarOpToLibmCall<math::CeilOp>>(ctx, "ceilf", "ceil");
+ patterns.add<ScalarOpToLibmCall<math::TruncOp>>(ctx, "truncf", "trunc");
}
namespace {
@@ -212,7 +195,7 @@ void ConvertMathToLibmPass::runOnOperation() {
auto module = getOperation();
RewritePatternSet patterns(&getContext());
- populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
+ populateMathToLibmConversionPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,
More information about the Mlir-commits
mailing list