[Mlir-commits] [mlir] [mlir][emitc] Add ArrayType (PR #83386)

Matthias Gehre llvmlistbot at llvm.org
Mon Mar 4 03:21:31 PST 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/83386

>From d46d2b294881782126257020e4563e286c86ff05 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <93204396+mgehre-amd at users.noreply.github.com>
Date: Wed, 28 Feb 2024 14:42:29 +0100
Subject: [PATCH 1/2] [mlir][emitc] Add ArrayType

This models a one or multi-dimensional C/C++ array.

The type implements the ShapedTypeInterface and prints similar to memref/tensor:
```
  %arg0: !emitc.array<1xf32>,
  %arg1: !emitc.array<10x20x30xi32>,
  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
  %arg3: !emitc.array<30x!emitc.opaque<"int">>
```

It can be translated to C++ when used as function parameter or as emitc.variable type.
---
 .../mlir/Dialect/EmitC/IR/EmitCTypes.td       | 52 ++++++++++++++-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 63 +++++++++++++++++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 32 ++++++++--
 mlir/test/Dialect/EmitC/invalid_types.mlir    | 54 ++++++++++++++++
 mlir/test/Dialect/EmitC/types.mlir            | 14 +++++
 mlir/test/Target/Cpp/common-cpp.mlir          |  5 ++
 mlir/test/Target/Cpp/variable.mlir            |  3 +
 7 files changed, 215 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index 8818c049ed77132..5ab729df67882a4 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -16,16 +16,64 @@
 
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/Dialect/EmitC/IR/EmitCBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
 
 //===----------------------------------------------------------------------===//
 // EmitC type definitions
 //===----------------------------------------------------------------------===//
 
-class EmitC_Type<string name, string typeMnemonic>
-    : TypeDef<EmitC_Dialect, name> {
+class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<EmitC_Dialect, name, traits> {
   let mnemonic = typeMnemonic;
 }
 
+def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
+  let summary = "EmitC array type";
+
+  let description = [{
+    An array data type.
+
+    Example:
+
+    ```mlir
+    // Array emitted as `int32_t[10]`
+    !emitc.array<10xi32>
+    // Array emitted as `float[10][20]`
+    !emitc.ptr<10x20xf32>
+    ```
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "Type":$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape,
+      "Type":$elementType
+    ), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this vector type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    ArrayType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                        Type elementType) const;
+
+    static bool isValidElementType(Type type) {
+      return type.isIntOrIndexOrFloat() ||
+         llvm::isa<PointerType, OpaqueType>(type);
+    }
+  }];
+  let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
   let summary = "EmitC opaque type";
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4df8149b94c95fa..3d74737495c6b87 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -762,6 +762,69 @@ LogicalResult emitc::YieldOp::verify() {
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+Type emitc::ArrayType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  SmallVector<int64_t, 4> dimensions;
+  if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
+                                /*withTrailingX=*/true))
+    return Type();
+  // Parse the element type.
+  auto typeLoc = parser.getCurrentLocation();
+  Type elementType;
+  if (parser.parseType(elementType))
+    return Type();
+
+  // Check that memref is formed from allowed types.
+  if (!isValidElementType(elementType))
+    return parser.emitError(typeLoc, "invalid array element type"), Type();
+  if (parser.parseGreater())
+    return Type();
+  return parser.getChecked<ArrayType>(dimensions, elementType);
+}
+
+void emitc::ArrayType::print(AsmPrinter &printer) const {
+  printer << "<";
+  for (int64_t dim : getShape()) {
+    printer << dim << 'x';
+  }
+  printer.printType(getElementType());
+  printer << ">";
+}
+
+LogicalResult emitc::ArrayType::verify(
+    ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    ::llvm::ArrayRef<int64_t> shape, Type elementType) {
+  if (shape.empty())
+    return emitError() << "shape must not be empty";
+
+  for (auto d : shape) {
+    if (d <= 0)
+      return emitError() << "dimensions must have positive size";
+  }
+
+  if (!elementType)
+    return emitError() << "element type must not be none";
+
+  if (!isValidElementType(elementType))
+    return emitError() << "invalid array element type";
+
+  return success();
+}
+
+emitc::ArrayType
+emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                            Type elementType) const {
+  if (!shape)
+    return emitc::ArrayType::get(getShape(), elementType);
+  return emitc::ArrayType::get(*shape, elementType);
+}
+
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 4bc707c43ad92f0..849f7e47934069b 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -128,6 +128,10 @@ struct CppEmitter {
   LogicalResult emitVariableDeclaration(OpResult result,
                                         bool trailingSemicolon);
 
+  /// Emits a declaration of a variable with the given type and name.
+  LogicalResult emitVariableDeclaration(Location loc, Type type,
+                                        StringRef name);
+
   /// Emits the variable declaration and assignment prefix for 'op'.
   /// - emits separate variable followed by std::tie for multi-valued operation;
   /// - emits single type followed by variable for single result;
@@ -855,10 +859,8 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
 
   return (interleaveCommaWithError(
       arguments, os, [&](BlockArgument arg) -> LogicalResult {
-        if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
-          return failure();
-        os << " " << emitter.getOrCreateName(arg);
-        return success();
+        return emitter.emitVariableDeclaration(
+            functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
       }));
 }
 
@@ -1291,9 +1293,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
     return result.getDefiningOp()->emitError(
         "result variable for the operation already declared");
   }
