[llvm] [mlir] [mlir][EmitC] Add Arith to EmitC conversions (PR #84151)

Marius Brehler via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 09:13:17 PST 2024


https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/84151

>From 43c348a8f5f7982341abf1aaaab053f6908b727b Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Wed, 6 Mar 2024 11:03:11 +0000
Subject: [PATCH 1/2] [mlir][EmitC] Add Arith to EmitC conversions

This adds patterns and a pass to convert the Arith dialect to EmitC.
For now, this covers arithemtic binary ops operating on floating point
types.

It is not checked within the patterns whether the types, such as the
Tensor type, are supported in the respective EmitC operations. If
unsupported types should be converted, the conversion will fail anyway
because no legal EmitC operation can be created. This can clearly be
improved in a follow up, also resulting in better error messages.
Functions for such checks should not solely be used in the conversions
and should also be (re)used in the verifier.
---
 .../Conversion/ArithToEmitC/ArithToEmitC.h    | 18 ++++++
 .../ArithToEmitC/ArithToEmitCPass.h           | 21 +++++++
 mlir/include/mlir/Conversion/Passes.h         |  1 +
 mlir/include/mlir/Conversion/Passes.td        |  9 +++
 .../Conversion/ArithToEmitC/ArithToEmitC.cpp  | 59 +++++++++++++++++++
 .../ArithToEmitC/ArithToEmitCPass.cpp         | 47 +++++++++++++++
 .../Conversion/ArithToEmitC/CMakeLists.txt    | 16 +++++
 mlir/lib/Conversion/CMakeLists.txt            |  1 +
 .../ArithToEmitC/arith-to-emitc.mlir          | 14 +++++
 .../llvm-project-overlay/mlir/BUILD.bazel     | 27 +++++++++
 10 files changed, 213 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
 create mode 100644 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
 create mode 100644 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
 create mode 100644 mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
 create mode 100644 mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
 create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
new file mode 100644
index 00000000000000..1f5e5ef0899b37
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -0,0 +1,18 @@
+//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+
+namespace mlir {
+class RewritePatternSet;
+
+void populateArithToEmitCPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
new file mode 100644
index 00000000000000..6b98fed7185ead
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
@@ -0,0 +1,21 @@
+//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f2aa4fb535402d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..0e76069faf44c0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
+  let summary = "Convert Arith dialect to EmitC dialect";
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArithToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..15942f54441424
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,59 @@
+//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+template <typename ArithOp, typename EmitCOp>
+class ArithOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+  using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+                                                  adaptor.getOperands());
+
+    return success();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
+  MLIRContext *ctx = patterns.getContext();
+
+  // clang-format off
+  patterns.add<
+    ArithOpConversion<arith::AddFOp, emitc::AddOp>,
+    ArithOpConversion<arith::DivFOp, emitc::DivOp>,
+    ArithOpConversion<arith::MulFOp, emitc::MulOp>,
+    ArithOpConversion<arith::SubFOp, emitc::SubOp>
+  >(ctx);
+  // clang-format on
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
new file mode 100644
index 00000000000000..e3ac6965ade22a
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -0,0 +1,47 @@
+//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertArithToEmitC
+    : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertArithToEmitC::runOnOperation() {
+  ConversionTarget target(getContext());
+
+  target.addLegalDialect<emitc::EmitCDialect>();
+
+  RewritePatternSet patterns(&getContext());
+  populateArithToEmitCPatterns(patterns);
+
+  if (failed(
+          applyPartialConversion(getOperation(), target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..a3784f47c3bc2d
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArithToEmitC
+  ArithToEmitC.cpp
+  ArithToEmitCPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRArithDialect
+  MLIREmitCDialect
+  MLIRPass
+  MLIRTransformUtils
+  )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
 add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
new file mode 100644
index 00000000000000..6a56474a5c48b2
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s
+
+func.func @arith_ops(%arg0: f32, %arg1: f32) {
+  // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32
+  %0 = arith.addf %arg0, %arg1 : f32
+  // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32
+  %1 = arith.divf %arg0, %arg1 : f32  
+  // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32
+  %2 = arith.mulf %arg0, %arg1 : f32
+  // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32
+  %3 = arith.subf %arg0, %arg1 : f32
+
+  return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8a8dd6e10c48aa..2961b1574c49b7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4011,6 +4011,7 @@ cc_library(
         ":AffineToStandard",
         ":ArithToAMDGPU",
         ":ArithToArmSME",
+        ":ArithToEmitC",
         ":ArithToLLVM",
         ":ArithToSPIRV",
         ":ArmNeon2dToIntr",
@@ -8156,6 +8157,32 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ArithToEmitC",
+    srcs = glob([
+        "lib/Conversion/ArithToEmitC/*.cpp",
+        "lib/Conversion/ArithToEmitC/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/ArithToEmitC/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/ArithToEmitC",
+    ],
+    deps = [
+        ":ArithDialect",
+        ":ConversionPassIncGen",
+        ":EmitCDialect",
+        ":IR",
+        ":Pass",
+        ":Support",
+        ":TransformUtils",
+        ":Transforms",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "ArithToLLVM",
     srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),

>From e2fe2a1e6714e818e188c045a0843044cdfa4611 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Wed, 6 Mar 2024 17:05:35 +0000
Subject: [PATCH 2/2] Address review comments

---
 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h | 3 ++-
 mlir/include/mlir/Conversion/Passes.td                   | 2 +-
 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp        | 2 +-
 mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp    | 4 +++-
 4 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
index 1f5e5ef0899b37..53ba32f1c4dbbc 100644
--- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -11,8 +11,9 @@
 
 namespace mlir {
 class RewritePatternSet;
+class TypeConverter;
 
-void populateArithToEmitCPatterns(RewritePatternSet &patterns);
+void populateArithToEmitCPatterns(TypeConverter typeConverter, RewritePatternSet &patterns);
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 0e76069faf44c0..bd81cc6d5323bf 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -137,7 +137,7 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
 // ArithToEmitC
 //===----------------------------------------------------------------------===//
 
-def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
   let summary = "Convert Arith dialect to EmitC dialect";
   let dependentDialects = ["emitc::EmitCDialect"];
 }
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 15942f54441424..289acc9222a3af 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -45,7 +45,7 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
 // Pattern population
 //===----------------------------------------------------------------------===//
 
-void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
+void mlir::populateArithToEmitCPatterns(TypeConverter typeConverter, RewritePatternSet &patterns) {
   MLIRContext *ctx = patterns.getContext();
 
   // clang-format off
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index e3ac6965ade22a..3421bd53b241fc 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -39,7 +39,9 @@ void ConvertArithToEmitC::runOnOperation() {
   target.addLegalDialect<emitc::EmitCDialect>();
 
   RewritePatternSet patterns(&getContext());
-  populateArithToEmitCPatterns(patterns);
+  TypeConverter typeConverter;
+
+  populateArithToEmitCPatterns(typeConverter, patterns);
 
   if (failed(
           applyPartialConversion(getOperation(), target, std::move(patterns))))



More information about the llvm-commits mailing list