[Mlir-commits] [mlir] c254b0b - [MLIR] Introduce std.global_memref and std.get_global_memref operations.

Rahul Joshi llvmlistbot at llvm.org
Mon Nov 2 13:43:32 PST 2020


Author: Rahul Joshi
Date: 2020-11-02T13:43:04-08:00
New Revision: c254b0bb69635c8fb896e27452351234dacac178

URL: https://github.com/llvm/llvm-project/commit/c254b0bb69635c8fb896e27452351234dacac178
DIFF: https://github.com/llvm/llvm-project/commit/c254b0bb69635c8fb896e27452351234dacac178.diff

LOG: [MLIR] Introduce std.global_memref and std.get_global_memref operations.

- Add standard dialect operations to define global variables with memref types and to
  retrieve the memref for to a named global variable
- Extend unit tests to test verification for these operations.

Differential Revision: https://reviews.llvm.org/D90337

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
    mlir/include/mlir/IR/OpImplementation.h
    mlir/lib/Dialect/StandardOps/IR/Ops.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/lib/Parser/Parser.cpp
    mlir/lib/Parser/Parser.h
    mlir/test/Dialect/Standard/invalid.mlir
    mlir/test/Dialect/Standard/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 3b749c232dca..999119f51880 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2005,6 +2005,97 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
   let hasFolder = 0;
 }
 
+//===----------------------------------------------------------------------===//
+// GlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> {
+  let summary = "declare or define a global memref variable";
+  let description = [{
+    The `global_memref` operation declares or defines a named global variable.
+    The backing memory for the variable is allocated statically and is described
+    by the type of the variable (which should be a statically shaped memref
+    type). The operation is a declaration if no `inital_value` is specified,
+    else it is a definition. The `initial_value` can either be a unit attribute
+    to represent a definition of an uninitialized global variable, or an
+    elements attribute to represent the definition of a global variable with an
+    initial value. The global variable can also be marked constant using the
+    `constant` unit attribute. Writing to such constant global variables is
+    undefined.
+
+    The global variable can be accessed by using the `get_global_memref` to
+    retrieve the memref for the global variable. Note that the memref
+    for such global variable itself is immutable (i.e., get_global_memref for a
+    given global variable will always return the same memref descriptor).
+
+    Example:
+
+    ```mlir
+    // Private variable with an initial value.
+    global_memref @x : memref<2xf32> { sym_visibility = "private",
+                         initial_value = dense<0.0,2.0> : tensor<2xf32> }
+
+    // External variable.
+    global_memref @y : memref<4xi32> { sym_visibility = "public" }
+
+    // Uninitialized externally visible variable.
+    global_memref @z : memref<3xf16> { sym_visibility = "public",
+                         initial_value }
+    ```
+  }];
+
+  let arguments = (ins
+      SymbolNameAttr:$sym_name,
+      OptionalAttr<StrAttr>:$sym_visibility,
+      TypeAttr:$type,
+      OptionalAttr<AnyAttr>:$initial_value,
+      UnitAttr:$constant
+  );
+
+  let assemblyFormat = [{
+       ($sym_visibility^)?
+       (`constant` $constant^)?
+       $sym_name `:`
+       custom<GlobalMemrefOpTypeAndInitialValue>($type, $initial_value)
+       attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+     bool isExternal() { return !initial_value(); }
+     bool isUnitialized() {
+       return !isExternal() && initial_value().getValue().isa<UnitAttr>();
+     }
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+def GetGlobalMemrefOp : Std_Op<"get_global_memref",
+    [NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+  let summary = "get the memref pointing to a global variable";
+  let description = [{
+     The `get_global_memref` operation retrieves the memref pointing to a
+     named global variable. If the global variable is marked constant, writing
+     to the result memref (such as through a `std.store` operation) is
+     undefined.
+
+     Example:
+
+     ```mlir
+     %x = get_global_memref @foo : memref<2xf32>
+     ```
+  }];
+
+  let arguments = (ins FlatSymbolRefAttr:$name);
+  let results = (outs AnyStaticShapeMemRef:$result);
+  let assemblyFormat = "$name `:` type($result) attr-dict";
+
+  // `GetGlobalMemrefOp` is fully verified by its traits.
+  let verifier = ?;
+}
+
 //===----------------------------------------------------------------------===//
 // ImOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a4e222257e94..0813cde0256d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -395,7 +395,7 @@ class OpAsmParser {
 
     // Parse any kind of attribute.
     Attribute attr;
-    if (parseAttribute(attr))
+    if (parseAttribute(attr, type))
       return failure();
 
     // Check for the right kind of attribute.
@@ -436,6 +436,10 @@ class OpAsmParser {
                                                      Type type,
                                                      StringRef attrName,
                                                      NamedAttrList &attrs) = 0;
+  virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
+                                                     Type type,
+                                                     StringRef attrName,
+                                                     NamedAttrList &attrs) = 0;
 
   /// Parse an arbitrary attribute of a given type and return it in result. This
   /// also adds the attribute to the specified attribute list with the specified

diff  --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5d4d242c2a21..666955c3c4c7 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -245,6 +245,18 @@ static bool areVectorCastSimpleCompatible(
   return false;
 }
 
+//===----------------------------------------------------------------------===//
+// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+static Type getTensorTypeFromMemRefType(Type type) {
+  if (auto memref = type.dyn_cast<MemRefType>())
+    return RankedTensorType::get(memref.getShape(), memref.getElementType());
+  if (auto memref = type.dyn_cast<UnrankedMemRefType>())
+    return UnrankedTensorType::get(memref.getElementType());
+  return NoneType::get(type.getContext());
+}
+
 //===----------------------------------------------------------------------===//
 // AddFOp
 //===----------------------------------------------------------------------===//