-  if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
+  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
+                                     result.getType(),
+                                     getOrCreateName(result))))
     return failure();
-  os << " " << getOrCreateName(result);
   if (trailingSemicolon)
     os << ";\n";
   return success();
@@ -1390,6 +1393,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   return success();
 }
 
+LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
+                                                  StringRef name) {
+  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
+    if (failed(emitType(loc, arrType.getElementType())))
+      return failure();
+    os << " " << name;
+    for (auto dim : arrType.getShape()) {
+      os << "[" << dim << "]";
+    }
+    return success();
+  }
+  if (failed(emitType(loc, type)))
+    return failure();
+  os << " " << name;
+  return success();
+}
+
 LogicalResult CppEmitter::emitType(Location loc, Type type) {
   if (auto iType = dyn_cast<IntegerType>(type)) {
     switch (iType.getWidth()) {
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index 54e3775ddb8ed11..4c526aa93dffb06 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -11,3 +11,57 @@ func.func @illegal_opaque_type_2() {
     // expected-error @+1 {{pointer not allowed as outer type with !emitc.opaque, use !emitc.ptr instead}}
     %1 = "emitc.variable"(){value = "nullptr" : !emitc.opaque<"int32_t*">} : () -> !emitc.opaque<"int32_t*">
 }
+
+// -----
+
+func.func @illegal_array_missing_spec(
+    // expected-error @+1 {{expected non-function type}}
+    %arg0: !emitc.array<>) {
+}
+
+// -----
+
+func.func @illegal_array_missing_shape(
+    // expected-error @+1 {{shape must not be empty}}
+    %arg9: !emitc.array<i32>) {
+}
+
+// -----
+
+func.func @illegal_array_missing_x(
+    // expected-error @+1 {{expected 'x' in dimension list}}
+    %arg0: !emitc.array<10>
+) {
+}
+
+// -----
+
+func.func @illegal_array_non_positive_dimenson(
+    // expected-error @+1 {{dimensions must have positive size}}
+    %arg0: !emitc.array<0xi32>
+) {
+}
+
+// -----
+
+func.func @illegal_array_missing_type(
+    // expected-error @+1 {{expected non-function type}}
+    %arg0: !emitc.array<10x>
+) {
+}
+
+// -----
+
+func.func @illegal_array_dynamic_shape(
+    // expected-error @+1 {{expected static shape}}
+    %arg0: !emitc.array<10x?xi32>
+) {
+}
+
+// -----
+
+func.func @illegal_array_unranked(
+    // expected-error @+1 {{expected non-function type}}
+    %arg0: !emitc.array<*xi32>
+) {
+}
diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir
index 26d6f43a5824e85..8477b0ed0597742 100644
--- a/mlir/test/Dialect/EmitC/types.mlir
+++ b/mlir/test/Dialect/EmitC/types.mlir
@@ -39,3 +39,17 @@ func.func @pointer_types() {
 
   return
 }
+
+// CHECK-LABEL: func @array_types(
+func.func @array_types(
+  // CHECK-SAME: !emitc.array<1xf32>,
+  %arg0: !emitc.array<1xf32>,
+  // CHECK-SAME: !emitc.array<10x20x30xi32>,
+  %arg1: !emitc.array<10x20x30xi32>,
+  // CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
+  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
+  // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
+  %arg3: !emitc.array<30x!emitc.opaque<"int">>
+) {
+  return
+}
diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir
index b537e7098deb515..a87b33a10844d31 100644
--- a/mlir/test/Target/Cpp/common-cpp.mlir
+++ b/mlir/test/Target/Cpp/common-cpp.mlir
@@ -89,3 +89,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
   %1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
   return %0 : !emitc.ptr<i32>
 }
+
+// CHECK: void array_type(int32_t v1[3], float v2[10][20])
+func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
+  return
+}
diff --git a/mlir/test/Target/Cpp/variable.mlir b/mlir/test/Target/Cpp/variable.mlir
index 77a060a32f9d45c..5d061a6c87505f0 100644
--- a/mlir/test/Target/Cpp/variable.mlir
+++ b/mlir/test/Target/Cpp/variable.mlir
@@ -9,6 +9,7 @@ func.func @emitc_variable() {
   %c4 = "emitc.variable"(){value = 255 : ui8} : () -> ui8
   %c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i32>
   %c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<i32>
+  %c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32>
   return
 }
 // CPP-DEFAULT: void emitc_variable() {
@@ -19,6 +20,7 @@ func.func @emitc_variable() {
 // CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255;
 // CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]];
 // CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL;
+// CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7];
 
 // CPP-DECLTOP: void emitc_variable() {
 // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -28,6 +30,7 @@ func.func @emitc_variable() {
 // CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]];
+// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7];
 // CPP-DECLTOP-NEXT: ;
 // CPP-DECLTOP-NEXT: [[V1]] = 42;
 // CPP-DECLTOP-NEXT: [[V2]] = -1;

>From 4ded2c6c076e55ae9718a291c4316d71dde376b3 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 4 Mar 2024 12:20:42 +0100
Subject: [PATCH 2/2] Address comments

---
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 12 ++++++-
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 22 +++++++++++++
 mlir/test/Dialect/EmitC/invalid_ops.mlir      | 32 +++++++++++++++++++
 mlir/test/Dialect/EmitC/invalid_types.mlir    | 16 ++++++++++
 mlir/test/Dialect/EmitC/types.mlir            | 28 ++++++++--------
 mlir/test/Target/Cpp/declare_func.mlir        |  8 +++++
 mlir/test/Target/Cpp/func.mlir                |  3 ++
 mlir/test/Target/Cpp/invalid.mlir             | 28 ++++++++++++++++
 .../Cpp/invalid_declare_variables_at_top.mlir |  9 ++++++
 mlir/test/Target/Cpp/variable.mlir            |  3 ++
 10 files changed, 146 insertions(+), 15 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/invalid_declare_variables_at_top.mlir

diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 3d74737495c6b87..f97db7dcd507328 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -140,6 +140,8 @@ LogicalResult emitc::AssignOp::verify() {
     return emitOpError() << "requires value's type (" << value.getType()
                          << ") to match variable's type (" << variable.getType()
                          << ")";
+  if (isa<ArrayType>(variable.getType()))
+    return emitOpError() << "cannot assign to array type";
   return success();
 }
 
@@ -191,6 +193,11 @@ LogicalResult emitc::CallOpaqueOp::verify() {
     }
   }
 
+  if (llvm::any_of(getResultTypes(),
+                   [](Type type) { return isa<ArrayType>(type); })) {
+    return emitOpError() << "cannot return array type";
+  }
+
   return success();
 }
 
