[Mlir-commits] [mlir] EmitC: Add emitc.global and emitc.get_global (#145) (PR #88701)
Matthias Gehre
llvmlistbot at llvm.org
Mon Apr 22 05:57:46 PDT 2024
https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/88701
>From a711c683e5e464fd6af1d7ac260bcc809d9498da Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 25 Mar 2024 16:50:17 +0100
Subject: [PATCH 1/2] EmitC: Add emitc.global and emitc.get_global (#145)
This adds
- `emitc.global` and `emitc.get_global` ops to model global variables
similar to how `memref.global` and `memref.get_global` work.
- translation of those ops to C++
- lowering of `memref.global` and `memref.get_global` into those ops
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 69 +++++++++++
.../MemRefToEmitC/MemRefToEmitC.cpp | 66 +++++++++-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 115 ++++++++++++++++--
mlir/lib/Target/Cpp/TranslateToCpp.cpp | 55 ++++++++-
.../MemRefToEmitC/memref-to-emitc-failed.mlir | 5 +
.../MemRefToEmitC/memref-to-emitc.mlir | 17 +++
mlir/test/Dialect/EmitC/invalid_ops.mlir | 15 +++
mlir/test/Dialect/EmitC/ops.mlir | 14 +++
mlir/test/Target/Cpp/global.mlir | 35 ++++++
9 files changed, 376 insertions(+), 15 deletions(-)
create mode 100644 mlir/test/Target/Cpp/global.mlir
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..ee5fc0b09a1611 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1016,6 +1016,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
let hasVerifier = 1;
}
+def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
+ let summary = "A global variable";
+ let description = [{
+ The `emitc.global` operation declares or defines a named global variable.
+ The backing memory for the variable is allocated statically and is
+ described by the type of the variable.
+ Optionally, and `initial_value` can be provided.
+ Internal linkage can be specified using the `staticSpecifier` unit attribute
+ and external linkage can be specified using the `externSpecifier` unit attribute.
+ Note that the default linkage without those two keywords depends on whether
+ the target is C or C++ and whether the global variable is `const`.
+ The global variable can also be marked constant using the `constSpecifier`
+ unit attribute. Writing to such constant global variables is
+ undefined.
+
+ The global variable can be accessed by using the `emitc.get_global` to
+ retrieve the value for the global variable.
+
+ Example:
+
+ ```mlir
+ // Global variable with an initial value.
+ emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
+ // External global variable
+ emitc.global extern @x : emitc.array<2xf32>
+ // Constant global variable with internal linkage
+ emitc.global static const @x : i32 = 0
+ ```
+ }];
+
+ let arguments = (ins SymbolNameAttr:$sym_name,
+ TypeAttr:$type,
+ OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
+ UnitAttr:$externSpecifier,
+ UnitAttr:$staticSpecifier,
+ UnitAttr:$constSpecifier);
+
+ let assemblyFormat = [{
+ (`extern` $externSpecifier^)?
+ (`static` $staticSpecifier^)?
+ (`const` $constSpecifier^)?
+ $sym_name
+ `:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
+ attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
+def EmitC_GetGlobalOp : EmitC_Op<"get_global",
+ [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Obtain access to a global variable";
+ let description = [{
+ The `emitc.get_global` operation retrieves the lvalue of a
+ named global variable. If the global variable is marked constant, assigning
+ to that lvalue is undefined.
+
+ Example:
+
+ ```mlir
+ %x = emitc.get_global @foo : !emitc.array<2xf32>
+ ```
+ }];
+
+ let arguments = (ins FlatSymbolRefAttr:$name);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "$name `:` type($result) attr-dict";
+}
+
def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
let summary = "Verbatim operation";
let description = [{
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..d3e7f233c08412 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -50,6 +50,68 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
}
};
+struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ if (!op.getType().hasStaticShape()) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "cannot transform global with dynamic shape");
+ }
+
+ if (op.getAlignment().value_or(1) > 1) {
+ // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
+ return rewriter.notifyMatchFailure(
+ op.getLoc(), "global variable with alignment requirement is "
+ "currently not supported");
+ }
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cannot convert result type");
+ }
+
+ SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
+ if (visibility != SymbolTable::Visibility::Public &&
+ visibility != SymbolTable::Visibility::Private) {
+ return rewriter.notifyMatchFailure(
+ op.getLoc(),
+ "only public and private visibility is currently supported");
+ }
+ // We are explicit in specifier the linkage because the default linkage
+ // for constants is different in C and C++.
+ bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
+ bool externSpecifier = !staticSpecifier;
+
+ rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
+ op, operands.getSymName(), resultTy, operands.getInitialValueAttr(),
+ externSpecifier, staticSpecifier, operands.getConstant());
+ return success();
+ }
+};
+
+struct ConvertGetGlobal final
+ : public OpConversionPattern<memref::GetGlobalOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto resultTy = getTypeConverter()->convertType(op.getType());
+ if (!resultTy) {
+ return rewriter.notifyMatchFailure(op.getLoc(),
+ "cannot convert result type");
+ }
+ rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
+ operands.getNameAttr());
+ return success();
+ }
+};
+
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -109,6 +171,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
- patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
+ ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..e269f20578d5f8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -790,13 +790,6 @@ LogicalResult emitc::SubscriptOp::verify() {
return success();
}
-//===----------------------------------------------------------------------===//
-// TableGen'd op method definitions
-//===----------------------------------------------------------------------===//
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
-
//===----------------------------------------------------------------------===//
// EmitC Enums
//===----------------------------------------------------------------------===//
@@ -896,3 +889,111 @@ LogicalResult mlir::emitc::OpaqueType::verify(
}
return success();
}
+
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
+ TypeAttr type,
+ Attribute initialValue) {
+ p << type;
+ if (initialValue) {
+ p << " = ";
+ p.printAttributeWithoutType(initialValue);
+ }
+}
+
+static Type getInitializerTypeForGlobal(Type type) {
+ if (auto array = llvm::dyn_cast<ArrayType>(type))
+ return RankedTensorType::get(array.getShape(), array.getElementType());
+ return type;
+}
+
+static ParseResult
+parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &initialValue) {
+ Type type;
+ if (parser.parseType(type))
+ return failure();
+
+ typeAttr = TypeAttr::get(type);
+
+ if (parser.parseOptionalEqual())
+ return success();
+
+ if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
+ return failure();
+
+ if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
+ return parser.emitError(parser.getNameLoc())
+ << "initial value should be a unit, integer, float or elements "
+ "attribute";
+ return success();
+}
+
+LogicalResult GlobalOp::verify() {
+ if (getInitialValue().has_value()) {
+ Attribute initValue = getInitialValue().value();
+ // Check that the type of the initial value is compatible with the type of
+ // the global variable.
+ if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+ auto arrayType = llvm::dyn_cast<ArrayType>(getType());
+ if (!arrayType)
+ return emitOpError("expected array type, but got ") << getType();
+
+ Type initType = elementsAttr.getType();
+ Type tensorType = getInitializerTypeForGlobal(getType());
+ if (initType != tensorType) {
+ return emitOpError("initial value expected to be of type ")
+ << getType() << ", but was of type " << initType;
+ }
+ } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
+ if (intAttr.getType() != getType()) {
+ return emitOpError("initial value expected to be of type ")
+ << getType() << ", but was of type " << intAttr.getType();
+ }
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
+ if (floatAttr.getType() != getType()) {
+ return emitOpError("initial value expected to be of type ")
+ << getType() << ", but was of type " << floatAttr.getType();
+ }
+ } else {
+ return emitOpError(
+ "initial value should be a unit, integer, float or elements "
+ "attribute, but got ")
+ << initValue;
+ }
+ }
+ if (getStaticSpecifier() && getExternSpecifier()) {
+ return emitOpError("cannot have both static and extern specifiers");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Verify that the type matches the type of the global variable.
+ auto global =
+ symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
+ if (!global)
+ return emitOpError("'")
+ << getName() << "' does not reference a valid emitc.global";
+
+ Type resultType = getResult().getType();
+ if (global.getType() != resultType)
+ return emitOpError("result type ")
+ << resultType << " does not match type " << global.getType()
+ << " of the global @" << getName();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..820bb65dff0ac9 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -154,6 +154,9 @@ struct CppEmitter {
/// any result type could not be converted.
LogicalResult emitAssignPrefix(Operation &op);
+ /// Emits a global variable declaration or definition.
+ LogicalResult emitGlobalVariable(GlobalOp op);
+
/// Emits a label for the block.
LogicalResult emitLabel(Block &block);
@@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
return printConstantOp(emitter, operation, value);
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::GlobalOp globalOp) {
+
+ return emitter.emitGlobalVariable(globalOp);
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::AssignOp assignOp) {
OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
@@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
return emitter.emitOperand(assignOp.getValue());
}
+static LogicalResult printOperation(CppEmitter &emitter,
+ emitc::GetGlobalOp op) {
+ // Add name to cache so that `hasValueInScope` works.
+ emitter.getOrCreateName(op.getResult());
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter,
emitc::SubscriptOp subscriptOp) {
// Add name to cache so that `hasValueInScope` works.
@@ -1120,6 +1136,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
if (auto subscript =
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
valueMapper.insert(val, getSubscriptName(subscript));
+ } else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
+ val.getDefiningOp())) {
+ valueMapper.insert(val, getGlobal.getName().str());
} else {
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
}
@@ -1385,6 +1404,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return success();
}
+LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
+ if (op.getExternSpecifier())
+ os << "extern ";
+ else if (op.getStaticSpecifier())
+ os << "static ";
+ if (op.getConstSpecifier())
+ os << "const ";
+
+ if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
+ op.getSymName()))) {
+ return failure();
+ }
+
+ std::optional<Attribute> initialValue = op.getInitialValue();
+ if (initialValue && !isa<UnitAttr>(*initialValue)) {
+ os << " = ";
+ if (failed(emitAttribute(op->getLoc(), *initialValue)))
+ return failure();
+ }
+
+ os << ";";
+ return success();
+}
+
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
// If op is being emitted as part of an expression, bail out.
if (getEmittedExpression())
@@ -1445,11 +1488,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
- emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
- emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
- emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
- emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
- emitc::VerbatimOp>(
+ emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
+ emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
+ emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
+ emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1462,7 +1505,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
if (failed(status))
return failure();
- if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
+ if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
return success();
if (getEmittedExpression() ||
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 390190d341e5ae..89dafa7529ed53 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -38,3 +38,8 @@ func.func @zero_rank() {
%0 = memref.alloca() : memref<f32>
return
}
+
+// -----
+
+// expected-error at +1 {{failed to legalize operation 'memref.global'}}
+memref.global "nested" constant @nested_global : memref<3x7xf32>
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..54129f4f6cbc8e 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -11,6 +11,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
memref.store %v, %0[%i, %j] : memref<4x8xf32>
return
}
+
// -----
// CHECK-LABEL: memref_load
@@ -26,3 +27,19 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
// CHECK: return %[[VAR]] : f32
return %1 : f32
}
+
+// -----
+
+// CHECK-LABEL: globals
+module @globals {
+ memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
+ // CHECK: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+ memref.global @public_global : memref<3x7xf32>
+ // CHECK: emitc.global extern @public_global : !emitc.array<3x7xf32>
+
+ func.func @use_global() {
+ // CHECK: emitc.get_global @public_global : !emitc.array<3x7xf32>
+ %0 = memref.get_global @public_global : memref<3x7xf32>
+ return
+ }
+}
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..82fa459a5c9270 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -395,3 +395,18 @@ func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2:
%0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
return
}
+
+// -----
+
+// expected-error @+1 {{'emitc.global' op cannot have both static and extern specifiers}}
+emitc.global extern static @uninit : i32
+
+// -----
+
+emitc.global @myglobal : !emitc.array<2xf32>
+
+func.func @use_global() {
+ // expected-error @+1 {{'emitc.get_global' op result type 'f32' does not match type '!emitc.array<2xf32>' of the global @myglobal}}
+ %0 = emitc.get_global @myglobal : f32
+ return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..3c987937f17212 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -224,3 +224,17 @@ emitc.verbatim "#endif // __cplusplus"
emitc.verbatim "typedef int32_t i32;"
emitc.verbatim "typedef float f32;"
+
+
+emitc.global @uninit : i32
+emitc.global @myglobal_int : i32 = 4
+emitc.global extern @external_linkage : i32
+emitc.global static @internal_linkage : i32
+emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
+emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
+
+func.func @use_global(%i: index) -> f32 {
+ %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
+ %1 = emitc.subscript %0[%i] : <2xf32>, index
+ return %1 : f32
+}
diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir
new file mode 100644
index 00000000000000..730d5e0337336f
--- /dev/null
+++ b/mlir/test/Target/Cpp/global.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+emitc.global extern @decl : i8
+// CHECK: extern int8_t decl;
+
+emitc.global @uninit : i32
+// CHECK: int32_t uninit;
+
+emitc.global @myglobal_int : i32 = 4
+// CHECK: int32_t myglobal_int = 4;
+
+emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
+// CHECK: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
+
+emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
+// CHECK: const int16_t myconstant[2] = {2, 2};
+
+emitc.global extern const @extern_constant : !emitc.array<2xi16>
+// CHECK: extern const int16_t extern_constant[2];
+
+emitc.global static @static_var : f32
+// CHECK: static float static_var;
+
+emitc.global static @static_const : f32 = 3.0
+// CHECK: static float static_const = 3.000000000e+00f;
+
+func.func @use_global(%i: index) -> f32 {
+ %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
+ %1 = emitc.subscript %0[%i] : <2xf32>, index
+ return %1 : f32
+ // CHECK-LABEL: use_global
+ // CHECK-SAME: (size_t [[V1:.*]])
+ // CHECK: return myglobal[[[V1]]];
+}
>From d5baf5f6cebb471a5d7c817784e0f15a6ee02865 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 22 Apr 2024 14:57:24 +0200
Subject: [PATCH 2/2] Update to main and address comments
---
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 15 +++++++++------
mlir/test/Dialect/EmitC/ops.mlir | 2 +-
mlir/test/Target/Cpp/global.mlir | 5 ++++-
4 files changed, 15 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 19d23932acb79f..54793bdd2a6ce9 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1079,7 +1079,7 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global",
}];
let arguments = (ins FlatSymbolRefAttr:$name);
- let results = (outs AnyType:$result);
+ let results = (outs EmitCType:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
}
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 760e55637bfb58..ef7b7a19489d4f 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1015,14 +1015,18 @@ parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
return failure();
- if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
+ if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
+ initialValue))
return parser.emitError(parser.getNameLoc())
- << "initial value should be a unit, integer, float or elements "
+ << "initial value should be a integer, float, elements or opaque "
"attribute";
return success();
}
LogicalResult GlobalOp::verify() {
+ if (!isSupportedEmitCType(getType())) {
+ return emitOpError("expected valid emitc type");
+ }
if (getInitialValue().has_value()) {
Attribute initValue = getInitialValue().value();
// Check that the type of the initial value is compatible with the type of
@@ -1048,10 +1052,9 @@ LogicalResult GlobalOp::verify() {
return emitOpError("initial value expected to be of type ")
<< getType() << ", but was of type " << floatAttr.getType();
}
- } else {
- return emitOpError(
- "initial value should be a unit, integer, float or elements "
- "attribute, but got ")
+ } else if (!isa<emitc::OpaqueAttr>(initValue)) {
+ return emitOpError("initial value should be a integer, float, elements "
+ "or opaque attribute, but got ")
<< initValue;
}
}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 85379ad1279928..05510e6dddbf59 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -242,6 +242,6 @@ emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
func.func @use_global(%i: index) -> f32 {
%0 = emitc.get_global @myglobal : !emitc.array<2xf32>
- %1 = emitc.subscript %0[%i] : <2xf32>, index
+ %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
return %1 : f32
}
diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir
index 730d5e0337336f..f0d92e862ae322 100644
--- a/mlir/test/Target/Cpp/global.mlir
+++ b/mlir/test/Target/Cpp/global.mlir
@@ -25,9 +25,12 @@ emitc.global static @static_var : f32
emitc.global static @static_const : f32 = 3.0
// CHECK: static float static_const = 3.000000000e+00f;
+emitc.global @opaque_init : !emitc.opaque<"char"> = #emitc.opaque<"CHAR_MIN">
+// CHECK: char opaque_init = CHAR_MIN;
+
func.func @use_global(%i: index) -> f32 {
%0 = emitc.get_global @myglobal : !emitc.array<2xf32>
- %1 = emitc.subscript %0[%i] : <2xf32>, index
+ %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
return %1 : f32
// CHECK-LABEL: use_global
// CHECK-SAME: (size_t [[V1:.*]])
More information about the Mlir-commits
mailing list