[Mlir-commits] [mlir] EmitC: Add emitc.global and emitc.get_global (#145) (PR #88701)

Matthias Gehre llvmlistbot at llvm.org
Mon Apr 22 05:58:55 PDT 2024


https://github.com/mgehre-amd updated https://github.com/llvm/llvm-project/pull/88701

>From a711c683e5e464fd6af1d7ac260bcc809d9498da Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 25 Mar 2024 16:50:17 +0100
Subject: [PATCH 1/3] EmitC: Add emitc.global and emitc.get_global (#145)

This adds
- `emitc.global` and `emitc.get_global` ops to model global variables
similar to how `memref.global` and `memref.get_global` work.
- translation of those ops to C++
- lowering of `memref.global` and `memref.get_global` into those ops
---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  69 +++++++++++
 .../MemRefToEmitC/MemRefToEmitC.cpp           |  66 +++++++++-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 115 ++++++++++++++++--
 mlir/lib/Target/Cpp/TranslateToCpp.cpp        |  55 ++++++++-
 .../MemRefToEmitC/memref-to-emitc-failed.mlir |   5 +
 .../MemRefToEmitC/memref-to-emitc.mlir        |  17 +++
 mlir/test/Dialect/EmitC/invalid_ops.mlir      |  15 +++
 mlir/test/Dialect/EmitC/ops.mlir              |  14 +++
 mlir/test/Target/Cpp/global.mlir              |  35 ++++++
 9 files changed, 376 insertions(+), 15 deletions(-)
 create mode 100644 mlir/test/Target/Cpp/global.mlir

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index d746222ff37a4b..ee5fc0b09a1611 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1016,6 +1016,75 @@ def EmitC_VariableOp : EmitC_Op<"variable", []> {
   let hasVerifier = 1;
 }
 
+def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
+  let summary = "A global variable";
+  let description = [{
+    The `emitc.global` operation declares or defines a named global variable.
+    The backing memory for the variable is allocated statically and is
+    described by the type of the variable.
+    Optionally, and `initial_value` can be provided.
+    Internal linkage can be specified using the `staticSpecifier` unit attribute
+    and external linkage can be specified using the `externSpecifier` unit attribute.
+    Note that the default linkage without those two keywords depends on whether
+    the target is C or C++ and whether the global variable is `const`.
+    The global variable can also be marked constant using the `constSpecifier`
+    unit attribute. Writing to such constant global variables is
+    undefined.
+
+    The global variable can be accessed by using the `emitc.get_global` to
+    retrieve the value for the global variable.
+
+    Example:
+
+    ```mlir
+    // Global variable with an initial value.
+    emitc.global @x : emitc.array<2xf32> = dense<0.0, 2.0>
+    // External global variable
+    emitc.global extern @x : emitc.array<2xf32>
+    // Constant global variable with internal linkage
+    emitc.global static const @x : i32 = 0
+    ```
+  }];
+
+  let arguments = (ins SymbolNameAttr:$sym_name,
+                       TypeAttr:$type,
+                       OptionalAttr<EmitC_OpaqueOrTypedAttr>:$initial_value,
+                       UnitAttr:$externSpecifier,
+                       UnitAttr:$staticSpecifier,
+                       UnitAttr:$constSpecifier);
+
+  let assemblyFormat = [{
+       (`extern` $externSpecifier^)?
+       (`static` $staticSpecifier^)?
+       (`const` $constSpecifier^)?
+       $sym_name
+       `:` custom<EmitCGlobalOpTypeAndInitialValue>($type, $initial_value)
+       attr-dict
+  }];
+
+  let hasVerifier = 1;
+}
+
+def EmitC_GetGlobalOp : EmitC_Op<"get_global",
+    [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "Obtain access to a global variable";
+  let description = [{
+     The `emitc.get_global` operation retrieves the lvalue of a
+     named global variable. If the global variable is marked constant, assigning
+     to that lvalue is undefined.
+
+     Example:
+
+     ```mlir
+     %x = emitc.get_global @foo : !emitc.array<2xf32>
+     ```
+  }];
+
+  let arguments = (ins FlatSymbolRefAttr:$name);
+  let results = (outs AnyType:$result);
+  let assemblyFormat = "$name `:` type($result) attr-dict";
+}
+
 def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
   let summary = "Verbatim operation";
   let description = [{
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b6469212640..d3e7f233c08412 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -50,6 +50,68 @@ struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
   }
 };
 
+struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::GlobalOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getType().hasStaticShape()) {
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "cannot transform global with dynamic shape");
+    }
+
+    if (op.getAlignment().value_or(1) > 1) {
+      // TODO: Extend GlobalOp to specify alignment via the `alignas` specifier.
+      return rewriter.notifyMatchFailure(
+          op.getLoc(), "global variable with alignment requirement is "
+                       "currently not supported");
+    }
+    auto resultTy = getTypeConverter()->convertType(op.getType());
+    if (!resultTy) {
+      return rewriter.notifyMatchFailure(op.getLoc(),
+                                         "cannot convert result type");
+    }
+
+    SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
+    if (visibility != SymbolTable::Visibility::Public &&
+        visibility != SymbolTable::Visibility::Private) {
+      return rewriter.notifyMatchFailure(
+          op.getLoc(),
+          "only public and private visibility is currently supported");
+    }
+    // We are explicit in specifier the linkage because the default linkage
+    // for constants is different in C and C++.
+    bool staticSpecifier = visibility == SymbolTable::Visibility::Private;
+    bool externSpecifier = !staticSpecifier;
+
+    rewriter.replaceOpWithNewOp<emitc::GlobalOp>(
+        op, operands.getSymName(), resultTy, operands.getInitialValueAttr(),
+        externSpecifier, staticSpecifier, operands.getConstant());
+    return success();
+  }
+};
+
+struct ConvertGetGlobal final
+    : public OpConversionPattern<memref::GetGlobalOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::GetGlobalOp op, OpAdaptor operands,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto resultTy = getTypeConverter()->convertType(op.getType());
+    if (!resultTy) {
+      return rewriter.notifyMatchFailure(op.getLoc(),
+                                         "cannot convert result type");
+    }
+    rewriter.replaceOpWithNewOp<emitc::GetGlobalOp>(op, resultTy,
+                                                    operands.getNameAttr());
+    return success();
+  }
+};
+
 struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -109,6 +171,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
 
 void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
                                                    TypeConverter &converter) {
-  patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
-                                                         patterns.getContext());
+  patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
+               ConvertStore>(converter, patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index ab5c418e844fbf..e269f20578d5f8 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -790,13 +790,6 @@ LogicalResult emitc::SubscriptOp::verify() {
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// TableGen'd op method definitions
-//===----------------------------------------------------------------------===//
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
-
 //===----------------------------------------------------------------------===//
 // EmitC Enums
 //===----------------------------------------------------------------------===//
@@ -896,3 +889,111 @@ LogicalResult mlir::emitc::OpaqueType::verify(
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// GlobalOp
+//===----------------------------------------------------------------------===//
+static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
+                                                  TypeAttr type,
+                                                  Attribute initialValue) {
+  p << type;
+  if (initialValue) {
+    p << " = ";
+    p.printAttributeWithoutType(initialValue);
+  }
+}
+
+static Type getInitializerTypeForGlobal(Type type) {
+  if (auto array = llvm::dyn_cast<ArrayType>(type))
+    return RankedTensorType::get(array.getShape(), array.getElementType());
+  return type;
+}
+
+static ParseResult
+parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+                                      Attribute &initialValue) {
+  Type type;
+  if (parser.parseType(type))
+    return failure();
+
+  typeAttr = TypeAttr::get(type);
+
+  if (parser.parseOptionalEqual())
+    return success();
+
+  if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
+    return failure();
+
+  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
+    return parser.emitError(parser.getNameLoc())
+           << "initial value should be a unit, integer, float or elements "
+              "attribute";
+  return success();
+}
+
+LogicalResult GlobalOp::verify() {
+  if (getInitialValue().has_value()) {
+    Attribute initValue = getInitialValue().value();
+    // Check that the type of the initial value is compatible with the type of
+    // the global variable.
+    if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
+      auto arrayType = llvm::dyn_cast<ArrayType>(getType());
+      if (!arrayType)
+        return emitOpError("expected array type, but got ") << getType();
+
+      Type initType = elementsAttr.getType();
+      Type tensorType = getInitializerTypeForGlobal(getType());
+      if (initType != tensorType) {
+        return emitOpError("initial value expected to be of type ")
+               << getType() << ", but was of type " << initType;
+      }
+    } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
+      if (intAttr.getType() != getType()) {
+        return emitOpError("initial value expected to be of type ")
+               << getType() << ", but was of type " << intAttr.getType();
+      }
+    } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
+      if (floatAttr.getType() != getType()) {
+        return emitOpError("initial value expected to be of type ")
+               << getType() << ", but was of type " << floatAttr.getType();
+      }
+    } else {
+      return emitOpError(
+                 "initial value should be a unit, integer, float or elements "
+                 "attribute, but got ")
+             << initValue;
+    }
+  }
+  if (getStaticSpecifier() && getExternSpecifier()) {
+    return emitOpError("cannot have both static and extern specifiers");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Verify that the type matches the type of the global variable.
+  auto global =
+      symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
+  if (!global)
+    return emitOpError("'")
+           << getName() << "' does not reference a valid emitc.global";
+
+  Type resultType = getResult().getType();
+  if (global.getType() != resultType)
+    return emitOpError("result type ")
+           << resultType << " does not match type " << global.getType()
+           << " of the global @" << getName();
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 95c7af2f07be46..820bb65dff0ac9 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -154,6 +154,9 @@ struct CppEmitter {
   /// any result type could not be converted.
   LogicalResult emitAssignPrefix(Operation &op);
 
+  /// Emits a global variable declaration or definition.
+  LogicalResult emitGlobalVariable(GlobalOp op);
+
   /// Emits a label for the block.
   LogicalResult emitLabel(Block &block);
 
@@ -344,6 +347,12 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return printConstantOp(emitter, operation, value);
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::GlobalOp globalOp) {
+
+  return emitter.emitGlobalVariable(globalOp);
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::AssignOp assignOp) {
   OpResult result = assignOp.getVar().getDefiningOp()->getResult(0);
@@ -354,6 +363,13 @@ static LogicalResult printOperation(CppEmitter &emitter,
   return emitter.emitOperand(assignOp.getValue());
 }
 
+static LogicalResult printOperation(CppEmitter &emitter,
+                                    emitc::GetGlobalOp op) {
+  // Add name to cache so that `hasValueInScope` works.
+  emitter.getOrCreateName(op.getResult());
+  return success();
+}
+
 static LogicalResult printOperation(CppEmitter &emitter,
                                     emitc::SubscriptOp subscriptOp) {
   // Add name to cache so that `hasValueInScope` works.
@@ -1120,6 +1136,9 @@ StringRef CppEmitter::getOrCreateName(Value val) {
     if (auto subscript =
             dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
       valueMapper.insert(val, getSubscriptName(subscript));
+    } else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
+                   val.getDefiningOp())) {
+      valueMapper.insert(val, getGlobal.getName().str());
     } else {
       valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
     }
@@ -1385,6 +1404,30 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
   return success();
 }
 
+LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
+  if (op.getExternSpecifier())
+    os << "extern ";
+  else if (op.getStaticSpecifier())
+    os << "static ";
+  if (op.getConstSpecifier())
+    os << "const ";
+
+  if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
+                                     op.getSymName()))) {
+    return failure();
+  }
+
+  std::optional<Attribute> initialValue = op.getInitialValue();
+  if (initialValue && !isa<UnitAttr>(*initialValue)) {
+    os << " = ";
+    if (failed(emitAttribute(op->getLoc(), *initialValue)))
+      return failure();
+  }
+
+  os << ";";
+  return success();
+}
+
 LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
   // If op is being emitted as part of an expression, bail out.
   if (getEmittedExpression())
