[Mlir-commits] [mlir] 8a9d489 - [mlir] Clean-up math -> libm/llvm conversion.

Alexander Belyaev llvmlistbot at llvm.org
Tue Feb 21 10:22:30 PST 2023


Author: Alexander Belyaev
Date: 2023-02-21T19:21:54+01:00
New Revision: 8a9d4895df780231a14a1afc44e18b1f6b7eab93

URL: https://github.com/llvm/llvm-project/commit/8a9d4895df780231a14a1afc44e18b1f6b7eab93
DIFF: https://github.com/llvm/llvm-project/commit/8a9d4895df780231a14a1afc44e18b1f6b7eab93.diff

LOG: [mlir] Clean-up math -> libm/llvm conversion.

At the moment, there is an optional log1pBenefit
populateMathToLibmConversionPatterns which is used to increase the priority of
the log1p->libm pattern compared to log1p->llvm pattern that approximates
log1p with precision issues. Instead, we can have a flag for the MathToLLVM
pass to enable or disable the imprecise approximation.

Differential Revision: https://reviews.llvm.org/D144450

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
    mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
    mlir/include/mlir/Conversion/Passes.td
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index d0fc2e390ed79..b2e5db330a64e 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -21,7 +21,8 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 void populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                          RewritePatternSet &patterns);
+                                          RewritePatternSet &patterns,
+                                          bool approximateLog1p = true);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H

diff  --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index cd79e8e491d98..ab9a1cef20cab 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -8,8 +8,7 @@
 #ifndef MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
 #define MLIR_CONVERSION_MATHTOLIBM_MATHTOLIBM_H_
 
-#include "mlir/Transforms/DialectConversion.h"
-#include <optional>
+#include "mlir/IR/PatternMatch.h"
 
 namespace mlir {
 template <typename T>
@@ -20,9 +19,7 @@ class OperationPass;
 
 /// Populate the given list with patterns that convert from Math to Libm calls.
 /// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit,
-    std::optional<PatternBenefit> log1pBenefit = std::nullopt);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
 
 /// Create a pass to convert Math operations to libm calls.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

diff  --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 02299f4170687..33502abc0e2a4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -561,10 +561,11 @@ def ConvertMathToLibm : Pass<"convert-math-to-libm", "ModuleOp"> {
 
 def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
   let summary = "Convert Math dialect to LLVM dialect";
-  let description = [{
-    This pass converts supported Math ops to LLVM dialect intrinsics.
-  }];
   let dependentDialects = ["LLVM::LLVMDialect"];
+  let options = [
+    Option<"approximateLog1p", "approximate-log1p", "bool", "true",
+           "Enable approximation of Log1p.">
+  ];
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 888c51238d063..c331f4f2163bd 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -291,7 +291,7 @@ struct ConvertMathToLLVMPass
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext());
-    populateMathToLLVMConversionPatterns(converter, patterns);
+    populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
     LLVMConversionTarget target(getContext());
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
@@ -301,7 +301,10 @@ struct ConvertMathToLLVMPass
 } // namespace
 
 void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
-                                                RewritePatternSet &patterns) {
+                                                RewritePatternSet &patterns,
+                                                bool approximateLog1p) {
+  if (approximateLog1p)
+    patterns.add<Log1pOpLowering>(converter);
   // clang-format off
   patterns.add<
     AbsFOpLowering,
@@ -319,7 +322,6 @@ void mlir::populateMathToLLVMConversionPatterns(LLVMTypeConverter &converter,
     FloorOpLowering,
     FmaOpLowering,
     Log10OpLowering,
-    Log1pOpLowering,
     Log2OpLowering,
     LogOpLowering,
     PowFOpLowering,

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 93b58e2ab7d1d..35ac2b3c2bdfe 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -14,11 +14,10 @@
 #include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
-#include <optional>
+#include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTMATHTOLIBM
@@ -52,8 +51,8 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 public:
   using OpRewritePattern<Op>::OpRewritePattern;
   ScalarOpToLibmCall<Op>(MLIRContext *context, StringRef floatFunc,
-                         StringRef doubleFunc, PatternBenefit benefit)
-      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+                         StringRef doubleFunc)
+      : OpRewritePattern<Op>(context), floatFunc(floatFunc),
         doubleFunc(doubleFunc){};
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -152,53 +151,37 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   return success();
 }
 
-void mlir::populateMathToLibmConversionPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit,
-    std::optional<PatternBenefit> log1pBenefit) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+  MLIRContext *ctx = patterns.getContext();
   patterns.add<VecOpToScalarOp<math::Atan2Op>, VecOpToScalarOp<math::CbrtOp>,
                VecOpToScalarOp<math::ExpM1Op>, VecOpToScalarOp<math::TanhOp>,
                VecOpToScalarOp<math::CosOp>, VecOpToScalarOp<math::SinOp>,
                VecOpToScalarOp<math::ErfOp>, VecOpToScalarOp<math::RoundEvenOp>,
                VecOpToScalarOp<math::RoundOp>, VecOpToScalarOp<math::AtanOp>,
                VecOpToScalarOp<math::TanOp>, VecOpToScalarOp<math::TruncOp>>(
-      patterns.getContext(), benefit);
+      ctx);
   patterns.add<PromoteOpToF32<math::Atan2Op>, PromoteOpToF32<math::CbrtOp>,
                PromoteOpToF32<math::ExpM1Op>, PromoteOpToF32<math::TanhOp>,
                PromoteOpToF32<math::CosOp>, PromoteOpToF32<math::SinOp>,
                PromoteOpToF32<math::ErfOp>, PromoteOpToF32<math::RoundEvenOp>,
                PromoteOpToF32<math::RoundOp>, PromoteOpToF32<math::AtanOp>,
