[Mlir-commits] [mlir] [mlir][emitc] Add ArrayType (PR #83386)
Matthias Gehre
llvmlistbot at llvm.org
Wed Feb 28 23:18:53 PST 2024
https://github.com/mgehre-amd created https://github.com/llvm/llvm-project/pull/83386
This models a one or multi-dimensional C/C++ array.
The type implements the `ShapedTypeInterface` and prints similar to memref/tensor:
```
%arg0: !emitc.array<1xf32>,
%arg1: !emitc.array<10x20x30xi32>,
%arg2: !emitc.array<30x!emitc.ptr<i32>>,
%arg3: !emitc.array<30x!emitc.opaque<"int">>
```
It can be translated to a C array type when used as function parameter or as `emitc.variable` type.
>From d543c271f3156006e05c1df5830359d039b373e8 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <93204396+mgehre-amd at users.noreply.github.com>
Date: Wed, 28 Feb 2024 14:42:29 +0100
Subject: [PATCH] [mlir][emitc] Add ArrayType
This models a one or multi-dimensional C/C++ array.
The type implements the ShapedTypeInterface and prints similar to memref/tensor:
```
%arg0: !emitc.array<1xf32>,
%arg1: !emitc.array<10x20x30xi32>,
%arg2: !emitc.array<30x!emitc.ptr<i32>>,
%arg3: !emitc.array<30x!emitc.opaque<"int">>
```
It can be translated to C++ when used as function parameter or as emitc.variable type.
---
.../mlir/Dialect/EmitC/IR/EmitCTypes.td | 52 ++++++++++++++-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 63 +++++++++++++++++++
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 32 ++++++++--
mlir/test/Dialect/EmitC/invalid_types.mlir | 54 ++++++++++++++++
mlir/test/Dialect/EmitC/types.mlir | 14 +++++
mlir/test/Target/Cpp/common-cpp.mlir | 5 ++
mlir/test/Target/Cpp/variable.mlir | 3 +
7 files changed, 215 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
index 8818c049ed7713..5ab729df67882a 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
@@ -16,16 +16,64 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/Dialect/EmitC/IR/EmitCBase.td"
+include "mlir/IR/BuiltinTypeInterfaces.td"
//===----------------------------------------------------------------------===//
// EmitC type definitions
//===----------------------------------------------------------------------===//
-class EmitC_Type<string name, string typeMnemonic>
- : TypeDef<EmitC_Dialect, name> {
+class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<EmitC_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}
+def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
+ let summary = "EmitC array type";
+
+ let description = [{
+ An array data type.
+
+ Example:
+
+ ```mlir
+ // Array emitted as `int32_t[10]`
+ !emitc.array<10xi32>
+ // Array emitted as `float[10][20]`
+ !emitc.ptr<10x20xf32>
+ ```
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ "Type":$elementType
+ );
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "ArrayRef<int64_t>":$shape,
+ "Type":$elementType
+ ), [{
+ return $_get(elementType.getContext(), shape, elementType);
+ }]>
+ ];
+ let extraClassDeclaration = [{
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Clone this vector type with the given shape and element type. If the
+ /// provided shape is `std::nullopt`, the current shape of the type is used.
+ ArrayType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
+
+ static bool isValidElementType(Type type) {
+ return type.isIntOrIndexOrFloat() ||
+ llvm::isa<PointerType, OpaqueType>(type);
+ }
+ }];
+ let genVerifyDecl = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
let summary = "EmitC opaque type";
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4df8149b94c95f..3d74737495c6b8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -762,6 +762,69 @@ LogicalResult emitc::YieldOp::verify() {
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
+//===----------------------------------------------------------------------===//
+// ArrayType
+//===----------------------------------------------------------------------===//
+
+Type emitc::ArrayType::parse(AsmParser &parser) {
+ if (parser.parseLess())
+ return Type();
+
+ SmallVector<int64_t, 4> dimensions;
+ if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
+ /*withTrailingX=*/true))
+ return Type();
+ // Parse the element type.
+ auto typeLoc = parser.getCurrentLocation();
+ Type elementType;
+ if (parser.parseType(elementType))
+ return Type();
+
+ // Check that memref is formed from allowed types.
+ if (!isValidElementType(elementType))
+ return parser.emitError(typeLoc, "invalid array element type"), Type();
+ if (parser.parseGreater())
+ return Type();
+ return parser.getChecked<ArrayType>(dimensions, elementType);
+}
+
+void emitc::ArrayType::print(AsmPrinter &printer) const {
+ printer << "<";
+ for (int64_t dim : getShape()) {
+ printer << dim << 'x';
+ }
+ printer.printType(getElementType());
+ printer << ">";
+}
+
+LogicalResult emitc::ArrayType::verify(
+ ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
+ ::llvm::ArrayRef<int64_t> shape, Type elementType) {
+ if (shape.empty())
+ return emitError() << "shape must not be empty";
+
+ for (auto d : shape) {
+ if (d <= 0)
+ return emitError() << "dimensions must have positive size";
+ }
+
+ if (!elementType)
+ return emitError() << "element type must not be none";
+
+ if (!isValidElementType(elementType))
+ return emitError() << "invalid array element type";
+
+ return success();
+}
+
+emitc::ArrayType
+emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (!shape)
+ return emitc::ArrayType::get(getShape(), elementType);
+ return emitc::ArrayType::get(*shape, elementType);
+}
+
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 2ba3dec0a9a57f..2adb1bb877c17e 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -128,6 +128,10 @@ struct CppEmitter {
LogicalResult emitVariableDeclaration(OpResult result,
bool trailingSemicolon);
+ /// Emits a declaration of a variable with the given type and name.
+ LogicalResult emitVariableDeclaration(Location loc, Type type,
+ StringRef name);
+
/// Emits the variable declaration and assignment prefix for 'op'.
/// - emits separate variable followed by std::tie for multi-valued operation;
/// - emits single type followed by variable for single result;
@@ -783,10 +787,8 @@ static LogicalResult printFunctionArgs(CppEmitter &emitter,
return (interleaveCommaWithError(
arguments, os, [&](BlockArgument arg) -> LogicalResult {
- if (failed(emitter.emitType(functionOp->getLoc(), arg.getType())))
- return failure();
- os << " " << emitter.getOrCreateName(arg);
- return success();
+ return emitter.emitVariableDeclaration(
+ functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
}));
}
@@ -1219,9 +1221,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
}
- if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
+ if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
+ result.getType(),
+ getOrCreateName(result))))
return failure();
- os << " " << getOrCreateName(result);
if (trailingSemicolon)
os << ";\n";
return success();
@@ -1314,6 +1317,23 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
return success();
}
+LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
+ StringRef name) {
+ if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
+ if (failed(emitType(loc, arrType.getElementType())))
+ return failure();
+ os << " " << name;
+ for (auto dim : arrType.getShape()) {
+ os << "[" << dim << "]";
+ }
+ return success();
+ }
+ if (failed(emitType(loc, type)))
+ return failure();
+ os << " " << name;
+ return success();
+}
+
LogicalResult CppEmitter::emitType(Location loc, Type type) {
if (auto iType = dyn_cast<IntegerType>(type)) {
switch (iType.getWidth()) {
diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir
index 54e3775ddb8ed1..4c526aa93dffb0 100644
--- a/mlir/test/Dialect/EmitC/invalid_types.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_types.mlir
@@ -11,3 +11,57 @@ func.func @illegal_opaque_type_2() {
// expected-error @+1 {{pointer not allowed as outer type with !emitc.opaque, use !emitc.ptr instead}}
%1 = "emitc.variable"(){value = "nullptr" : !emitc.opaque<"int32_t*">} : () -> !emitc.opaque<"int32_t*">
}
+
+// -----
+
+func.func @illegal_array_missing_spec(
+ // expected-error @+1 {{expected non-function type}}
+ %arg0: !emitc.array<>) {
+}
+
+// -----
+
+func.func @illegal_array_missing_shape(
+ // expected-error @+1 {{shape must not be empty}}
+ %arg9: !emitc.array<i32>) {
+}
+
+// -----
+
+func.func @illegal_array_missing_x(
+ // expected-error @+1 {{expected 'x' in dimension list}}
+ %arg0: !emitc.array<10>
+) {
+}
+
+// -----
+
+func.func @illegal_array_non_positive_dimenson(
+ // expected-error @+1 {{dimensions must have positive size}}
+ %arg0: !emitc.array<0xi32>
+) {
+}
+
+// -----
+
+func.func @illegal_array_missing_type(
+ // expected-error @+1 {{expected non-function type}}
+ %arg0: !emitc.array<10x>
+) {
+}
+
+// -----
+
+func.func @illegal_array_dynamic_shape(
+ // expected-error @+1 {{expected static shape}}
+ %arg0: !emitc.array<10x?xi32>
+) {
+}
+
+// -----
+
+func.func @illegal_array_unranked(
+ // expected-error @+1 {{expected non-function type}}
+ %arg0: !emitc.array<*xi32>
+) {
+}
diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir
index 26d6f43a5824e8..8477b0ed059774 100644
--- a/mlir/test/Dialect/EmitC/types.mlir
+++ b/mlir/test/Dialect/EmitC/types.mlir
@@ -39,3 +39,17 @@ func.func @pointer_types() {
return
}
+
+// CHECK-LABEL: func @array_types(
+func.func @array_types(
+ // CHECK-SAME: !emitc.array<1xf32>,
+ %arg0: !emitc.array<1xf32>,
+ // CHECK-SAME: !emitc.array<10x20x30xi32>,
+ %arg1: !emitc.array<10x20x30xi32>,
+ // CHECK-SAME: !emitc.array<30x!emitc.ptr<i32>>,
+ %arg2: !emitc.array<30x!emitc.ptr<i32>>,
+ // CHECK-SAME: !emitc.array<30x!emitc.opaque<"int">>
+ %arg3: !emitc.array<30x!emitc.opaque<"int">>
+) {
+ return
+}
diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir
index b537e7098deb51..a87b33a10844d3 100644
--- a/mlir/test/Target/Cpp/common-cpp.mlir
+++ b/mlir/test/Target/Cpp/common-cpp.mlir
@@ -89,3 +89,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
%1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> (i32)
return %0 : !emitc.ptr<i32>
}
+
+// CHECK: void array_type(int32_t v1[3], float v2[10][20])
+func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
+ return
+}
diff --git a/mlir/test/Target/Cpp/variable.mlir b/mlir/test/Target/Cpp/variable.mlir
index 77a060a32f9d45..5d061a6c87505f 100644
--- a/mlir/test/Target/Cpp/variable.mlir
+++ b/mlir/test/Target/Cpp/variable.mlir
@@ -9,6 +9,7 @@ func.func @emitc_variable() {
%c4 = "emitc.variable"(){value = 255 : ui8} : () -> ui8
%c5 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.ptr<i32>
%c6 = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<i32>
+ %c7 = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32>
return
}
// CPP-DEFAULT: void emitc_variable() {
@@ -19,6 +20,7 @@ func.func @emitc_variable() {
// CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255;
// CPP-DEFAULT-NEXT: int32_t* [[V5:[^ ]*]];
// CPP-DEFAULT-NEXT: int32_t* [[V6:[^ ]*]] = NULL;
+// CPP-DEFAULT-NEXT: int32_t [[V7:[^ ]*]][3][7];
// CPP-DECLTOP: void emitc_variable() {
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -28,6 +30,7 @@ func.func @emitc_variable() {
// CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t* [[V5:[^ ]*]];
// CPP-DECLTOP-NEXT: int32_t* [[V6:[^ ]*]];
+// CPP-DECLTOP-NEXT: int32_t [[V7:[^ ]*]][3][7];
// CPP-DECLTOP-NEXT: ;
// CPP-DECLTOP-NEXT: [[V1]] = 42;
// CPP-DECLTOP-NEXT: [[V2]] = -1;
More information about the Mlir-commits
mailing list