[Mlir-commits] [mlir] [MLIR][Math] Add optional benefit arg to populate math lowering patterns (PR #127291)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Feb 14 17:23:15 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: William Moses (wsmoses)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/127291.diff


4 Files Affected:

- (modified) mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h (+1) 
- (modified) mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h (+1-1) 
- (modified) mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp (+3-2) 
- (modified) mlir/lib/Conversion/MathToLibm/MathToLibm.cpp (+38-38) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
index 93cd780bba438..b7883fe9a55ff 100644
--- a/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
+++ b/mlir/include/mlir/Conversion/MathToLLVM/MathToLLVM.h
@@ -23,6 +23,7 @@ class Pass;
 
 void populateMathToLLVMConversionPatterns(const LLVMTypeConverter &converter,
                                           RewritePatternSet &patterns,
+					  PatternBenefit benefit = 1,
                                           bool approximateLog1p = true);
 
 void registerConvertMathToLLVMInterface(DialectRegistry &registry);
diff --git a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
index ab9a1cef20cab..6db661a7b5748 100644
--- a/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
+++ b/mlir/include/mlir/Conversion/MathToLibm/MathToLibm.h
@@ -19,7 +19,7 @@ class OperationPass;
 
 /// Populate the given list with patterns that convert from Math to Libm calls.
 /// If log1pBenefit is present, use it instead of benefit for the Log1p op.
-void populateMathToLibmConversionPatterns(RewritePatternSet &patterns);
+void populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1);
 
 /// Create a pass to convert Math operations to libm calls.
 std::unique_ptr<OperationPass<ModuleOp>> createConvertMathToLibmPass();
diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
index 98680773e00d2..196fad2d8367b 100644
--- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
+++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp
@@ -304,9 +304,10 @@ struct ConvertMathToLLVMPass
 
 void mlir::populateMathToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns,
+    PatternBenefit benefit,
     bool approximateLog1p) {
   if (approximateLog1p)
-    patterns.add<Log1pOpLowering>(converter);
+    patterns.add<Log1pOpLowering>(converter, benefit);
   // clang-format off
   patterns.add<
     AbsFOpLowering,
@@ -337,7 +338,7 @@ void mlir::populateMathToLLVMConversionPatterns(
     FTruncOpLowering,
     TanOpLowering,
     TanhOpLowering
-  >(converter);
+  >(converter, benefit);
   // clang-format on
 }
 
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51..97ec5cf178f5e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -50,9 +50,9 @@ template <typename Op>
 struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 public:
   using OpRewritePattern<Op>::OpRewritePattern;
-  ScalarOpToLibmCall(MLIRContext *context, StringRef floatFunc,
+  ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit, StringRef floatFunc,
                      StringRef doubleFunc)
-      : OpRewritePattern<Op>(context), floatFunc(floatFunc),
+      : OpRewritePattern<Op>(context, benegit, ), floatFunc(floatFunc),
         doubleFunc(doubleFunc){};
 
   LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
@@ -62,10 +62,10 @@ struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
 };
 
 template <typename OpTy>
-void populatePatternsForOp(RewritePatternSet &patterns, MLIRContext *ctx,
+void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit, MLIRContext *ctx,
                            StringRef floatFunc, StringRef doubleFunc) {
-  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx);
-  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, floatFunc, doubleFunc);
+  patterns.add<VecOpToScalarOp<OpTy>, PromoteOpToF32<OpTy>>(ctx, benefit);
+  patterns.add<ScalarOpToLibmCall<OpTy>>(ctx, benefit, floatFunc, doubleFunc);
 }
 
 } // namespace
@@ -159,42 +159,42 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   return success();
 }
 