@@ -2140,6 +2152,106 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
   return areVectorCastSimpleCompatible(a, b, areCastCompatible);
 }
 
+//===----------------------------------------------------------------------===//
+// GlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
+                                                   GlobalMemrefOp op,
+                                                   TypeAttr type,
+                                                   Attribute initialValue) {
+  p << type;
+  if (!op.isExternal()) {
+    p << " = ";
+    if (op.isUnitialized())
+      p << "uninitialized";
+    else
+      p.printAttributeWithoutType(initialValue);
+  }
+}
+
+static ParseResult
+parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+                                       Attribute &initialValue) {
+  Type type;
+  if (parser.parseType(type))
+    return failure();
+
+  auto memrefType = type.dyn_cast<MemRefType>();
+  if (!memrefType || !memrefType.hasStaticShape())
+    return parser.emitError(parser.getNameLoc())
+           << "type should be static shaped memref, but got " << type;
+  typeAttr = TypeAttr::get(type);
+
+  if (parser.parseOptionalEqual())
+    return success();
+
+  if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
+    initialValue = UnitAttr::get(parser.getBuilder().getContext());
+    return success();
+  }
+
+  Type tensorType = getTensorTypeFromMemRefType(memrefType);
+  if (parser.parseAttribute(initialValue, tensorType))
+    return failure();
+  if (!initialValue.isa<ElementsAttr>())
+    return parser.emitError(parser.getNameLoc())
+           << "initial value should be a unit or elements attribute";
+  return success();
+}
+
+static LogicalResult verify(GlobalMemrefOp op) {
+  auto memrefType = op.type().dyn_cast<MemRefType>();
+  if (!memrefType || !memrefType.hasStaticShape())
+    return op.emitOpError("type should be static shaped memref, but got ")
+           << op.type();
+
+  // Verify that the initial value, if present, is either a unit attribute or
+  // an elements attribute.
+  if (op.initial_value().hasValue()) {
+    Attribute initValue = op.initial_value().getValue();
+    if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
+      return op.emitOpError("initial value should be a unit or elements "
+                            "attribute, but got ")
+             << initValue;
+
+    // Check that the type of the initial value is compatible with the type of
+    // the global variable.
+    if (initValue.isa<ElementsAttr>()) {
+      Type initType = initValue.getType();
+      Type tensorType = getTensorTypeFromMemRefType(memrefType);
+      if (initType != tensorType)
+        return op.emitOpError("initial value expected to be of type ")
+               << tensorType << ", but was of type " << initType;
+    }
+  }
+
+  // TODO: verify visibility for declarations.
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+  // Verify that the result type is same as the type of the referenced
+  // global_memref op.
+  auto global =
+      symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
+  if (!global)
+    return emitOpError("'")
+           << name() << "' does not reference a valid global memref";
+
+  Type resultType = result().getType();
+  if (global.type() != resultType)
+    return emitOpError("result type ")
+           << resultType << " does not match type " << global.type()
+           << " of the global memref @" << name();
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // IndexCastOp
 //===----------------------------------------------------------------------===//
@@ -3891,18 +4003,6 @@ void TensorCastOp::getCanonicalizationPatterns(
   results.insert<ChainedTensorCast>(context);
 }
 
-//===----------------------------------------------------------------------===//
-// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
-//===----------------------------------------------------------------------===//
-
-static Type getTensorTypeFromMemRefType(Type type) {
-  if (auto memref = type.dyn_cast<MemRefType>())
-    return RankedTensorType::get(memref.getShape(), memref.getElementType());
-  if (auto memref = type.dyn_cast<UnrankedMemRefType>())
-    return UnrankedTensorType::get(memref.getElementType());
-  return NoneType::get(type.getContext());
-}
-
 //===----------------------------------------------------------------------===//
 // TensorLoadOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index ef060d481f2a..6234d3b59660 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -226,6 +226,10 @@ OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
                                                    Type type) {
   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
 }
+OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
+                                                   Type type) {
+  return parseOptionalAttributeWithToken(Token::string, attribute, type);
+}
 
 /// Attribute dictionary.
 ///
