[Mlir-commits] [mlir] [mlir][emitc] Add ArrayType (PR #83386)

Matthias Gehre llvmlistbot at llvm.org
Wed Feb 28 23:18:53 PST 2024


https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/83386

This models a one or multi-dimensional C/C++ array.

The type implements the `ShapedTypeInterface` and prints similar to memref/tensor:
```
  %arg0: !emitc.array<1xf32>,
  %arg1: !emitc.array<10x20x30xi32>,
  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
  %arg3: !emitc.array<30x!emitc.opaque<"int">>
```

It can be translated to a C array type when used as function parameter or as `emitc.variable` type.

>From d543c271f3156006e05c1df5830359d039b373e8 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <93204396+mgehre-amd at users.noreply.github.com>
Date: Wed, 28 Feb 2024 14:42:29 +0100
Subject: [PATCH] [mlir][emitc] Add ArrayType

This models a one or multi-dimensional C/C++ array.

The type implements the ShapedTypeInterface and prints similar to memref/tensor:
```
  %arg0: !emitc.array<1xf32>,
  %arg1: !emitc.array<10x20x30xi32>,
  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
  %arg3: !emitc.array<30x!emitc.opaque<"int">>
```

It can be translated to C++ when used as function parameter or as emitc.variable type.
---
 .../mlir/Dialect/EmitC/IR/EmitCTypes.td       | 52 ++++++++++++++-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 63 +++++++++++++++++++
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        | 32 ++++++++--
 mlir/test/Dialect/EmitC/invalid_types.mlir    | 54 ++++++++++++++++
 mlir/test/Dialect/EmitC/types.mlir            | 14 +++++
 mlir/test/Target/Cpp/common-cpp.mlir          |  5 ++
 mlir/test/Target/Cpp/variable.mlir            |  3 +
 7 files changed, 215 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index 8818c049ed7713..5ab729df67882a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -16,16 +16,64 @@
 
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/Dialect/EmitC/IR/EmitCBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
 
 //===----------------------------------------------------------------------===//
 // EmitC type definitions
 //===----------------------------------------------------------------------===//
 
-class EmitC_Type<string name, string typeMnemonic>
-    : TypeDef<EmitC_Dialect, name> {
+class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
+    : TypeDef<EmitC_Dialect, name, traits> {
   let mnemonic = typeMnemonic;
 }
 
+def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
+  let summary = "EmitC array type";
+
+  let description = [{
+    An array data type.
+
+    Example:
+
+    ```mlir
+    // Array emitted as `int32_t[10]`
+    !emitc.array<10xi32>
+    // Array emitted as `float[10][20]`
+    !emitc.ptr<10x20xf32>
+    ```
+  }];
+
+  let parameters = (ins
+    ArrayRefParameter<"int64_t">:$shape,
+    "Type":$elementType
+  );
+
+  let builders = [
+    TypeBuilderWithInferredContext<(ins
+      "ArrayRef<int64_t>":$shape,
+      "Type":$elementType
+    ), [{
+      return $_get(elementType.getContext(), shape, elementType);
+    }]>
+  ];
+  let extraClassDeclaration = [{
+    /// Returns if this type is ranked (always true).
+    bool hasRank() const { return true; }
+
+    /// Clone this vector type with the given shape and element type. If the
+    /// provided shape is `std::nullopt`, the current shape of the type is used.
+    ArrayType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                        Type elementType) const;
+
+    static bool isValidElementType(Type type) {
+      return type.isIntOrIndexOrFloat() ||
+         llvm::isa<PointerType, OpaqueType>(type);
+    }
+  }];
+  let genVerifyDecl = 1;
+  let hasCustomAssemblyFormat = 1;
+}
+
 def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
   let summary = "EmitC opaque type";
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4df8149b94c95f..3d74737495c6b8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -762,6 +762,69 @@ LogicalResult emitc::YieldOp::verify() {
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+Type emitc::ArrayType::parse(AsmParser &parser) {
+  if (parser.parseLess())
+    return Type();
+
+  SmallVector<int64_t, 4> dimensions;
+  if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
+                                /*withTrailingX=*/true))
+    return Type();
+  // Parse the element type.
+  auto typeLoc = parser.getCurrentLocation();
+  Type elementType;
+  if (parser.parseType(elementType))
+    return Type();
+
+  // Check that memref is formed from allowed types.
+  if (!isValidElementType(elementType))
+    return parser.emitError(typeLoc, "invalid array element type"), Type();
+  if (parser.parseGreater())
+    return Type();
+  return parser.getChecked<ArrayType>(dimensions, elementType);
+}
+
+void emitc::ArrayType::print(AsmPrinter &printer) const {
+  printer << "<";
+  for (int64_t dim : getShape()) {
+    printer << dim << 'x';
+  }
+  printer.printType(getElementType());
+  printer << ">";
+}
+
+LogicalResult emitc::ArrayType::verify(
+    ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+    ::llvm::ArrayRef<int64_t> shape, Type elementType) {
+  if (shape.empty())
+    return emitError() << "shape must not be empty";
+
+  for (auto d : shape) {
+    if (d <= 0)
+      return emitError() << "dimensions must have positive size";
+  }
+
+  if (!elementType)
+    return emitError() << "element type must not be none";
+
+  if (!isValidElementType(elementType))
+    return emitError() << "invalid array element type";
+
+  return success();
+}
+
+emitc::ArrayType
+emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                            Type elementType) const {
+  if (!shape)
+    return emitc::ArrayType::get(getShape(), elementType);
+  return emitc::ArrayType::get(*shape, elementType);
+}
+
 //===----------------------------------------------------------------------===//
 // OpaqueType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 2ba3dec0a9a57f..2adb1bb877c17e 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -128,6 +128,10 @@ struct CppEmitter {
   LogicalResult emitVariableDeclaration(OpResult result,
                                         bool trailingSemicolon);
 