@@ -455,6 +462,9 @@ LogicalResult FuncOp::verify() {
     return emitOpError("requires zero or exactly one result, but has ")
            << getNumResults();
 
+  if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
+    return emitOpError("cannot return array type");
+
   return success();
 }
 
@@ -780,7 +790,7 @@ Type emitc::ArrayType::parse(AsmParser &parser) {
   if (parser.parseType(elementType))
     return Type();
 
-  // Check that memref is formed from allowed types.
+  // Check that array is formed from allowed types.
   if (!isValidElementType(elementType))
     return parser.emitError(typeLoc, "invalid array element type"), Type();
   if (parser.parseGreater())
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 849f7e47934069b..58fa91277fcc8e1 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -904,6 +904,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
       if (emitter.hasValueInScope(arg))
         return functionOp->emitOpError(" block argument #")
                << arg.getArgNumber() << " is out of scope";
+      if (isa<ArrayType>(arg.getType()))
+        return functionOp->emitOpError("cannot emit block argument #")
+               << arg.getArgNumber() << " with array type";
       if (failed(
               emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
         return failure();
@@ -947,6 +950,11 @@ static LogicalResult printOperation(CppEmitter &emitter,
         "with multiple blocks needs variables declared at top");
   }
 
+  if (llvm::any_of(functionOp.getResultTypes(),
+                   [](Type type) { return isa<ArrayType>(type); })) {
+    return functionOp.emitOpError() << "cannot emit array type as result type";
+  }
+
   CppEmitter::Scope scope(emitter);
   raw_indented_ostream &os = emitter.ostream();
   if (failed(emitter.emitTypes(functionOp.getLoc(),
@@ -1445,6 +1453,8 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
     if (!tType.hasStaticShape())
       return emitError(loc, "cannot emit tensor type with non static shape");
     os << "Tensor<";
+    if (isa<ArrayType>(tType.getElementType()))
+      return emitError(loc, "cannot emit tensor of array type ") << type;
     if (failed(emitType(loc, tType.getElementType())))
       return failure();
     auto shape = tType.getShape();
@@ -1461,7 +1471,16 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
     os << oType.getValue();
     return success();
   }
+  if (auto aType = dyn_cast<emitc::ArrayType>(type)) {
+    if (failed(emitType(loc, aType.getElementType())))
+      return failure();
+    for (auto dim : aType.getShape())
+      os << "[" << dim << "]";
+    return success();
+  }
   if (auto pType = dyn_cast<emitc::PointerType>(type)) {
+    if (isa<ArrayType>(pType.getPointee()))
+      return emitError(loc, "cannot emit pointer to array type ") << type;
     if (failed(emitType(loc, pType.getPointee())))
       return failure();
     os << "*";
@@ -1483,6 +1502,9 @@ LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
 }
 
 LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
+  if (llvm::any_of(types, [](Type type) { return isa<ArrayType>(type); })) {
+    return emitError(loc, "cannot emit tuple of array type");
+  }
   os << "std::tuple<";
   if (failed(interleaveCommaWithError(
           types, os, [&](Type type) { return emitType(loc, type); })))
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 5f64b535d684f32..58b3a11ed93e15f 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -80,6 +80,14 @@ func.func @dense_template_argument(%arg : i32) {
 
 // -----
 
+func.func @array_result() {
+    // expected-error @+1 {{'emitc.call_opaque' op cannot return array type}}
+    emitc.call_opaque "array_result"() : () -> !emitc.array<4xi32>
+    return
+}
+
+// -----
+
 func.func @empty_operator(%arg : i32) {
     // expected-error @+1 {{'emitc.apply' op applicable operator must not be empty}}
     %2 = emitc.apply ""(%arg) : (i32) -> !emitc.ptr<i32>
@@ -129,6 +137,14 @@ func.func @cast_tensor(%arg : tensor<f32>) {
 
 // -----
 
+func.func @cast_array(%arg : !emitc.array<4xf32>) {
+    // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
+    %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
+    return
+}
+
+// -----
+
 func.func @add_two_pointers(%arg0: !emitc.ptr<f32>, %arg1: !emitc.ptr<f32>) {
     // expected-error @+1 {{'emitc.add' op requires that at most one operand is a pointer}}
     %1 = "emitc.add" (%arg0, %arg1) : (!emitc.ptr<f32>, !emitc.ptr<f32>) -> !emitc.ptr<f32>
@@ -235,6 +251,15 @@ func.func @test_assign_type_mismatch(%arg1: f32) {
 
 // -----
 
+func.func @test_assign_to_array(%arg1: !emitc.array<4xi32>) {
+  %v = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4xi32>
+  // expected-error @+1 {{'emitc.assign' op cannot assign to array type}}
+  emitc.assign %arg1 : !emitc.array<4xi32> to %v : !emitc.array<4xi32>
+  return
+}
+
+// -----
+
 func.func @test_expression_no_yield() -> i32 {
   // expected-error @+1 {{'emitc.expression' op must yield a value at termination}}
   %r = emitc.expression : i32 {
@@ -313,6 +338,13 @@ emitc.func @return_type_mismatch() -> i32 {
 
 // -----
 
+// expected-error at +1 {{'emitc.func' op cannot return array type}}
+emitc.func @return_type_array(%arg : !emitc.array<4xi32>) -> !emitc.array<4xi32> {
+  emitc.return %arg : !emitc.array<4xi32>
+}
+
+// -----
+
 func.func @return_inside_func.func(%0: i32) -> (i32) {
   // expected-error at +1 {{'emitc.return' op expects parent op 'emitc.func'}}
   emitc.return %0 : i32
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index 4c526aa93dffb06..079371b39b9d1ee 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -65,3 +65,19 @@ func.func @illegal_array_unranked(
     %arg0: !emitc.array<*xi32>
 ) {
 }
+
+// -----
+
+func.func @illegal_array_with_array_element_type(
+    // expected-error @+1 {{invalid array element type}}
+    %arg0: !emitc.array<4x!emitc.array<4xi32>>
+) {
+}
+
+// -----
+
+func.func @illegal_array_with_tensor_element_type(
+    // expected-error @+1 {{invalid array element type}}
+    %arg0: !emitc.array<4xtensor<4xi32>>
+) {
+}
diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir
index 8477b0ed0597742..752f2c10c17be90 100644
--- a/mlir/test/Dialect/EmitC/types.mlir
+++ b/mlir/test/Dialect/EmitC/types.mlir
@@ -2,6 +2,20 @@
 // check parser
 // RUN: mlir-opt -verify-diagnostics %s | mlir-opt -verify-diagnostics | FileCheck %s
 
+// CHECK-LABEL: func @array_types(
+func.func @array_types(
+  // CHECK-SAME: !emitc.array<1xf32>,
+  %arg0: !emitc.array<1xf32>,
+  // CHECK-SAME: !emitc.array<10x20x30xi32>,
+  %arg1: !emitc.array<10x20x30xi32>,
+  // CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
+  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
+  // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
+  %arg3: !emitc.array<30x!emitc.opaque<"int">>
+) {
+  return
+}
+
 // CHECK-LABEL: func @opaque_types() {
 func.func @opaque_types() {
   // CHECK-NEXT: !emitc.opaque<"int">
@@ -39,17 +53,3 @@ func.func @pointer_types() {
 
   return
 }
-
-// CHECK-LABEL: func @array_types(
-func.func @array_types(
-  // CHECK-SAME: !emitc.array<1xf32>,
-  %arg0: !emitc.array<1xf32>,
-  // CHECK-SAME: !emitc.array<10x20x30xi32>,
-  %arg1: !emitc.array<10x20x30xi32>,
-  // CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
-  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
-  // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
-  %arg3: !emitc.array<30x!emitc.opaque<"int">>
-) {
-  return
-}
diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir
index 72c087a3388e205..00680d71824ae04 100644
--- a/mlir/test/Target/Cpp/declare_func.mlir
+++ b/mlir/test/Target/Cpp/declare_func.mlir
@@ -14,3 +14,11 @@ emitc.declare_func @foo
 emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} {
     emitc.return %arg0 : i32
 }
+
+
+// CHECK: void array_arg(int32_t [[V2:[^ ]*]][3]);
+emitc.declare_func @array_arg
+// CHECK: void array_arg(int32_t  [[V2:[^ ]*]][3]) {
+emitc.func @array_arg(%arg0: !emitc.array<3xi32>) {
+    emitc.return
+}
diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir
index a639cae6f623c59..9c9ea55bfc4e1a1 100644
--- a/mlir/test/Target/Cpp/func.mlir
+++ b/mlir/test/Target/Cpp/func.mlir
@@ -40,3 +40,6 @@ emitc.func @emitc_call() -> i32 {
 
 emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}
 // CPP-DEFAULT: extern void extern_func(int32_t);
+
+emitc.func private @array_arg(!emitc.array<3xi32>) attributes {specifiers = ["extern"]}
+// CPP-DEFAULT: extern void array_arg(int32_t[3]);
diff --git a/mlir/test/Target/Cpp/invalid.mlir b/mlir/test/Target/Cpp/invalid.mlir
index 18dabb915586635..8672850e4a65310 100644
--- a/mlir/test/Target/Cpp/invalid.mlir
+++ b/mlir/test/Target/Cpp/invalid.mlir
@@ -57,3 +57,31 @@ func.func @non_static_shape(%arg0 : tensor<?xf32>) {
 func.func @unranked_tensor(%arg0 : tensor<*xf32>) {
   return
 }
+
+// -----
+
+// expected-error at +1 {{cannot emit tensor of array type}}
+func.func @tensor_of_array(%arg0 : tensor<4x!emitc.array<4xf32>>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{cannot emit pointer to array type}}
+func.func @tensor_of_array(%arg0 : !emitc.ptr<!emitc.array<4xf32>>) {
+  return
+}
+
+// -----
+
+// expected-error at +1 {{cannot emit array type as result type}}
+func.func @array_as_result(%arg: !emitc.array<4xi8>) -> (!emitc.array<4xi8>) {
+   return %arg : !emitc.array<4xi8>
+}
+
+// -----
+func.func @ptr_to_array() {
+  // expected-error at +1 {{cannot emit pointer to array type '!emitc.ptr<!emitc.array<9xi16>>'}}
+  %v = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<!emitc.array<9xi16>>
+  return
+}
diff --git a/mlir/test/Target/Cpp/invalid_declare_variables_at_top.mlir b/mlir/test/Target/Cpp/invalid_declare_variables_at_top.mlir
new file mode 100644
index 000000000000000..844fe03bad4aba5
--- /dev/null
+++ b/mlir/test/Target/Cpp/invalid_declare_variables_at_top.mlir
@@ -0,0 +1,9 @@
+// RUN: mlir-translate -split-input-file -declare-variables-at-top -mlir-to-cpp -verify-diagnostics %s
+
+// expected-error at +1 {{'func.func' op cannot emit block argument #0 with array type}}
+func.func @array_as_block_argument(!emitc.array<4xi8>) {
+^bb0(%arg0 : !emitc.array<4xi8>):
+  cf.br ^bb1(%arg0 : !emitc.array<4xi8>)
+^bb1(%a : !emitc.array<4xi8>):
+   return
+}
diff --git a/mlir/test/Target/Cpp/variable.mlir b/mlir/test/Target/Cpp/variable.mlir
index 5d061a6c87505f0..126dd384b47a2a0 100644
--- a/mlir/test/Target/Cpp/variable.mlir
+++ b/mlir/test/Target/Cpp/variable.mlir
@@ -10,6 +10,7 @@ func.func @emitc_variable() {
   %c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i32>
   %c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<i32>
   %c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32>
+  %c8 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<5x!emitc.ptr<i8>>
   return
 }
 // CPP-DEFAULT: void emitc_variable() {
@@ -21,6 +22,7 @@ func.func @emitc_variable() {
 // CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]];
 // CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL;
 // CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7];
+// CPP-DEFAULT-NEXT: int8_t* [[V8:[^ ]*]][5];
 
 // CPP-DECLTOP: void emitc_variable() {
 // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -31,6 +33,7 @@ func.func @emitc_variable() {
 // CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7];
+// CPP-DECLTOP-NEXT: int8_t* [[V8:[^ ]*]][5];
 // CPP-DECLTOP-NEXT: ;
 // CPP-DECLTOP-NEXT: [[V1]] = 42;
 // CPP-DECLTOP-NEXT: [[V2]] = -1;



More information about the Mlir-commits mailing list