[llvm] [mlir] [mlir][EmitC] Add Arith to EmitC conversions (PR #84151)

via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 03:21:16 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Marius Brehler (marbre)

<details>
<summary>Changes</summary>

This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types.

It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.

---
Full diff: https://github.com/llvm/llvm-project/pull/84151.diff


7 Files Affected:

- (modified) mlir/include/mlir/Conversion/Passes.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+9) 
- (added) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+59) 
- (added) mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp (+47) 
- (added) mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt (+16) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1) 
- (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+27) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f2aa4fb535402d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..0e76069faf44c0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
+  let summary = "Convert Arith dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArithToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..15942f54441424
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,59 @@
+//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+template <typename ArithOp, typename EmitCOp>
+class ArithOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+                                                  adaptor.getOperands());
+
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
+  MLIRContext *ctx = patterns.getContext();
+
+  // clang-format off
+  patterns.add<
+    ArithOpConversion<arith::AddFOp, emitc::AddOp>,
+    ArithOpConversion<arith::DivFOp, emitc::DivOp>,
+    ArithOpConversion<arith::MulFOp, emitc::MulOp>,
+    ArithOpConversion<arith::SubFOp, emitc::SubOp>
+  >(ctx);
+  // clang-format on
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
new file mode 100644
index 00000000000000..f6a531eaaca242
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -0,0 +1,47 @@
+//===- ArithToEmitCPass.cpp - Func to EmitC Pass ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertArithToEmitC
+    : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertArithToEmitC::runOnOperation() {
+  ConversionTarget target(getContext());
+
+  target.addLegalDialect<emitc::EmitCDialect>();
+
+  RewritePatternSet patterns(&getContext());
+  populateArithToEmitCPatterns(patterns);
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..a3784f47c3bc2d
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArithToEmitC
+  ArithToEmitC.cpp
+  ArithToEmitCPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIREmitCDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
 add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8a8dd6e10c48aa..2961b1574c49b7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4011,6 +4011,7 @@ cc_library(
         ":AffineToStandard",
         ":ArithToAMDGPU",
         ":ArithToArmSME",
+        ":ArithToEmitC",
         ":ArithToLLVM",
         ":ArithToSPIRV",
         ":ArmNeon2dToIntr",
@@ -8156,6 +8157,32 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArithToEmitC",
+    srcs = glob([
+        "lib/Conversion/ArithToEmitC/*.cpp",
+        "lib/Conversion/ArithToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/ArithToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/ArithToEmitC",
+    ],
+    deps = [
+        ":ArithDialect",
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "ArithToLLVM",
     srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),

``````````

</details>


https://github.com/llvm/llvm-project/pull/84151


More information about the llvm-commits mailing list