+  /// Emits a declaration of a variable with the given type and name.
+  LogicalResult emitVariableDeclaration(Location loc, Type type,
+                                        StringRef name);
+
   /// Emits the variable declaration and assignment prefix for 'op'.
   /// - emits separate variable followed by std::tie for multi-valued operation;
   /// - emits single type followed by variable for single result;
@@ -783,10 +787,8 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
 
   return (interleaveCommaWithError(
       arguments, os, [&](BlockArgument arg) -> LogicalResult {
-        if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
-          return failure();
-        os << " " << emitter.getOrCreateName(arg);
-        return success();
+        return emitter.emitVariableDeclaration(
+            functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
       }));
 }
 
@@ -1219,9 +1221,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
     return result.getDefiningOp()->emitError(
         "result variable for the operation already declared");
   }
-  if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
+  if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
+                                     result.getType(),
+                                     getOrCreateName(result))))
     return failure();
-  os << " " << getOrCreateName(result);
   if (trailingSemicolon)
     os << ";\n";
   return success();
@@ -1314,6 +1317,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   return success();
 }
 
+LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
+                                                  StringRef name) {
+  if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
+    if (failed(emitType(loc, arrType.getElementType())))
+      return failure();
+    os << " " << name;
+    for (auto dim : arrType.getShape()) {
+      os << "[" << dim << "]";
+    }
+    return success();
+  }
+  if (failed(emitType(loc, type)))
+    return failure();
+  os << " " << name;
+  return success();
+}
+
 LogicalResult CppEmitter::emitType(Location loc, Type type) {
   if (auto iType = dyn_cast<IntegerType>(type)) {
     switch (iType.getWidth()) {
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index 54e3775ddb8ed1..4c526aa93dffb0 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -11,3 +11,57 @@ func.func @illegal_opaque_type_2() {
     // expected-error @+1 {{pointer not allowed as outer type with !emitc.opaque, use !emitc.ptr instead}}
     %1 = "emitc.variable"(){value = "nullptr" : !emitc.opaque<"int32_t*">} : () -> !emitc.opaque<"int32_t*">
 }
+
+// -----
+
+func.func @illegal_array_missing_spec(
+    // expected-error @+1 {{expected non-function type}}
+    %arg0: !emitc.array<>) {
+}
+
+// -----
+
+func.func @illegal_array_missing_shape(
+    // expected-error @+1 {{shape must not be empty}}
+    %arg9: !emitc.array<i32>) {
+}
+
+// -----
+
+func.func @illegal_array_missing_x(
+    // expected-error @+1 {{expected 'x' in dimension list}}
+    %arg0: !emitc.array<10>
+) {
+}
+
+// -----
+
+func.func @illegal_array_non_positive_dimenson(
+    // expected-error @+1 {{dimensions must have positive size}}
+    %arg0: !emitc.array<0xi32>
+) {
+}
+
+// -----
+
+func.func @illegal_array_missing_type(
+    // expected-error @+1 {{expected non-function type}}
+    %arg0: !emitc.array<10x>
+) {
+}
+
+// -----
+
+func.func @illegal_array_dynamic_shape(
+    // expected-error @+1 {{expected static shape}}
+    %arg0: !emitc.array<10x?xi32>
+) {
+}
+
+// -----
+
+func.func @illegal_array_unranked(
+    // expected-error @+1 {{expected non-function type}}
+    %arg0: !emitc.array<*xi32>
+) {
+}
diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir
index 26d6f43a5824e8..8477b0ed059774 100644
--- a/mlir/test/Dialect/EmitC/types.mlir
+++ b/mlir/test/Dialect/EmitC/types.mlir
@@ -39,3 +39,17 @@ func.func @pointer_types() {
 
   return
 }
+
+// CHECK-LABEL: func @array_types(
+func.func @array_types(
+  // CHECK-SAME: !emitc.array<1xf32>,
+  %arg0: !emitc.array<1xf32>,
+  // CHECK-SAME: !emitc.array<10x20x30xi32>,
+  %arg1: !emitc.array<10x20x30xi32>,
+  // CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
+  %arg2: !emitc.array<30x!emitc.ptr<i32>>,
+  // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
+  %arg3: !emitc.array<30x!emitc.opaque<"int">>
+) {
+  return
+}
diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir
index b537e7098deb51..a87b33a10844d3 100644
--- a/mlir/test/Target/Cpp/common-cpp.mlir
+++ b/mlir/test/Target/Cpp/common-cpp.mlir
@@ -89,3 +89,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
   %1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
   return %0 : !emitc.ptr<i32>
 }
+
+// CHECK: void array_type(int32_t v1[3], float v2[10][20])
+func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
+  return
+}
diff --git a/mlir/test/Target/Cpp/variable.mlir b/mlir/test/Target/Cpp/variable.mlir
index 77a060a32f9d45..5d061a6c87505f 100644
--- a/mlir/test/Target/Cpp/variable.mlir
+++ b/mlir/test/Target/Cpp/variable.mlir
@@ -9,6 +9,7 @@ func.func @emitc_variable() {
   %c4 = "emitc.variable"(){value = 255 : ui8} : () -> ui8
   %c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i32>
   %c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<i32>
+  %c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32>
   return
 }
 // CPP-DEFAULT: void emitc_variable() {
@@ -19,6 +20,7 @@ func.func @emitc_variable() {
 // CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255;
 // CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]];
 // CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL;
+// CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7];
 
 // CPP-DECLTOP: void emitc_variable() {
 // CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -28,6 +30,7 @@ func.func @emitc_variable() {
 // CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]];
 // CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]];
+// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7];
 // CPP-DECLTOP-NEXT: ;
 // CPP-DECLTOP-NEXT: [[V1]] = 42;
 // CPP-DECLTOP-NEXT: [[V2]] = -1;



More information about the Mlir-commits mailing list