-void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
+void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1) {
   MLIRContext *ctx = patterns.getContext();
 
-  populatePatternsForOp<math::AbsFOp>(patterns, ctx, "fabsf", "fabs");
-  populatePatternsForOp<math::AcosOp>(patterns, ctx, "acosf", "acos");
-  populatePatternsForOp<math::AcoshOp>(patterns, ctx, "acoshf", "acosh");
-  populatePatternsForOp<math::AsinOp>(patterns, ctx, "asinf", "asin");
-  populatePatternsForOp<math::AsinhOp>(patterns, ctx, "asinhf", "asinh");
-  populatePatternsForOp<math::Atan2Op>(patterns, ctx, "atan2f", "atan2");
-  populatePatternsForOp<math::AtanOp>(patterns, ctx, "atanf", "atan");
-  populatePatternsForOp<math::AtanhOp>(patterns, ctx, "atanhf", "atanh");
-  populatePatternsForOp<math::CbrtOp>(patterns, ctx, "cbrtf", "cbrt");
-  populatePatternsForOp<math::CeilOp>(patterns, ctx, "ceilf", "ceil");
-  populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
-  populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
-  populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
-  populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
-  populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
-  populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
-  populatePatternsForOp<math::FloorOp>(patterns, ctx, "floorf", "floor");
-  populatePatternsForOp<math::FmaOp>(patterns, ctx, "fmaf", "fma");
-  populatePatternsForOp<math::LogOp>(patterns, ctx, "logf", "log");
-  populatePatternsForOp<math::Log2Op>(patterns, ctx, "log2f", "log2");
-  populatePatternsForOp<math::Log10Op>(patterns, ctx, "log10f", "log10");
-  populatePatternsForOp<math::Log1pOp>(patterns, ctx, "log1pf", "log1p");
-  populatePatternsForOp<math::PowFOp>(patterns, ctx, "powf", "pow");
-  populatePatternsForOp<math::RoundEvenOp>(patterns, ctx, "roundevenf",
+  populatePatternsForOp<math::AbsFOp>(patterns, benefit, ctx, "fabsf", "fabs");
+  populatePatternsForOp<math::AcosOp>(patterns, benefit, ctx, "acosf", "acos");
+  populatePatternsForOp<math::AcoshOp>(patterns, benefit, ctx, "acoshf", "acosh");
+  populatePatternsForOp<math::AsinOp>(patterns, benefit, ctx, "asinf", "asin");
+  populatePatternsForOp<math::AsinhOp>(patterns, benefit, ctx, "asinhf", "asinh");
+  populatePatternsForOp<math::Atan2Op>(patterns, benefit, ctx, "atan2f", "atan2");
+  populatePatternsForOp<math::AtanOp>(patterns, benefit, ctx, "atanf", "atan");
+  populatePatternsForOp<math::AtanhOp>(patterns, benefit, ctx, "atanhf", "atanh");
+  populatePatternsForOp<math::CbrtOp>(patterns, benefit, ctx, "cbrtf", "cbrt");
+  populatePatternsForOp<math::CeilOp>(patterns, benefit, ctx, "ceilf", "ceil");
+  populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
+  populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
+  populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
+  populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
+  populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
+  populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f", "expm1");
+  populatePatternsForOp<math::FloorOp>(patterns, benefit, ctx, "floorf", "floor");
+  populatePatternsForOp<math::FmaOp>(patterns, benefit, ctx, "fmaf", "fma");
+  populatePatternsForOp<math::LogOp>(patterns, benefit, ctx, "logf", "log");
+  populatePatternsForOp<math::Log2Op>(patterns, benefit, ctx, "log2f", "log2");
+  populatePatternsForOp<math::Log10Op>(patterns, benefit, ctx, "log10f", "log10");
+  populatePatternsForOp<math::Log1pOp>(patterns, benefit, ctx, "log1pf", "log1p");
+  populatePatternsForOp<math::PowFOp>(patterns, benefit, ctx, "powf", "pow");
+  populatePatternsForOp<math::RoundEvenOp>(patterns, benefit, ctx, "roundevenf",
                                            "roundeven");
-  populatePatternsForOp<math::RoundOp>(patterns, ctx, "roundf", "round");
-  populatePatternsForOp<math::SinOp>(patterns, ctx, "sinf", "sin");
-  populatePatternsForOp<math::SinhOp>(patterns, ctx, "sinhf", "sinh");
-  populatePatternsForOp<math::SqrtOp>(patterns, ctx, "sqrtf", "sqrt");
-  populatePatternsForOp<math::RsqrtOp>(patterns, ctx, "rsqrtf", "rsqrt");
-  populatePatternsForOp<math::TanOp>(patterns, ctx, "tanf", "tan");
-  populatePatternsForOp<math::TanhOp>(patterns, ctx, "tanhf", "tanh");
-  populatePatternsForOp<math::TruncOp>(patterns, ctx, "truncf", "trunc");
+  populatePatternsForOp<math::RoundOp>(patterns, benefit, ctx, "roundf", "round");
+  populatePatternsForOp<math::SinOp>(patterns, benefit, ctx, "sinf", "sin");
+  populatePatternsForOp<math::SinhOp>(patterns, benefit, ctx, "sinhf", "sinh");
+  populatePatternsForOp<math::SqrtOp>(patterns, benefit, ctx, "sqrtf", "sqrt");
+  populatePatternsForOp<math::RsqrtOp>(patterns, benefit, ctx, "rsqrtf", "rsqrt");
+  populatePatternsForOp<math::TanOp>(patterns, benefit, ctx, "tanf", "tan");
+  populatePatternsForOp<math::TanhOp>(patterns, benefit, ctx, "tanhf", "tanh");
+  populatePatternsForOp<math::TruncOp>(patterns, benefit, ctx, "truncf", "trunc");
 }
 
 namespace {

``````````

</details>


https://github.com/llvm/llvm-project/pull/127291


More information about the Mlir-commits mailing list