@@ -1445,11 +1488,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
                 emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
                 emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
                 emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
-                emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
-                emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
-                emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
-                emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
-                emitc::VerbatimOp>(
+                emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
+                emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
+                emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
+                emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
+                emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
               [&](auto op) { return printOperation(*this, op); })
           // Func ops.
           .Case<func::CallOp, func::FuncOp, func::ReturnOp>(
@@ -1462,7 +1505,7 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
   if (failed(status))
     return failure();
 
-  if (isa<emitc::LiteralOp, emitc::SubscriptOp>(op))
+  if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
     return success();
 
   if (getEmittedExpression() ||
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
index 390190d341e5ae..89dafa7529ed53 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir
@@ -38,3 +38,8 @@ func.func @zero_rank() {
   %0 = memref.alloca() : memref<f32>
   return
 }
+
+// -----
+
+// expected-error at +1 {{failed to legalize operation 'memref.global'}}
+memref.global "nested" constant @nested_global : memref<3x7xf32>
diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
index 9793b2d6d7832f..54129f4f6cbc8e 100644
--- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
+++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir
@@ -11,6 +11,7 @@ func.func @memref_store(%v : f32, %i: index, %j: index) {
   memref.store %v, %0[%i, %j] : memref<4x8xf32>
   return
 }
+
 // -----
 
 // CHECK-LABEL: memref_load
@@ -26,3 +27,19 @@ func.func @memref_load(%i: index, %j: index) -> f32 {
   // CHECK: return %[[VAR]] : f32
   return %1 : f32
 }
+
+// -----
+
+// CHECK-LABEL: globals
+module @globals {
+  memref.global "private" constant @internal_global : memref<3x7xf32> = dense<4.0>
+  // CHECK: emitc.global static const @internal_global : !emitc.array<3x7xf32> = dense<4.000000e+00>
+  memref.global @public_global : memref<3x7xf32>
+  // CHECK: emitc.global extern @public_global : !emitc.array<3x7xf32>
+
+  func.func @use_global() {
+    // CHECK: emitc.get_global @public_global : !emitc.array<3x7xf32>
+    %0 = memref.get_global @public_global : memref<3x7xf32>
+    return
+  }
+}
diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir
index 22423cf61b5556..82fa459a5c9270 100644
--- a/mlir/test/Dialect/EmitC/invalid_ops.mlir
+++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir
@@ -395,3 +395,18 @@ func.func @test_subscript_indices_mismatch(%arg0: !emitc.array<4x8xf32>, %arg2:
   %0 = emitc.subscript %arg0[%arg2] : <4x8xf32>, index
   return
 }
+
+// -----
+
+// expected-error @+1 {{'emitc.global' op cannot have both static and extern specifiers}}
+emitc.global extern static @uninit : i32
+
+// -----
+
+emitc.global @myglobal : !emitc.array<2xf32>
+
+func.func @use_global() {
+  // expected-error @+1 {{'emitc.get_global' op result type 'f32' does not match type '!emitc.array<2xf32>' of the global @myglobal}}
+  %0 = emitc.get_global @myglobal : f32
+  return
+}
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 5f00a295ed740e..3c987937f17212 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -224,3 +224,17 @@ emitc.verbatim "#endif  // __cplusplus"
 
 emitc.verbatim "typedef int32_t i32;"
 emitc.verbatim "typedef float f32;"
+
+
+emitc.global @uninit : i32
+emitc.global @myglobal_int : i32 = 4
+emitc.global extern @external_linkage : i32
+emitc.global static @internal_linkage : i32
+emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
+emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
+
+func.func @use_global(%i: index) -> f32 {
+  %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
+  %1 = emitc.subscript %0[%i] : <2xf32>, index
+  return %1 : f32
+}
diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir
new file mode 100644
index 00000000000000..730d5e0337336f
--- /dev/null
+++ b/mlir/test/Target/Cpp/global.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
+// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
+
+emitc.global extern @decl : i8
+// CHECK: extern int8_t decl;
+
+emitc.global @uninit : i32
+// CHECK: int32_t uninit;
+
+emitc.global @myglobal_int : i32 = 4
+// CHECK: int32_t myglobal_int = 4;
+
+emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
+// CHECK: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
+
+emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
+// CHECK: const int16_t myconstant[2] = {2, 2};
+
+emitc.global extern const @extern_constant : !emitc.array<2xi16>
+// CHECK: extern const int16_t extern_constant[2];
+
+emitc.global static @static_var : f32
+// CHECK: static float static_var;
+
+emitc.global static @static_const : f32 = 3.0
+// CHECK: static float static_const = 3.000000000e+00f;
+
+func.func @use_global(%i: index) -> f32 {
+  %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
+  %1 = emitc.subscript %0[%i] : <2xf32>, index
+  return %1 : f32
+  // CHECK-LABEL: use_global
+  // CHECK-SAME: (size_t [[V1:.*]])
+  // CHECK:   return myglobal[[[V1]]];
+}

>From d5baf5f6cebb471a5d7c817784e0f15a6ee02865 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 22 Apr 2024 14:57:24 +0200
Subject: [PATCH 2/3] Update to main and address comments

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td |  2 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp         | 15 +++++++++------
 mlir/test/Dialect/EmitC/ops.mlir            |  2 +-
 mlir/test/Target/Cpp/global.mlir            |  5 ++++-
 4 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 19d23932acb79f..54793bdd2a6ce9 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1079,7 +1079,7 @@ def EmitC_GetGlobalOp : EmitC_Op<"get_global",
   }];
 
   let arguments = (ins FlatSymbolRefAttr:$name);
