[Mlir-commits] [mlir] [MLIR][Math] Fix mathtolibm to use conversion patterns (PR #154083)

William Moses llvmlistbot at llvm.org
Mon Aug 18 03:19:34 PDT 2025


https://github.com/wsmoses updated https://github.com/llvm/llvm-project/pull/154083

>From b5a6e88845d9eb53c5bafa8e638ff2cf0fa90981 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh at wsmoses.com>
Date: Mon, 18 Aug 2025 04:21:39 -0500
Subject: [PATCH 1/2] [MLIR][Math] Fix mathtolibm to use conversion patterns

---
 mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index f7c0d4fe3a799..3e9ce6f7d1476 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -29,32 +29,35 @@ namespace {
 // Pattern to convert vector operations to scalar operations. This is needed as
 // libm calls require scalars.
 template <typename Op>
-struct VecOpToScalarOp : public OpRewritePattern<Op> {
+struct VecOpToScalarOp : public OpConversionPattern<Op> {
 public:
-  using OpRewritePattern<Op>::OpRewritePattern;
+  using OpConversionPattern<Op>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+  LogicalResult
+  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
 };
 // Pattern to promote an op of a smaller floating point type to F32.
 template <typename Op>
-struct PromoteOpToF32 : public OpRewritePattern<Op> {
+struct PromoteOpToF32 : public OpConversionPattern<Op> {
 public:
-  using OpRewritePattern<Op>::OpRewritePattern;
+  using OpConversionPattern<Op>::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+  LogicalResult
+  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
 };
 // Pattern to convert scalar math operations to calls to libm functions.
 // Additionally the libm function signatures are declared.
 template <typename Op>
-struct ScalarOpToLibmCall : public OpRewritePattern<Op> {
+struct ScalarOpToLibmCall : public OpConversionPattern<Op> {
 public:
   using OpRewritePattern<Op>::OpRewritePattern;
   ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
                      StringRef floatFunc, StringRef doubleFunc)
-      : OpRewritePattern<Op>(context, benefit), floatFunc(floatFunc),
+      : OpConversionPattern<Op>(context, benefit), floatFunc(floatFunc),
         doubleFunc(doubleFunc) {};
 
-  LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const final;
+  LogicalResult
+  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
 
 private:
   std::string floatFunc, doubleFunc;

>From c75561b2d5adbb4e77811c1b1ab2dbaa92904459 Mon Sep 17 00:00:00 2001
From: "William S. Moses" <gh at wsmoses.com>
Date: Mon, 18 Aug 2025 04:58:55 -0500
Subject: [PATCH 2/2] fmt

---
 mlir/lib/Conversion/MathToLibm/MathToLibm.cpp | 35 +++++++++++--------
 1 file changed, 20 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index 3e9ce6f7d1476..3cbe0aa5fe17e 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -34,7 +34,8 @@ struct VecOpToScalarOp : public OpConversionPattern<Op> {
   using OpConversionPattern<Op>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
+  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final;
 };
 // Pattern to promote an op of a smaller floating point type to F32.
 template <typename Op>
@@ -43,21 +44,23 @@ struct PromoteOpToF32 : public OpConversionPattern<Op> {
   using OpConversionPattern<Op>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
+  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final;
 };
 // Pattern to convert scalar math operations to calls to libm functions.
 // Additionally the libm function signatures are declared.
 template <typename Op>
 struct ScalarOpToLibmCall : public OpConversionPattern<Op> {
 public:
-  using OpRewritePattern<Op>::OpRewritePattern;
+  using OpConversionPattern<Op>::OpConversionPattern;
   ScalarOpToLibmCall(MLIRContext *context, PatternBenefit benefit,
                      StringRef floatFunc, StringRef doubleFunc)
       : OpConversionPattern<Op>(context, benefit), floatFunc(floatFunc),
-        doubleFunc(doubleFunc) {};
+        doubleFunc(doubleFunc){};
 
   LogicalResult
-  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final;
+  matchAndRewrite(Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final;
 
 private:
   std::string floatFunc, doubleFunc;
@@ -74,8 +77,9 @@ void populatePatternsForOp(RewritePatternSet &patterns, PatternBenefit benefit,
 } // namespace
 
 template <typename Op>
-LogicalResult
-VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+LogicalResult VecOpToScalarOp<Op>::matchAndRewrite(
+    Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto opType = op.getType();
   auto loc = op.getLoc();
   auto vecType = dyn_cast<VectorType>(opType);
@@ -95,7 +99,7 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   for (auto linearIndex = 0; linearIndex < numElements; ++linearIndex) {
     SmallVector<int64_t> positions = delinearize(linearIndex, strides);
     SmallVector<Value> operands;
-    for (auto input : op->getOperands())
+    for (auto input : adaptor.getOperands())
       operands.push_back(
           vector::ExtractOp::create(rewriter, loc, input, positions));
     Value scalarOp =
@@ -108,8 +112,9 @@ VecOpToScalarOp<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
 }
 
 template <typename Op>
-LogicalResult
-PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
+LogicalResult PromoteOpToF32<Op>::matchAndRewrite(
+    Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto opType = op.getType();
   if (!isa<Float16Type, BFloat16Type>(opType))
     return failure();
@@ -117,7 +122,7 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
   auto loc = op.getLoc();
   auto f32 = rewriter.getF32Type();
   auto extendedOperands = llvm::to_vector(
-      llvm::map_range(op->getOperands(), [&](Value operand) -> Value {
+      llvm::map_range(adaptor.getOperands(), [&](Value operand) -> Value {
         return arith::ExtFOp::create(rewriter, loc, f32, operand);
       }));
   auto newOp = Op::create(rewriter, loc, f32, extendedOperands);
@@ -126,9 +131,9 @@ PromoteOpToF32<Op>::matchAndRewrite(Op op, PatternRewriter &rewriter) const {
 }
 
 template <typename Op>
-LogicalResult
-ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
-                                        PatternRewriter &rewriter) const {
+LogicalResult ScalarOpToLibmCall<Op>::matchAndRewrite(
+    Op op, typename OpConversionPattern<Op>::OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
   auto module = SymbolTable::getNearestSymbolTable(op);
   auto type = op.getType();
   if (!isa<Float32Type, Float64Type>(type))
@@ -158,7 +163,7 @@ ScalarOpToLibmCall<Op>::matchAndRewrite(Op op,
   assert(isa<FunctionOpInterface>(SymbolTable::lookupSymbolIn(module, name)));
 
   rewriter.replaceOpWithNewOp<func::CallOp>(op, name, op.getType(),
-                                            op->getOperands());
+                                            adaptor.getOperands());
 
   return success();
 }



More information about the Mlir-commits mailing list