[Mlir-commits] [mlir] [mlir][emitc] Arith to EmitC conversion pass (PR #83798)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 4 01:09:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tina Jung (TinaAMD)

<details>
<summary>Changes</summary>

Add a conversion pass from Arith to EmitC. Add an initial conversion from `arith.constant` to `emitc.constant`.

---
Full diff: https://github.com/llvm/llvm-project/pull/83798.diff


8 Files Affected:

- (added) mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h (+22) 
- (modified) mlir/include/mlir/Conversion/Passes.h (+1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+12) 
- (added) mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp (+104) 
- (added) mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt (+17) 
- (modified) mlir/lib/Conversion/CMakeLists.txt (+1) 
- (added) mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir (+15) 
- (added) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+21) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
new file mode 100644
index 00000000000000..43322ac7f51f6c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
@@ -0,0 +1,22 @@
+//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
+
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+class RewritePatternSet;
+
+#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 81f69210fade8d..f41400a633ef22 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -13,6 +13,7 @@
 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
 #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 94fc7a7d2194bf..358ac997fba2a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -133,6 +133,18 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ArithToEmitC
+//===----------------------------------------------------------------------===//
+
+def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> {
+  let summary = "Convert Arith ops to EmitC ops";
+  let description = [{
+    Convert `arith` operations to operations in the `emitc` dialect.
+  }];
+  let dependentDialects = ["emitc::EmitCDialect"];
+}
+
 //===----------------------------------------------------------------------===//
 // ArithToLLVM
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
new file mode 100644
index 00000000000000..648fd2b4af0b70
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
@@ -0,0 +1,104 @@
+//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to convert arith ops into emitc ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+static bool isConvertibleToEmitC(Type type) {
+  Type baseType = type;
+  if (auto tensorType = dyn_cast<TensorType>(type)) {
+    if (!tensorType.hasRank() || !tensorType.hasStaticShape()) {
+      return false;
+    }
+    baseType = tensorType.getElementType();
+  }
+
+  if (isa<IndexType>(baseType)) {
+    return true;
+  }
+
+  if (auto intType = dyn_cast<IntegerType>(baseType)) {
+    switch (intType.getWidth()) {
+    case 1:
+    case 8:
+    case 16:
+    case 32:
+    case 64:
+      return true;
+    }
+    return false;
+  }
+
+  if (auto floatType = dyn_cast<FloatType>(baseType)) {
+    return floatType.isF32() || floatType.isF64();
+  }
+
+  return false;
+}
+
+class ArithConstantOpConversionPattern
+    : public OpRewritePattern<arith::ConstantOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::ConstantOp arithConst,
+                                PatternRewriter &rewriter) const override {
+
+    auto constantType = arithConst.getType();
+    if (!isConvertibleToEmitC(constantType)) {
+      return rewriter.notifyMatchFailure(arithConst.getLoc(),
+                                         "Type cannot be converted to emitc");
+    }
+
+    rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, constantType,
+                                                   arithConst.getValue());
+    return success();
+  }
+};
+
+struct ConvertArithToEmitCPass
+    : public impl::ArithToEmitCConversionPassBase<ConvertArithToEmitCPass> {
+public:
+  void runOnOperation() override {
+
+    ConversionTarget target(getContext());
+    target.addIllegalDialect<arith::ArithDialect>();
+    target.addLegalDialect<emitc::EmitCDialect>();
+    RewritePatternSet patterns(&getContext());
+    populateArithToEmitCConversionPatterns(patterns);
+
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns)))) {
+      signalPassFailure();
+    }
+  }
+};
+
+} // namespace
+
+void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) {
+  patterns.add<ArithConstantOpConversionPattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
new file mode 100644
index 00000000000000..c1bb6d71310edb
--- /dev/null
+++ b/mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_conversion_library(ArithToEmitC
+  ArithToEmitC.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIREmitCDialect
+  MLIRArithDialect
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 9e421f7c49dbc3..8219cf98575f3c 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
 add_subdirectory(ArithCommon)
 add_subdirectory(ArithToAMDGPU)
 add_subdirectory(ArithToArmSME)
+add_subdirectory(ArithToEmitC)
 add_subdirectory(ArithToLLVM)
 add_subdirectory(ArithToSPIRV)
 add_subdirectory(ArmNeon2dToIntr)
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
new file mode 100644
index 00000000000000..b13c6506787c56
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emit-c-failed.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
+
+func.func @arith_constant_complex_tensor() -> (tensor<complex<i32>>) {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+  %c = arith.constant dense<(2, 2)> : tensor<complex<i32>>
+  return %c : tensor<complex<i32>>
+}
+
+// -----
+
+func.func @arith_constant_invalid_int_type() -> (i10) {
+  // expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
+  %c = arith.constant 0 : i10
+  return %c : i10
+}
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
new file mode 100644
index 00000000000000..2583dd832c314c
--- /dev/null
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s
+
+// CHECK-LABEL: arith_constants
+func.func @arith_constants() {
+  // CHECK: emitc.constant
+  // CHECK-SAME: value = 0 : index
+  %c_index = arith.constant 0 : index
+  // CHECK: emitc.constant
+  // CHECK-SAME: value = 0 : i32
+  %c_signless_int_32 = arith.constant 0 : i32
+  // CHECK: emitc.constant
+  // CHECK-SAME: value = 0.{{0+}}e+00 : f32
+  %c_float_32 = arith.constant 0.0 : f32
+  // CHECK: emitc.constant
+  // CHECK-SAME: value = dense<0> : tensor<i32>
+  %c_tensor_single_value = arith.constant dense<0> : tensor<i32>
+  // CHECK: emitc.constant
+  // CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
+  %c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/83798


More information about the Mlir-commits mailing list