-  let results = (outs AnyType:$result);
+  let results = (outs EmitCType:$result);
   let assemblyFormat = "$name `:` type($result) attr-dict";
 }
 
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 760e55637bfb58..ef7b7a19489d4f 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1015,14 +1015,18 @@ parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
   if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
     return failure();
 
-  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr>(initialValue))
+  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
+          initialValue))
     return parser.emitError(parser.getNameLoc())
-           << "initial value should be a unit, integer, float or elements "
+           << "initial value should be a integer, float, elements or opaque "
               "attribute";
   return success();
 }
 
 LogicalResult GlobalOp::verify() {
+  if (!isSupportedEmitCType(getType())) {
+    return emitOpError("expected valid emitc type");
+  }
   if (getInitialValue().has_value()) {
     Attribute initValue = getInitialValue().value();
     // Check that the type of the initial value is compatible with the type of
@@ -1048,10 +1052,9 @@ LogicalResult GlobalOp::verify() {
         return emitOpError("initial value expected to be of type ")
                << getType() << ", but was of type " << floatAttr.getType();
       }
-    } else {
-      return emitOpError(
-                 "initial value should be a unit, integer, float or elements "
-                 "attribute, but got ")
+    } else if (!isa<emitc::OpaqueAttr>(initValue)) {
+      return emitOpError("initial value should be a integer, float, elements "
+                         "or opaque attribute, but got ")
              << initValue;
     }
   }
diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir
index 85379ad1279928..05510e6dddbf59 100644
--- a/mlir/test/Dialect/EmitC/ops.mlir
+++ b/mlir/test/Dialect/EmitC/ops.mlir
@@ -242,6 +242,6 @@ emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
 
 func.func @use_global(%i: index) -> f32 {
   %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
-  %1 = emitc.subscript %0[%i] : <2xf32>, index
+  %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
   return %1 : f32
 }
diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir
index 730d5e0337336f..f0d92e862ae322 100644
--- a/mlir/test/Target/Cpp/global.mlir
+++ b/mlir/test/Target/Cpp/global.mlir
@@ -25,9 +25,12 @@ emitc.global static @static_var : f32
 emitc.global static @static_const : f32 = 3.0
 // CHECK: static float static_const = 3.000000000e+00f;
 
+emitc.global @opaque_init : !emitc.opaque<"char"> = #emitc.opaque<"CHAR_MIN">
+// CHECK: char opaque_init = CHAR_MIN;
+
 func.func @use_global(%i: index) -> f32 {
   %0 = emitc.get_global @myglobal : !emitc.array<2xf32>
-  %1 = emitc.subscript %0[%i] : <2xf32>, index
+  %1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
   return %1 : f32
   // CHECK-LABEL: use_global
   // CHECK-SAME: (size_t [[V1:.*]])

>From c079302d3ae1940209027b7cdaefe655f7a6760f Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre at amd.com>
Date: Mon, 22 Apr 2024 14:58:39 +0200
Subject: [PATCH 3/3] typo

---
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 54793bdd2a6ce9..0e8035f3c15b68 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1020,7 +1020,7 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> {
     The `emitc.global` operation declares or defines a named global variable.
     The backing memory for the variable is allocated statically and is
     described by the type of the variable.
-    Optionally, and `initial_value` can be provided.
+    Optionally, an `initial_value` can be provided.
     Internal linkage can be specified using the `staticSpecifier` unit attribute
     and external linkage can be specified using the `externSpecifier` unit attribute.
     Note that the default linkage without those two keywords depends on whether



More information about the Mlir-commits mailing list