-               PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(
-      patterns.getContext(), benefit);
-  patterns.add<ScalarOpToLibmCall<math::AtanOp>>(patterns.getContext(), "atanf",
-                                                 "atan", benefit);
-  patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(patterns.getContext(),
-                                                  "atan2f", "atan2", benefit);
-  patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(patterns.getContext(), "cbrtf",
-                                                 "cbrt", benefit);
-  patterns.add<ScalarOpToLibmCall<math::ErfOp>>(patterns.getContext(), "erff",
-                                                "erf", benefit);
-  patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(patterns.getContext(),
-                                                  "expm1f", "expm1", benefit);
-  patterns.add<ScalarOpToLibmCall<math::TanOp>>(patterns.getContext(), "tanf",
-                                                "tan", benefit);
-  patterns.add<ScalarOpToLibmCall<math::TanhOp>>(patterns.getContext(), "tanhf",
-                                                 "tanh", benefit);
-  patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(
-      patterns.getContext(), "roundevenf", "roundeven", benefit);
-  patterns.add<ScalarOpToLibmCall<math::RoundOp>>(patterns.getContext(),
-                                                  "roundf", "round", benefit);
-  patterns.add<ScalarOpToLibmCall<math::CosOp>>(patterns.getContext(), "cosf",
-                                                "cos", benefit);
-  patterns.add<ScalarOpToLibmCall<math::SinOp>>(patterns.getContext(), "sinf",
-                                                "sin", benefit);
-  patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(
-      patterns.getContext(), "log1pf", "log1p", log1pBenefit.value_or(benefit));
-  patterns.add<ScalarOpToLibmCall<math::FloorOp>>(patterns.getContext(),
-                                                  "floorf", "floor", benefit);
-  patterns.add<ScalarOpToLibmCall<math::CeilOp>>(patterns.getContext(), "ceilf",
-                                                 "ceil", benefit);
-  patterns.add<ScalarOpToLibmCall<math::TruncOp>>(patterns.getContext(),
-                                                  "truncf", "trunc", benefit);
+               PromoteOpToF32<math::TanOp>, PromoteOpToF32<math::TruncOp>>(ctx);
+  patterns.add<ScalarOpToLibmCall<math::AtanOp>>(ctx, "atanf", "atan");
+  patterns.add<ScalarOpToLibmCall<math::Atan2Op>>(ctx, "atan2f", "atan2");
+  patterns.add<ScalarOpToLibmCall<math::CbrtOp>>(ctx, "cbrtf", "cbrt");
+  patterns.add<ScalarOpToLibmCall<math::ErfOp>>(ctx, "erff", "erf");
+  patterns.add<ScalarOpToLibmCall<math::ExpM1Op>>(ctx, "expm1f", "expm1");
+  patterns.add<ScalarOpToLibmCall<math::TanOp>>(ctx, "tanf", "tan");
+  patterns.add<ScalarOpToLibmCall<math::TanhOp>>(ctx, "tanhf", "tanh");
+  patterns.add<ScalarOpToLibmCall<math::RoundEvenOp>>(ctx, "roundevenf",
+                                                      "roundeven");
+  patterns.add<ScalarOpToLibmCall<math::RoundOp>>(ctx, "roundf", "round");
+  patterns.add<ScalarOpToLibmCall<math::CosOp>>(ctx, "cosf", "cos");
+  patterns.add<ScalarOpToLibmCall<math::SinOp>>(ctx, "sinf", "sin");
+  patterns.add<ScalarOpToLibmCall<math::Log1pOp>>(ctx, "log1pf", "log1p");
+  patterns.add<ScalarOpToLibmCall<math::FloorOp>>(ctx, "floorf", "floor");
+  patterns.add<ScalarOpToLibmCall<math::CeilOp>>(ctx, "ceilf", "ceil");
+  patterns.add<ScalarOpToLibmCall<math::TruncOp>>(ctx, "truncf", "trunc");
 }
 
 namespace {
@@ -212,7 +195,7 @@ void ConvertMathToLibmPass::runOnOperation() {
   auto module = getOperation();
 
   RewritePatternSet patterns(&getContext());
-  populateMathToLibmConversionPatterns(patterns, /*benefit=*/1);
+  populateMathToLibmConversionPatterns(patterns);
 
   ConversionTarget target(getContext());
   target.addLegalDialect<arith::ArithDialect, BuiltinDialect, func::FuncDialect,


        


More information about the Mlir-commits mailing list