[llvm] [mlir] [mlir][EmitC] Add MathToEmitC pass for math function lowering to EmitC (PR #113799)

Tomer Solomon via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 29 02:54:40 PDT 2024


https://github.com/recursion-man updated https://github.com/llvm/llvm-project/pull/113799

>From d5bd00c0454a2c11db3adc47c1a91a0a318f49af Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomsol2009 at gmail.com>
Date: Tue, 22 Oct 2024 16:05:58 +0300
Subject: [PATCH 1/2] [mlir][EmitC] Add MathToEmitC pass for math function
 lowering to EmitC

This commit introduces a new `MathToEmitC` conversion pass that lowers selected math operations to the `emitc.call_opaque` operation in the EmitC dialect.

The supported math operations include:
- math.floor -> emitc.call_opaque<"floor">
- math.exp -> emitc.call_opaque<"exp">
- math.cos -> emitc.call_opaque<"cos">
- math.sin -> emitc.call_opaque<"sin">
- math.ipowi -> emitc.call_opaque<"pow">

We chose to use `emitc.call_opaque` instead of `emitc.call` to better align with C-style function overloading. Unlike `emitc.call`, which requires unique type signatures, `emitc.call_opaque` allows us to call functions without specifying a unique type-based signature. This flexibility is essential for mimicking function overloading behavior as seen in `<math.h>`.

Additionally, the pass inserts an `emitc.include` operation to generate `#include <math.h>` at the top of the module to ensure the availability of the necessary math functions in the generated code.

This pass enables the use of EmitC as an intermediate layer to generate C/C++ code with opaque calls to standard math functions.
---
 .../mlir/Conversion/MathToEmitC/MathToEmitC.h |  25 ++++
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  19 +++
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../lib/Conversion/MathToEmitC/CMakeLists.txt |  19 +++
 .../Conversion/MathToEmitC/MathToEmitC.cpp    |  99 +++++++++++++
 .../Conversion/MathToEmitC/math-to-emitc.mlir | 140 ++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  22 +++
 8 files changed, 326 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
 create mode 100644 mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
 create mode 100644 mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir

diff --git a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
new file mode 100644
index 00000000000000..f2e8779b057937
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
@@ -0,0 +1,25 @@
+//===- MathToEmitC.h - Math to EmitC Pass -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
+#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+
+#define GEN_PASS_DECL_CONVERTMATHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertMathToEmitCPass();
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 2ab32836c80b1c..54e795cd137d31 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -43,6 +43,7 @@
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
+#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 4d272ba219c6f1..09a93439ab898f 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -780,6 +780,25 @@ def ConvertMathToSPIRV : Pass<"convert-math-to-spirv"> {
   let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToEmitC  : Pass<"convert-math-to-emitc", "ModuleOp"> {
+  let summary = "Convert some Math operations to EmitC Call_opaque";
+  let description = [{
+    This pass converts supported Math ops to call_opaque calls to compiler generated
+    functions implementing these operations in software.
+    Unlike convert-math-to-funcs pass, this pass uses call_opaque,
+    therefore enables us to overload the same funtion with different argument types
+  }];
+
+  let constructor = "mlir::createConvertMathToEmitCPass()";
+  let dependentDialects = ["emitc::EmitCDialect",
+                           "math::MathDialect"
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToFuncs
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 6651d87162257f..120b4972454d57 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -33,6 +33,7 @@ add_subdirectory(IndexToLLVM)
 add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
+add_subdirectory(MathToEmitC)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
diff --git a/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..7b02a57dff3d4d
--- /dev/null
+++ b/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRMathToEmitC
+  MathToEmitC.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRLLVMCommonConversion
+  MLIREmitCDialect
+  MLIRMathDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
new file mode 100644
index 00000000000000..43641a8ad634a7
--- /dev/null
+++ b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
@@ -0,0 +1,99 @@
+
+//===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//  Replaces Math operations with `emitc.call_opaque` operations.
+struct ConvertMathToEmitCPass
+    : public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
+public:
+  void runOnOperation() final;
+};
+
+} // end anonymous namespace
+
+template <typename OpType>
+class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
+  std::string calleeStr;
+
+public:
+  LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr)
+      : OpRewritePattern<OpType>(context), calleeStr(calleeStr) {}
+
+  LogicalResult matchAndRewrite(OpType op,
+                                PatternRewriter &rewriter) const override;
+};
+
+// Populates patterns to replace `math` operations with `emitc.call_opaque`,
+// using function names consistent with those in <math.h>.
+static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
+  auto *context = patterns.getContext();
+  patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor");
+  patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint");
+  patterns.insert<LowerToEmitCCallOpaque<math::ExpOp>>(context, "exp");
+  patterns.insert<LowerToEmitCCallOpaque<math::CosOp>>(context, "cos");
+  patterns.insert<LowerToEmitCCallOpaque<math::SinOp>>(context, "sin");
+  patterns.insert<LowerToEmitCCallOpaque<math::AcosOp>>(context, "acos");
+  patterns.insert<LowerToEmitCCallOpaque<math::AsinOp>>(context, "asin");
+  patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2");
+  patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil");
+  patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs");
+  patterns.insert<LowerToEmitCCallOpaque<math::FPowIOp>>(context, "powf");
+  patterns.insert<LowerToEmitCCallOpaque<math::IPowIOp>>(context, "pow");
+}
+
+template <typename OpType>
+LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
+    OpType op, PatternRewriter &rewriter) const {
+  mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
+  auto actualOp = mlir::cast<OpType>(op);
+  rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
+      actualOp, actualOp.getType(), callee, actualOp->getOperands());
+  return mlir::success();
+}
+
+void ConvertMathToEmitCPass::runOnOperation() {
+  auto moduleOp = getOperation();
+  // Insert #include <math.h> at the beginning of the module
+  OpBuilder builder(moduleOp.getBodyRegion());
+  builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
+  builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
+                                   builder.getStringAttr("math.h"));
+
+  ConversionTarget target(getContext());
+  target.addLegalOp<emitc::CallOpaqueOp>();
+
+  target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
+                      math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
+                      math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
+                      math::FPowIOp, math::IPowIOp>();
+
+  RewritePatternSet patterns(&getContext());
+  populateConvertMathToEmitCPatterns(patterns);
+
+  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+    signalPassFailure();
+}
+
+std::unique_ptr<OperationPass<mlir::ModuleOp>>
+mlir::createConvertMathToEmitCPass() {
+  return std::make_unique<ConvertMathToEmitCPass>();
+}
diff --git a/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
new file mode 100644
index 00000000000000..9add25d71ef478
--- /dev/null
+++ b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt --split-input-file -convert-math-to-emitc %s | FileCheck %s
+
+// CHECK-LABEL:   emitc.include "math.h"
+
+// CHECK-LABEL:   func.func @absf_to_call_opaque(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "fabs"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @absf_to_call_opaque(%arg0: f32) {
+    %1 = math.absf %arg0 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @floor_to_call_opaque(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "floor"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @floor_to_call_opaque(%arg0: f32) {
+    %1 = math.floor %arg0 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @sin_to_call_opaque(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "sin"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @sin_to_call_opaque(%arg0: f32) {
+    %1 = math.sin %arg0 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @cos_to_call_opaque(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "cos"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @cos_to_call_opaque(%arg0: f32) {
+    %1 = math.cos %arg0 : f32
+    return
+  }
+
+
+// -----
+
+// CHECK-LABEL:   func.func @asin_to_call_opaque(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "asin"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @asin_to_call_opaque(%arg0: f32) {
+    %1 = math.asin %arg0 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @acos_to_call_opaque(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "acos"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @acos_to_call_opaque(%arg0: f32) {
+    %1 = math.acos %arg0 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @atan2_to_call_opaque(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: f32,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: f32) {
+// CHECK:           %[[VAL_2:.*]] = emitc.call_opaque "atan2"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @atan2_to_call_opaque(%arg0: f32, %arg1: f32) {
+    %1 = math.atan2 %arg0, %arg1 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @ceil_to_call_opaque(
+// CHECK-SAME:                                   %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "ceil"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @ceil_to_call_opaque(%arg0: f32) {
+    %1 = math.ceil %arg0 : f32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @exp_to_call_opaque(
+// CHECK-SAME:                                  %[[VAL_0:.*]]: f32) {
+// CHECK:           %[[VAL_1:.*]] = emitc.call_opaque "exp"(%[[VAL_0]]) : (f32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @exp_to_call_opaque(%arg0: f32) {
+    %1 = math.exp %arg0 : f32
+    return
+  }
+
+
+// -----
+
+// CHECK-LABEL:   func.func @fpowi_to_call_opaque(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: f32,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: i32) {
+// CHECK:           %[[VAL_2:.*]] = emitc.call_opaque "powf"(%[[VAL_0]], %[[VAL_1]]) : (f32, i32) -> f32
+// CHECK:           return
+// CHECK:         }
+func.func @fpowi_to_call_opaque(%arg0: f32, %arg1: i32) {
+    %1 = math.fpowi %arg0, %arg1 : f32, i32
+    return
+  }
+
+// -----
+
+// CHECK-LABEL:   func.func @ipowi_to_call_opaque(
+// CHECK-SAME:                                    %[[VAL_0:.*]]: i32,
+// CHECK-SAME:                                    %[[VAL_1:.*]]: i32) {
+// CHECK:           %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32
+// CHECK:           return
+// CHECK:         }
+func.func @ipowi_to_call_opaque(%arg0: i32, %arg1: i32) {
+    %1 = math.ipowi %arg0, %arg1 : i32
+    return
+  }
+
+
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 779609340d7224..b193e4295e4759 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4201,6 +4201,7 @@ cc_library(
         ":IndexToLLVM",
         ":IndexToSPIRV",
         ":LinalgToStandard",
+        ":MathToEmitC",
         ":MathToFuncs",
         ":MathToLLVM",
         ":MathToLibm",
@@ -8721,6 +8722,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "MathToEmitC",
+    srcs = glob([
+        "lib/Conversion/MathToEmitC/*.cpp",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/MathToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/MathToEmitC",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":MathDialect",
+        ":Pass",
+        ":TransformUtils",
+    ],
+)
+
 cc_library(
     name = "MathToFuncs",
     srcs = glob(["lib/Conversion/MathToFuncs/*.cpp"]),

>From ad9af428683e8dba4d17a31b4193c709b15726c6 Mon Sep 17 00:00:00 2001
From: Tomer Solomon <tomsol2009 at gmail.com>
Date: Tue, 29 Oct 2024 11:47:39 +0200
Subject: [PATCH 2/2] [MLIR][MathToEmitC] Ensure scalar type handling and
 refactor

This patch ensures that the MathToEmitC pass only converts scalar `FloatType`s, avoiding invalid conversions of non-scalar types like tensors.

- **Validation:** Added checks to convert only scalar types.
- **Refactoring:** Moved implementation to `MathToEmitCPass.cpp` and split headers.
- **Testing:** Added test cases to ensure proper error handling for non-scalar types.
---
 .../mlir/Conversion/MathToEmitC/MathToEmitC.h | 10 +--
 .../Conversion/MathToEmitC/MathToEmitCPass.h  | 21 +++++
 mlir/include/mlir/Conversion/Passes.h         |  2 +-
 mlir/include/mlir/Conversion/Passes.td        |  2 -
 .../lib/Conversion/MathToEmitC/CMakeLists.txt |  1 +
 .../Conversion/MathToEmitC/MathToEmitC.cpp    | 84 ++++++-------------
 .../MathToEmitC/MathToEmitCPass.cpp           | 58 +++++++++++++
 .../Conversion/MathToEmitC/math-to-emitc.mlir | 31 +++----
 8 files changed, 120 insertions(+), 89 deletions(-)
 create mode 100644 mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h
 create mode 100644 mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp

diff --git a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
index f2e8779b057937..e2a8d59ffcd6b3 100644
--- a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
+++ b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitC.h
@@ -9,16 +9,10 @@
 #ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
 #define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITC_H
 
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/Pass/Pass.h"
-#include <memory>
-
 namespace mlir {
+class RewritePatternSet;
 
-#define GEN_PASS_DECL_CONVERTMATHTOEMITC
-#include "mlir/Conversion/Passes.h.inc"
-
-std::unique_ptr<OperationPass<mlir::ModuleOp>> createConvertMathToEmitCPass();
+void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns);
 
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h
new file mode 100644
index 00000000000000..5e92fba71b5107
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToEmitC/MathToEmitCPass.h
@@ -0,0 +1,21 @@
+//===- MathToEmitCPass.h - Math to EmitC Pass -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
+#define MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOEMITC_MATHTOEMITCPASS_H
\ No newline at end of file
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 54e795cd137d31..6749cee0edccc9 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -43,7 +43,7 @@
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
 #include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
-#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
+#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 09a93439ab898f..20baaf40ead2f6 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -792,8 +792,6 @@ def ConvertMathToEmitC  : Pass<"convert-math-to-emitc", "ModuleOp"> {
     Unlike convert-math-to-funcs pass, this pass uses call_opaque,
     therefore enables us to overload the same funtion with different argument types
   }];
-
-  let constructor = "mlir::createConvertMathToEmitCPass()";
   let dependentDialects = ["emitc::EmitCDialect",
                            "math::MathDialect"
   ];
diff --git a/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt b/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
index 7b02a57dff3d4d..8996869c0e7a54 100644
--- a/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
+++ b/mlir/lib/Conversion/MathToEmitC/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_conversion_library(MLIRMathToEmitC
   MathToEmitC.cpp
+  MathToEmitCPass.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToEmitC
diff --git a/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
index 43641a8ad634a7..c5422c09c6c221 100644
--- a/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
+++ b/mlir/lib/Conversion/MathToEmitC/MathToEmitC.cpp
@@ -1,4 +1,3 @@
-
 //===- MathToEmitC.cpp - Math to EmitC Pass Implementation ----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -8,43 +7,49 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
+
 #include "mlir/Dialect/EmitC/IR/EmitC.h"
 #include "mlir/Dialect/Math/IR/Math.h"
-#include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/DialectConversion.h"
 
-namespace mlir {
-#define GEN_PASS_DEF_CONVERTMATHTOEMITC
-#include "mlir/Conversion/Passes.h.inc"
-} // namespace mlir
-
 using namespace mlir;
-namespace {
-
-//  Replaces Math operations with `emitc.call_opaque` operations.
-struct ConvertMathToEmitCPass
-    : public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
-public:
-  void runOnOperation() final;
-};
-
-} // end anonymous namespace
 
+namespace {
 template <typename OpType>
 class LowerToEmitCCallOpaque : public mlir::OpRewritePattern<OpType> {
   std::string calleeStr;
 
 public:
   LowerToEmitCCallOpaque(MLIRContext *context, std::string calleeStr)
-      : OpRewritePattern<OpType>(context), calleeStr(calleeStr) {}
+      : OpRewritePattern<OpType>(context), calleeStr(std::move(calleeStr)) {}
 
   LogicalResult matchAndRewrite(OpType op,
                                 PatternRewriter &rewriter) const override;
 };
 
+template <typename OpType>
+LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
+    OpType op, PatternRewriter &rewriter) const {
+  auto actualOp = mlir::cast<OpType>(op);
+  if (!llvm::all_of(
+          actualOp->getOperands(),
+          [](Value operand) { return isa<FloatType>(operand.getType()); }) ||
+      !llvm::all_of(actualOp->getResultTypes(),
+                    [](mlir::Type type) { return isa<FloatType>(type); })) {
+    op.emitError("non-float types are not supported");
+    return mlir::failure();
+  }
+  mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
+  rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
+      actualOp, actualOp.getType(), callee, actualOp->getOperands());
+  return mlir::success();
+}
+
+} // namespace
+
 // Populates patterns to replace `math` operations with `emitc.call_opaque`,
 // using function names consistent with those in <math.h>.
-static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
+void mlir::populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
   patterns.insert<LowerToEmitCCallOpaque<math::FloorOp>>(context, "floor");
   patterns.insert<LowerToEmitCCallOpaque<math::RoundEvenOp>>(context, "rint");
@@ -56,44 +61,5 @@ static void populateConvertMathToEmitCPatterns(RewritePatternSet &patterns) {
   patterns.insert<LowerToEmitCCallOpaque<math::Atan2Op>>(context, "atan2");
   patterns.insert<LowerToEmitCCallOpaque<math::CeilOp>>(context, "ceil");
   patterns.insert<LowerToEmitCCallOpaque<math::AbsFOp>>(context, "fabs");
-  patterns.insert<LowerToEmitCCallOpaque<math::FPowIOp>>(context, "powf");
-  patterns.insert<LowerToEmitCCallOpaque<math::IPowIOp>>(context, "pow");
-}
-
-template <typename OpType>
-LogicalResult LowerToEmitCCallOpaque<OpType>::matchAndRewrite(
-    OpType op, PatternRewriter &rewriter) const {
-  mlir::StringAttr callee = rewriter.getStringAttr(calleeStr);
-  auto actualOp = mlir::cast<OpType>(op);
-  rewriter.replaceOpWithNewOp<mlir::emitc::CallOpaqueOp>(
-      actualOp, actualOp.getType(), callee, actualOp->getOperands());
-  return mlir::success();
-}
-
-void ConvertMathToEmitCPass::runOnOperation() {
-  auto moduleOp = getOperation();
-  // Insert #include <math.h> at the beginning of the module
-  OpBuilder builder(moduleOp.getBodyRegion());
-  builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
-  builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
-                                   builder.getStringAttr("math.h"));
-
-  ConversionTarget target(getContext());
-  target.addLegalOp<emitc::CallOpaqueOp>();
-
-  target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
-                      math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
-                      math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
-                      math::FPowIOp, math::IPowIOp>();
-
-  RewritePatternSet patterns(&getContext());
-  populateConvertMathToEmitCPatterns(patterns);
-
-  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
-    signalPassFailure();
-}
-
-std::unique_ptr<OperationPass<mlir::ModuleOp>>
-mlir::createConvertMathToEmitCPass() {
-  return std::make_unique<ConvertMathToEmitCPass>();
+  patterns.insert<LowerToEmitCCallOpaque<math::PowFOp>>(context, "pow");
 }
diff --git a/mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp b/mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp
new file mode 100644
index 00000000000000..6e0ea81b34a924
--- /dev/null
+++ b/mlir/lib/Conversion/MathToEmitC/MathToEmitCPass.cpp
@@ -0,0 +1,58 @@
+//===- MathToEmitCPass.cpp - Math to EmitC Pass -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Math dialect to the EmitC dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToEmitC/MathToEmitCPass.h"
+#include "mlir/Conversion/MathToEmitC/MathToEmitC.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//  Replaces Math operations with `emitc.call_opaque` operations.
+struct ConvertMathToEmitCPass
+    : public impl::ConvertMathToEmitCBase<ConvertMathToEmitCPass> {
+public:
+  void runOnOperation() final;
+};
+
+} // end anonymous namespace
+
+void ConvertMathToEmitCPass::runOnOperation() {
+  auto moduleOp = getOperation();
+  // Insert #include <math.h> at the beginning of the module
+  OpBuilder builder(moduleOp.getBodyRegion());
+  builder.setInsertionPointToStart(&moduleOp.getBodyRegion().front());
+  builder.create<emitc::IncludeOp>(moduleOp.getLoc(),
+                                   builder.getStringAttr("math.h"));
+
+  ConversionTarget target(getContext());
+  target.addLegalOp<emitc::CallOpaqueOp>();
+
+  target.addIllegalOp<math::FloorOp, math::ExpOp, math::RoundEvenOp,
+                      math::CosOp, math::SinOp, math::Atan2Op, math::CeilOp,
+                      math::AcosOp, math::AsinOp, math::AbsFOp, math::PowFOp,
+                      math::FPowIOp, math::IPowIOp>();
+
+  RewritePatternSet patterns(&getContext());
+  populateConvertMathToEmitCPatterns(patterns);
+
+  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+    signalPassFailure();
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
index 9add25d71ef478..6cf8b53e73839a 100644
--- a/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
+++ b/mlir/test/Conversion/MathToEmitC/math-to-emitc.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --split-input-file -convert-math-to-emitc %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -convert-math-to-emitc -verify-diagnostics %s | FileCheck %s
 
 // CHECK-LABEL:   emitc.include "math.h"
 
@@ -110,31 +110,24 @@ func.func @exp_to_call_opaque(%arg0: f32) {
     return
   }
 
-
 // -----
 
-// CHECK-LABEL:   func.func @fpowi_to_call_opaque(
+// CHECK-LABEL:   func.func @powf_to_call_opaque(
 // CHECK-SAME:                                    %[[VAL_0:.*]]: f32,
-// CHECK-SAME:                                    %[[VAL_1:.*]]: i32) {
-// CHECK:           %[[VAL_2:.*]] = emitc.call_opaque "powf"(%[[VAL_0]], %[[VAL_1]]) : (f32, i32) -> f32
+// CHECK-SAME:                                    %[[VAL_1:.*]]: f32) {
+// CHECK:           %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (f32, f32) -> f32
 // CHECK:           return
 // CHECK:         }
-func.func @fpowi_to_call_opaque(%arg0: f32, %arg1: i32) {
-    %1 = math.fpowi %arg0, %arg1 : f32, i32
+func.func @powf_to_call_opaque(%arg0: f32, %arg1: f32) {
+    %1 = math.powf %arg0, %arg1 : f32
     return
   }
 
 // -----
 
-// CHECK-LABEL:   func.func @ipowi_to_call_opaque(
-// CHECK-SAME:                                    %[[VAL_0:.*]]: i32,
-// CHECK-SAME:                                    %[[VAL_1:.*]]: i32) {
-// CHECK:           %[[VAL_2:.*]] = emitc.call_opaque "pow"(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32
-// CHECK:           return
-// CHECK:         }
-func.func @ipowi_to_call_opaque(%arg0: i32, %arg1: i32) {
-    %1 = math.ipowi %arg0, %arg1 : i32
-    return
-  }
-
-
+func.func @test(%arg0 : tensor<4xf32>) -> tensor<4xf32> {
+// expected-error @+2 {{failed to legalize operation 'math.absf' that was explicitly marked illegal}}
+// expected-error @+1 {{non-float types are not supported}}
+  %0 = math.absf %arg0 : tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
\ No newline at end of file



More information about the llvm-commits mailing list