[Mlir-commits] [mlir] 5c93eb5 - [MLIR][Math] Add optional benefit arg to populate math lowering patterns (#127291)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 14 20:38:14 PST 2025


Author: William Moses
Date: 2025-02-14T22:38:11-06:00
New Revision: 5c93eb56dc9bc0c0210483cdd5d31e6b6580454f

URL: https://github.com/llvm/llvm-project/commit/5c93eb56dc9bc0c0210483cdd5d31e6b6580454f
DIFF: https://github.com/llvm/llvm-project/commit/5c93eb56dc9bc0c0210483cdd5d31e6b6580454f.diff

LOG: [MLIR][Math] Add optional benefit arg to populate math lowering patterns (#127291)

Co-authored-by: Ivan R. Ivanov <ivanov.i.aa at m.titech.ac.jp>

Added: 
    

Modified: 
    mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
    mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
    mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
    mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index 93cd780bba438..0c1203e1e3c0e 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
 #define MLIR_CONVERSION_MATHTOLLVM_MATHTOLLVM_H
 
+#include "mlir/IR/PatternMatch.h"
 #include <memory>
 
 namespace mlir {
@@ -23,7 +24,8 @@ class Pass;
 
 void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
                                           RewritePatternSet &patterns,
-                                          bool approximateLog1p = true);
+                                          bool approximateLog1p = true,
+                                          PatternBenefit benefit = 1);
 
 void registerConvertMathToLLVMInterface(DialectRegistry &registry);
 

diff  --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index ab9a1cef20cab..8ace53a0fd582 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -19,7 +19,8 @@ class OperationPass;
 
 /// Populate the given list with patterns that convert from Math to Libm calls.
 /// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+                                          PatternBenefit benefit = 1);
 
 /// Create a pass to convert Math operations to libm calls.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();

diff  --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 98680773e00d2..85ec288268aeb 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,9 +304,9 @@ struct ConvertMathToLLVMPass
 
 void mlir::populateMathToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
-    bool approximateLog1p) {
+    bool approximateLog1p, PatternBenefit benefit) {
   if (approximateLog1p)
-    patterns.add<Log1pOpLowering>(converter);
+    patterns.add<Log1pOpLowering>(converter, benefit);
   // clang-format off
   patterns.add<
     AbsFOpLowering,
@@ -337,7 +337,7 @@ void mlir::populateMathToLLVMConversionPatterns(
     FTruncOpLowering,
     TanOpLowering,
     TanhOpLowering
-  >(converter);
+  >(converter, benefit);
   // clang-format on
 }
 

diff  --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51..12a6d9c3452df 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -50,10 +50,10 @@ template <typename Op>
 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 public:
   using OpRewritePattern<Op>::OpRewritePattern;
-  ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
-                     StringRef doubleFunc)
-      : OpRewritePattern<Op>(context), floatFunc(floatFunc),
-        doubleFunc(doubleFunc){};
+  ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
+                     StringRef floatFunc, StringRef doubleFunc)
+      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+        doubleFunc(doubleFunc) {};
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
 
@@ -62,10 +62,11 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 };
 
 template <typename OpTy>
-void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
-                           StringRef floatFunc, StringRef doubleFunc) {
-  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
-  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
+                           MLIRContext *ctx, StringRef floatFunc,
+                           StringRef doubleFunc) {
+  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
+  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
 }
 
 } // namespace
@@ -159,42 +160,54 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   return success();
 }
 
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
+                                                PatternBenefit benefit) {
   MLIRContext *ctx = patterns.getContext();
 
-  populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
-  populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
-  populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
-  populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
-  populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
-  populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
-  populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
-  populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
-  populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
-  populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
-  populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
-  populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
-  populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
-  populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
-  populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
-  populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
-  populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
-  populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
-  populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
-  populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
-  populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
-  populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
-  populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
-  populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+  populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
+  populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
+  populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf",
+                                       "acosh");
+  populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
+  populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf",
+                                       "asinh");
+  populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f",
+                                       "atan2");
+  populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
+  populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf",
+                                       "atanh");
+  populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
+  populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
+  populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
+  populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
+  populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
+  populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
+  populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
+  populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",
+                                       "expm1");
+  populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf",
+                                       "floor");
+  populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
+  populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
+  populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
+  populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f",
+                                       "log10");
+  populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf",
+                                       "log1p");
+  populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
+  populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
                                            "roundeven");
-  populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
-  populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
-  populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
-  populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
-  populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
-  populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
-  populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
-  populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
+  populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf",
+                                       "round");
+  populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
+  populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
+  populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
+  populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf",
+                                       "rsqrt");
+  populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
+  populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
+  populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf",
+                                       "trunc");
 }
 
 namespace {


        


More information about the Mlir-commits mailing list