@@ -807,6 +811,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
 
 /// Parse a dense elements attribute.
 Attribute Parser::parseDenseElementsAttr(Type attrType) {
+  auto attribLoc = getToken().getLoc();
   consumeToken(Token::kw_dense);
   if (parseToken(Token::less, "expected '<' after 'dense'"))
     return nullptr;
@@ -819,11 +824,14 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
       return nullptr;
   }
 
-  auto typeLoc = getToken().getLoc();
+  // If the type is specified `parseElementsLiteralType` will not parse a type.
+  // Use the attribute location as the location for error reporting in that
+  // case.
+  auto loc = attrType ? attribLoc : getToken().getLoc();
   auto type = parseElementsLiteralType(attrType);
   if (!type)
     return nullptr;
-  return literalParser.getAttr(typeLoc, type);
+  return literalParser.getAttr(loc, type);
 }
 
 /// Parse an opaque elements attribute.

diff  --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 48651a98561c..6d04f97f4576 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1065,6 +1065,11 @@ class CustomOpAsmParser : public OpAsmParser {
                                              NamedAttrList &attrs) override {
     return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
   }
+  OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
+                                             StringRef attrName,
+                                             NamedAttrList &attrs) override {
+    return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+  }
 
   /// Parse a named dictionary into 'result' if it is present.
   ParseResult parseOptionalAttrDict(NamedAttrList &result) override {

diff  --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index c01a0a004072..9ffaea051f66 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -188,6 +188,7 @@ class Parser {
   OptionalParseResult parseOptionalAttribute(Attribute &attribute,
                                              Type type = {});
   OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
+  OptionalParseResult parseOptionalAttribute(StringAttr &attribute, Type type);
 
   /// Parse an optional attribute that is demarcated by a specific token.
   template <typename AttributeT>

diff  --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index ea42028781ab..9ee69124c6ef 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -231,3 +231,84 @@ func @memref_reshape_result_affine_map_is_not_identity(
   memref_reshape %buf(%shape)
     : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]>
 }
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : i32
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : i32 = 5
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : memref<*xf32>
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : memref<?x?xf32>
+
+// -----
+
+// expected-error @+1 {{initial value should be a unit or elements attribute}}
+global_memref @foo : memref<2x2xf32>  = "foo"
+
+// -----
+
+// expected-error @+1 {{inferred shape of elements literal ([2]) does not match type ([2, 2])}}
+global_memref @foo : memref<2x2xf32> = dense<[0.0, 1.0]>
+
+// -----
+
+// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
+global_memref "private" "public" @foo : memref<2x2xf32>  = "foo"
+
+// -----
+
+// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
+global_memref constant external @foo : memref<2x2xf32>  = "foo"
+
+// -----
+
+// constant qualifier must be after visibility.
+// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
+global_memref constant "private" @foo : memref<2x2xf32>  = "foo"
+
+
+// -----
+
+// expected-error @+1 {{op visibility expected to be one of ["public", "private", "nested"], but got "priate"}}
+global_memref "priate" constant @memref5 : memref<2xf32>  = uninitialized
+
+// -----
+
+func @nonexistent_global_memref() {
+  // expected-error @+1 {{'gv' does not reference a valid global memref}}
+  %0 = get_global_memref @gv : memref<3xf32>
+  return
+}
+
+// -----
+
+func @foo()
+
+func @nonexistent_global_memref() {
+  // expected-error @+1 {{'foo' does not reference a valid global memref}}
+  %0 = get_global_memref @foo : memref<3xf32>
+  return
+}
+
+// -----
+
+global_memref @gv : memref<3xi32>
+
+func @mismatched_types() {
+  // expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}}
+  %0 = get_global_memref @gv : memref<3xf32>
+  return
+}
+

diff  --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index fcf2325fd2e8..cd173670ae54 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -77,3 +77,34 @@ func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
                : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
   return %new_unranked : memref<*xf32>
 }
+
+// CHECK-LABEL: global_memref @memref0 : memref<2xf32>
+global_memref @memref0 : memref<2xf32>
+
+// CHECK-LABEL: global_memref constant @memref1 : memref<2xf32> = dense<[0.000000e+00, 1.000000e+00]>
+global_memref constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]>
+
+// CHECK-LABEL: global_memref @memref2 : memref<2xf32> = uninitialized
+global_memref @memref2 : memref<2xf32>  = uninitialized
+
+// CHECK-LABEL: global_memref "private" @memref3 : memref<2xf32> = uninitialized
+global_memref "private" @memref3 : memref<2xf32>  = uninitialized
+
+// CHECK-LABEL: global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
+global_memref "private" constant @memref4 : memref<2xf32>  = uninitialized
+
+// CHECK-LABEL: func @write_global_memref
+func @write_global_memref() {
+  %0 = get_global_memref @memref0 : memref<2xf32>
+  %1 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+  tensor_store %1, %0 : memref<2xf32>
+  return
+}
+
+// CHECK-LABEL: func @read_global_memref
+func @read_global_memref() {
+  %0 = get_global_memref @memref0 : memref<2xf32>
+  %1 = tensor_load %0 : memref<2xf32>
+  return
+}
+


        


More information about the Mlir-commits mailing list