[Mlir-commits] [mlir] [mlir][emitc] Arith to EmitC conversion pass (PR #83798)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 4 01:09:03 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Tina Jung (TinaAMD)
<details>
<summary>Changes</summary>
Add a conversion pass from Arith to EmitC. Add an initial conversion from `arith.constant` to `emitc.constant`.
---
Full diff: https://github.com/llvm/llvm-project/pull/83798.diff
8 Files Affected:
- (added) mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h (+22)
- (modified) mlir/include/mlir/Conversion/Passes.h (+1)
- (modified) mlir/include/mlir/Conversion/Passes.td (+12)
- (added) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+104)
- (added) mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt (+17)
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
- (added) mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir (+15)
- (added) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+21)
``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
new file mode 100644
index 00000000000000..43322ac7f51f6c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -0,0 +1,22 @@
+//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f41400a633ef22 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..358ac997fba2a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,18 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
];
}
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> {
+ let summary = "Convert Arith ops to EmitC ops";
+ let description = [{
+ Convert `arith` operations to operations in the `emitc` dialect.
+ }];
+ let dependentDialects = ["emitc::EmitCDialect"];
+}
+
//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..648fd2b4af0b70
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,104 @@
+//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert arith ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+static bool isConvertibleToEmitC(Type type) {
+ Type baseType = type;
+ if (auto tensorType = dyn_cast<TensorType>(type)) {
+ if (!tensorType.hasRank() || !tensorType.hasStaticShape()) {
+ return false;
+ }
+ baseType = tensorType.getElementType();
+ }
+
+ if (isa<IndexType>(baseType)) {
+ return true;
+ }
+
+ if (auto intType = dyn_cast<IntegerType>(baseType)) {
+ switch (intType.getWidth()) {
+ case 1:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ return true;
+ }
+ return false;
+ }
+
+ if (auto floatType = dyn_cast<FloatType>(baseType)) {
+ return floatType.isF32() || floatType.isF64();
+ }
+
+ return false;
+}
+
+class ArithConstantOpConversionPattern
+ : public OpRewritePattern<arith::ConstantOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::ConstantOp arithConst,
+ PatternRewriter &rewriter) const override {
+
+ auto constantType = arithConst.getType();
+ if (!isConvertibleToEmitC(constantType)) {
+ return rewriter.notifyMatchFailure(arithConst.getLoc(),
+ "Type cannot be converted to emitc");
+ }
+
+ rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, constantType,
+ arithConst.getValue());
+ return success();
+ }
+};
+
+struct ConvertArithToEmitCPass
+ : public impl::ArithToEmitCConversionPassBase<ConvertArithToEmitCPass> {
+public:
+ void runOnOperation() override {
+
+ ConversionTarget target(getContext());
+ target.addIllegalDialect<arith::ArithDialect>();
+ target.addLegalDialect<emitc::EmitCDialect>();
+ RewritePatternSet patterns(&getContext());
+ populateArithToEmitCConversionPatterns(patterns);
+
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns)))) {
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
+
+void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ArithConstantOpConversionPattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..c1bb6d71310edb
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(ArithToEmitC
+ ArithToEmitC.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIREmitCDialect
+ MLIRArithDialect
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
new file mode 100644
index 00000000000000..b13c6506787c56
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_constant_complex_tensor() -> (tensor<complex<i32>>) {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+ %c = arith.constant dense<(2, 2)> : tensor<complex<i32>>
+ return %c : tensor<complex<i32>>
+}
+
+// -----
+
+func.func @arith_constant_invalid_int_type() -> (i10) {
+ // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+ %c = arith.constant 0 : i10
+ return %c : i10
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
new file mode 100644
index 00000000000000..2583dd832c314c
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s
+
+// CHECK-LABEL: arith_constants
+func.func @arith_constants() {
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = 0 : index
+ %c_index = arith.constant 0 : index
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = 0 : i32
+ %c_signless_int_32 = arith.constant 0 : i32
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = 0.{{0+}}e+00 : f32
+ %c_float_32 = arith.constant 0.0 : f32
+ // CHECK: emitc.constant
+ // CHECK-SAME: value = dense<0> : tensor<i32>
+ %c_tensor_single_value = arith.constant dense<0> : tensor<i32>
+ // CHECK: emitc.constant
+ // CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
+ %c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/83798
More information about the Mlir-commits
mailing list