[llvm] [mlir] [mlir][EmitC] Add Arith to EmitC conversions (PR #84151)
Marius Brehler via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 7 00:43:04 PST 2024
https://github.com/marbre updated https://github.com/llvm/llvm-project/pull/84151
>From 43c348a8f5f7982341abf1aaaab053f6908b727b Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Wed, 6 Mar 2024 11:03:11 +0000
Subject: [PATCH 1/3] [mlir][EmitC] Add Arith to EmitC conversions
This adds patterns and a pass to convert the Arith dialect to EmitC.
For now, this covers arithemtic binary ops operating on floating point
types.
It is not checked within the patterns whether the types, such as the
Tensor type, are supported in the respective EmitC operations. If
unsupported types should be converted, the conversion will fail anyway
because no legal EmitC operation can be created. This can clearly be
improved in a follow up, also resulting in better error messages.
Functions for such checks should not solely be used in the conversions
and should also be (re)used in the verifier.
---
.../Conversion/ArithToEmitC/ArithToEmitC.h | 18 ++++++
.../ArithToEmitC/ArithToEmitCPass.h | 21 +++++++
mlir/include/mlir/Conversion/Passes.h | 1 +
mlir/include/mlir/Conversion/Passes.td | 9 +++
.../Conversion/ArithToEmitC/ArithToEmitC.cpp | 59 +++++++++++++++++++
.../ArithToEmitC/ArithToEmitCPass.cpp | 47 +++++++++++++++
.../Conversion/ArithToEmitC/CMakeLists.txt | 16 +++++
mlir/lib/Conversion/CMakeLists.txt | 1 +
.../ArithToEmitC/arith-to-emitc.mlir | 14 +++++
.../llvm-project-overlay/mlir/BUILD.bazel | 27 +++++++++
10 files changed, 213 insertions(+)
create mode 100644 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
create mode 100644 mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
create mode 100644 mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
create mode 100644 mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
create mode 100644 mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
create mode 100644 mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
new file mode 100644
index 00000000000000..1f5e5ef0899b37
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -0,0 +1,18 @@
+//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+
+namespace mlir {
+class RewritePatternSet;
+
+void populateArithToEmitCPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
new file mode 100644
index 00000000000000..6b98fed7185ead
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
@@ -0,0 +1,21 @@
+//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
+
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f2aa4fb535402d 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..0e76069faf44c0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
+ let summary = "Convert Arith dialect to EmitC dialect";
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..15942f54441424
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,59 @@
+//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements patterns to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// Conversion Patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+template <typename ArithOp, typename EmitCOp>
+class ArithOpConversion final : public OpConversionPattern<ArithOp> {
+public:
+ using OpConversionPattern<ArithOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
+ adaptor.getOperands());
+
+ return success();
+ }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern population
+//===----------------------------------------------------------------------===//
+
+void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
+ MLIRContext *ctx = patterns.getContext();
+
+ // clang-format off
+ patterns.add<
+ ArithOpConversion<arith::AddFOp, emitc::AddOp>,
+ ArithOpConversion<arith::DivFOp, emitc::DivOp>,
+ ArithOpConversion<arith::MulFOp, emitc::MulOp>,
+ ArithOpConversion<arith::SubFOp, emitc::SubOp>
+ >(ctx);
+ // clang-format on
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
new file mode 100644
index 00000000000000..e3ac6965ade22a
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -0,0 +1,47 @@
+//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert the Arith dialect to the EmitC
+// dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTARITHTOEMITC
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+struct ConvertArithToEmitC
+ : public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertArithToEmitC::runOnOperation() {
+ ConversionTarget target(getContext());
+
+ target.addLegalDialect<emitc::EmitCDialect>();
+
+ RewritePatternSet patterns(&getContext());
+ populateArithToEmitCPatterns(patterns);
+
+ if (failed(
+ applyPartialConversion(getOperation(), target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..a3784f47c3bc2d
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRArithToEmitC
+ ArithToEmitC.cpp
+ ArithToEmitCPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRArithDialect
+ MLIREmitCDialect
+ MLIRPass
+ MLIRTransformUtils
+ )
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
new file mode 100644
index 00000000000000..6a56474a5c48b2
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s
+
+func.func @arith_ops(%arg0: f32, %arg1: f32) {
+ // CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32
+ %0 = arith.addf %arg0, %arg1 : f32
+ // CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32
+ %1 = arith.divf %arg0, %arg1 : f32
+ // CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32
+ %2 = arith.mulf %arg0, %arg1 : f32
+ // CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32
+ %3 = arith.subf %arg0, %arg1 : f32
+
+ return
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 8a8dd6e10c48aa..2961b1574c49b7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -4011,6 +4011,7 @@ cc_library(
":AffineToStandard",
":ArithToAMDGPU",
":ArithToArmSME",
+ ":ArithToEmitC",
":ArithToLLVM",
":ArithToSPIRV",
":ArmNeon2dToIntr",
@@ -8156,6 +8157,32 @@ cc_library(
],
)
+cc_library(
+ name = "ArithToEmitC",
+ srcs = glob([
+ "lib/Conversion/ArithToEmitC/*.cpp",
+ "lib/Conversion/ArithToEmitC/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Conversion/ArithToEmitC/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/ArithToEmitC",
+ ],
+ deps = [
+ ":ArithDialect",
+ ":ConversionPassIncGen",
+ ":EmitCDialect",
+ ":IR",
+ ":Pass",
+ ":Support",
+ ":TransformUtils",
+ ":Transforms",
+ "//llvm:Support",
+ ],
+)
+
cc_library(
name = "ArithToLLVM",
srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),
>From 0b306405dc59e64edf097728ca9d805a14825405 Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Wed, 6 Mar 2024 17:05:35 +0000
Subject: [PATCH 2/3] Address review comments
---
mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h | 4 +++-
mlir/include/mlir/Conversion/Passes.td | 2 +-
mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp | 3 ++-
mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp | 4 +++-
4 files changed, 9 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
index 1f5e5ef0899b37..c71dfc6aa533f0 100644
--- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -11,8 +11,10 @@
namespace mlir {
class RewritePatternSet;
+class TypeConverter;
-void populateArithToEmitCPatterns(RewritePatternSet &patterns);
+void populateArithToEmitCPatterns(TypeConverter typeConverter,
+ RewritePatternSet &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 0e76069faf44c0..bd81cc6d5323bf 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -137,7 +137,7 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
// ArithToEmitC
//===----------------------------------------------------------------------===//
-def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
+def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
let summary = "Convert Arith dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"];
}
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
index 15942f54441424..596556c502dd44 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -45,7 +45,8 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
// Pattern population
//===----------------------------------------------------------------------===//
-void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
+void mlir::populateArithToEmitCPatterns(TypeConverter typeConverter,
+ RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();
// clang-format off
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index e3ac6965ade22a..3421bd53b241fc 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -39,7 +39,9 @@ void ConvertArithToEmitC::runOnOperation() {
target.addLegalDialect<emitc::EmitCDialect>();
RewritePatternSet patterns(&getContext());
- populateArithToEmitCPatterns(patterns);
+ TypeConverter typeConverter;
+
+ populateArithToEmitCPatterns(typeConverter, patterns);
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
>From b70455f2d1f9d408fcca7399b408fe5925c0d2ea Mon Sep 17 00:00:00 2001
From: Marius Brehler <marius.brehler at iml.fraunhofer.de>
Date: Thu, 7 Mar 2024 08:42:03 +0000
Subject: [PATCH 3/3] Mark ArithDialect illegal and ConstantOp legal
---
mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp | 2 ++
1 file changed, 2 insertions(+)
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
index 3421bd53b241fc..57df60466115a2 100644
--- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
@@ -37,6 +37,8 @@ void ConvertArithToEmitC::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<emitc::EmitCDialect>();
+ target.addIllegalDialect<arith::ArithDialect>();
+ target.addLegalOp<arith::ConstantOp>();
RewritePatternSet patterns(&getContext());
TypeConverter typeConverter;
More information about the llvm-commits
mailing list