[Mlir-commits] [mlir] c254b0b - [MLIR] Introduce std.global_memref and std.get_global_memref operations.
Rahul Joshi
llvmlistbot at llvm.org
Mon Nov 2 13:43:32 PST 2020
Author: Rahul Joshi
Date: 2020-11-02T13:43:04-08:00
New Revision: c254b0bb69635c8fb896e27452351234dacac178
URL: https://github.com/llvm/llvm-project/commit/c254b0bb69635c8fb896e27452351234dacac178
DIFF: https://github.com/llvm/llvm-project/commit/c254b0bb69635c8fb896e27452351234dacac178.diff
LOG: [MLIR] Introduce std.global_memref and std.get_global_memref operations.
- Add standard dialect operations to define global variables with memref types and to
retrieve the memref for to a named global variable
- Extend unit tests to test verification for these operations.
Differential Revision: https://reviews.llvm.org/D90337
Added:
Modified:
mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/Dialect/StandardOps/IR/Ops.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Parser/Parser.h
mlir/test/Dialect/Standard/invalid.mlir
mlir/test/Dialect/Standard/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
index 3b749c232dca..999119f51880 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -2005,6 +2005,97 @@ def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> {
let hasFolder = 0;
}
+//===----------------------------------------------------------------------===//
+// GlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> {
+ let summary = "declare or define a global memref variable";
+ let description = [{
+ The `global_memref` operation declares or defines a named global variable.
+ The backing memory for the variable is allocated statically and is described
+ by the type of the variable (which should be a statically shaped memref
+ type). The operation is a declaration if no `inital_value` is specified,
+ else it is a definition. The `initial_value` can either be a unit attribute
+ to represent a definition of an uninitialized global variable, or an
+ elements attribute to represent the definition of a global variable with an
+ initial value. The global variable can also be marked constant using the
+ `constant` unit attribute. Writing to such constant global variables is
+ undefined.
+
+ The global variable can be accessed by using the `get_global_memref` to
+ retrieve the memref for the global variable. Note that the memref
+ for such global variable itself is immutable (i.e., get_global_memref for a
+ given global variable will always return the same memref descriptor).
+
+ Example:
+
+ ```mlir
+ // Private variable with an initial value.
+ global_memref @x : memref<2xf32> { sym_visibility = "private",
+ initial_value = dense<0.0,2.0> : tensor<2xf32> }
+
+ // External variable.
+ global_memref @y : memref<4xi32> { sym_visibility = "public" }
+
+ // Uninitialized externally visible variable.
+ global_memref @z : memref<3xf16> { sym_visibility = "public",
+ initial_value }
+ ```
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$sym_name,
+ OptionalAttr<StrAttr>:$sym_visibility,
+ TypeAttr:$type,
+ OptionalAttr<AnyAttr>:$initial_value,
+ UnitAttr:$constant
+ );
+
+ let assemblyFormat = [{
+ ($sym_visibility^)?
+ (`constant` $constant^)?
+ $sym_name `:`
+ custom<GlobalMemrefOpTypeAndInitialValue>($type, $initial_value)
+ attr-dict
+ }];
+
+ let extraClassDeclaration = [{
+ bool isExternal() { return !initial_value(); }
+ bool isUnitialized() {
+ return !isExternal() && initial_value().getValue().isa<UnitAttr>();
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+def GetGlobalMemrefOp : Std_Op<"get_global_memref",
+ [NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "get the memref pointing to a global variable";
+ let description = [{
+ The `get_global_memref` operation retrieves the memref pointing to a
+ named global variable. If the global variable is marked constant, writing
+ to the result memref (such as through a `std.store` operation) is
+ undefined.
+
+ Example:
+
+ ```mlir
+ %x = get_global_memref @foo : memref<2xf32>
+ ```
+ }];
+
+ let arguments = (ins FlatSymbolRefAttr:$name);
+ let results = (outs AnyStaticShapeMemRef:$result);
+ let assemblyFormat = "$name `:` type($result) attr-dict";
+
+ // `GetGlobalMemrefOp` is fully verified by its traits.
+ let verifier = ?;
+}
+
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a4e222257e94..0813cde0256d 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -395,7 +395,7 @@ class OpAsmParser {
// Parse any kind of attribute.
Attribute attr;
- if (parseAttribute(attr))
+ if (parseAttribute(attr, type))
return failure();
// Check for the right kind of attribute.
@@ -436,6 +436,10 @@ class OpAsmParser {
Type type,
StringRef attrName,
NamedAttrList &attrs) = 0;
+ virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
+ Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) = 0;
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
index 5d4d242c2a21..666955c3c4c7 100644
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -245,6 +245,18 @@ static bool areVectorCastSimpleCompatible(
return false;
}
+//===----------------------------------------------------------------------===//
+// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+static Type getTensorTypeFromMemRefType(Type type) {
+ if (auto memref = type.dyn_cast<MemRefType>())
+ return RankedTensorType::get(memref.getShape(), memref.getElementType());
+ if (auto memref = type.dyn_cast<UnrankedMemRefType>())
+ return UnrankedTensorType::get(memref.getElementType());
+ return NoneType::get(type.getContext());
+}
+
//===----------------------------------------------------------------------===//
// AddFOp
//===----------------------------------------------------------------------===//
@@ -2140,6 +2152,106 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
return areVectorCastSimpleCompatible(a, b, areCastCompatible);
}
+//===----------------------------------------------------------------------===//
+// GlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p,
+ GlobalMemrefOp op,
+ TypeAttr type,
+ Attribute initialValue) {
+ p << type;
+ if (!op.isExternal()) {
+ p << " = ";
+ if (op.isUnitialized())
+ p << "uninitialized";
+ else
+ p.printAttributeWithoutType(initialValue);
+ }
+}
+
+static ParseResult
+parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &initialValue) {
+ Type type;
+ if (parser.parseType(type))
+ return failure();
+
+ auto memrefType = type.dyn_cast<MemRefType>();
+ if (!memrefType || !memrefType.hasStaticShape())
+ return parser.emitError(parser.getNameLoc())
+ << "type should be static shaped memref, but got " << type;
+ typeAttr = TypeAttr::get(type);
+
+ if (parser.parseOptionalEqual())
+ return success();
+
+ if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
+ initialValue = UnitAttr::get(parser.getBuilder().getContext());
+ return success();
+ }
+
+ Type tensorType = getTensorTypeFromMemRefType(memrefType);
+ if (parser.parseAttribute(initialValue, tensorType))
+ return failure();
+ if (!initialValue.isa<ElementsAttr>())
+ return parser.emitError(parser.getNameLoc())
+ << "initial value should be a unit or elements attribute";
+ return success();
+}
+
+static LogicalResult verify(GlobalMemrefOp op) {
+ auto memrefType = op.type().dyn_cast<MemRefType>();
+ if (!memrefType || !memrefType.hasStaticShape())
+ return op.emitOpError("type should be static shaped memref, but got ")
+ << op.type();
+
+ // Verify that the initial value, if present, is either a unit attribute or
+ // an elements attribute.
+ if (op.initial_value().hasValue()) {
+ Attribute initValue = op.initial_value().getValue();
+ if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
+ return op.emitOpError("initial value should be a unit or elements "
+ "attribute, but got ")
+ << initValue;
+
+ // Check that the type of the initial value is compatible with the type of
+ // the global variable.
+ if (initValue.isa<ElementsAttr>()) {
+ Type initType = initValue.getType();
+ Type tensorType = getTensorTypeFromMemRefType(memrefType);
+ if (initType != tensorType)
+ return op.emitOpError("initial value expected to be of type ")
+ << tensorType << ", but was of type " << initType;
+ }
+ }
+
+ // TODO: verify visibility for declarations.
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// GetGlobalMemrefOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ // Verify that the result type is same as the type of the referenced
+ // global_memref op.
+ auto global =
+ symbolTable.lookupNearestSymbolFrom<GlobalMemrefOp>(*this, nameAttr());
+ if (!global)
+ return emitOpError("'")
+ << name() << "' does not reference a valid global memref";
+
+ Type resultType = result().getType();
+ if (global.type() != resultType)
+ return emitOpError("result type ")
+ << resultType << " does not match type " << global.type()
+ << " of the global memref @" << name();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// IndexCastOp
//===----------------------------------------------------------------------===//
@@ -3891,18 +4003,6 @@ void TensorCastOp::getCanonicalizationPatterns(
results.insert<ChainedTensorCast>(context);
}
-//===----------------------------------------------------------------------===//
-// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
-//===----------------------------------------------------------------------===//
-
-static Type getTensorTypeFromMemRefType(Type type) {
- if (auto memref = type.dyn_cast<MemRefType>())
- return RankedTensorType::get(memref.getShape(), memref.getElementType());
- if (auto memref = type.dyn_cast<UnrankedMemRefType>())
- return UnrankedTensorType::get(memref.getElementType());
- return NoneType::get(type.getContext());
-}
-
//===----------------------------------------------------------------------===//
// TensorLoadOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index ef060d481f2a..6234d3b59660 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -226,6 +226,10 @@ OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
Type type) {
return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
}
+OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
+ Type type) {
+ return parseOptionalAttributeWithToken(Token::string, attribute, type);
+}
/// Attribute dictionary.
///
@@ -807,6 +811,7 @@ ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
/// Parse a dense elements attribute.
Attribute Parser::parseDenseElementsAttr(Type attrType) {
+ auto attribLoc = getToken().getLoc();
consumeToken(Token::kw_dense);
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
@@ -819,11 +824,14 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return nullptr;
}
- auto typeLoc = getToken().getLoc();
+ // If the type is specified `parseElementsLiteralType` will not parse a type.
+ // Use the attribute location as the location for error reporting in that
+ // case.
+ auto loc = attrType ? attribLoc : getToken().getLoc();
auto type = parseElementsLiteralType(attrType);
if (!type)
return nullptr;
- return literalParser.getAttr(typeLoc, type);
+ return literalParser.getAttr(loc, type);
}
/// Parse an opaque elements attribute.
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 48651a98561c..6d04f97f4576 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -1065,6 +1065,11 @@ class CustomOpAsmParser : public OpAsmParser {
NamedAttrList &attrs) override {
return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
}
+ OptionalParseResult parseOptionalAttribute(StringAttr &result, Type type,
+ StringRef attrName,
+ NamedAttrList &attrs) override {
+ return parseOptionalAttributeAndAddToList(result, type, attrName, attrs);
+ }
/// Parse a named dictionary into 'result' if it is present.
ParseResult parseOptionalAttrDict(NamedAttrList &result) override {
diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h
index c01a0a004072..9ffaea051f66 100644
--- a/mlir/lib/Parser/Parser.h
+++ b/mlir/lib/Parser/Parser.h
@@ -188,6 +188,7 @@ class Parser {
OptionalParseResult parseOptionalAttribute(Attribute &attribute,
Type type = {});
OptionalParseResult parseOptionalAttribute(ArrayAttr &attribute, Type type);
+ OptionalParseResult parseOptionalAttribute(StringAttr &attribute, Type type);
/// Parse an optional attribute that is demarcated by a specific token.
template <typename AttributeT>
diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir
index ea42028781ab..9ee69124c6ef 100644
--- a/mlir/test/Dialect/Standard/invalid.mlir
+++ b/mlir/test/Dialect/Standard/invalid.mlir
@@ -231,3 +231,84 @@ func @memref_reshape_result_affine_map_is_not_identity(
memref_reshape %buf(%shape)
: (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]>
}
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : i32
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : i32 = 5
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : memref<*xf32>
+
+// -----
+
+// expected-error @+1 {{type should be static shaped memref}}
+global_memref @foo : memref<?x?xf32>
+
+// -----
+
+// expected-error @+1 {{initial value should be a unit or elements attribute}}
+global_memref @foo : memref<2x2xf32> = "foo"
+
+// -----
+
+// expected-error @+1 {{inferred shape of elements literal ([2]) does not match type ([2, 2])}}
+global_memref @foo : memref<2x2xf32> = dense<[0.0, 1.0]>
+
+// -----
+
+// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
+global_memref "private" "public" @foo : memref<2x2xf32> = "foo"
+
+// -----
+
+// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
+global_memref constant external @foo : memref<2x2xf32> = "foo"
+
+// -----
+
+// constant qualifier must be after visibility.
+// expected-error @+1 {{expected valid '@'-identifier for symbol name}}
+global_memref constant "private" @foo : memref<2x2xf32> = "foo"
+
+
+// -----
+
+// expected-error @+1 {{op visibility expected to be one of ["public", "private", "nested"], but got "priate"}}
+global_memref "priate" constant @memref5 : memref<2xf32> = uninitialized
+
+// -----
+
+func @nonexistent_global_memref() {
+ // expected-error @+1 {{'gv' does not reference a valid global memref}}
+ %0 = get_global_memref @gv : memref<3xf32>
+ return
+}
+
+// -----
+
+func @foo()
+
+func @nonexistent_global_memref() {
+ // expected-error @+1 {{'foo' does not reference a valid global memref}}
+ %0 = get_global_memref @foo : memref<3xf32>
+ return
+}
+
+// -----
+
+global_memref @gv : memref<3xi32>
+
+func @mismatched_types() {
+ // expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}}
+ %0 = get_global_memref @gv : memref<3xf32>
+ return
+}
+
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
index fcf2325fd2e8..cd173670ae54 100644
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -77,3 +77,34 @@ func @memref_reshape(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return %new_unranked : memref<*xf32>
}
+
+// CHECK-LABEL: global_memref @memref0 : memref<2xf32>
+global_memref @memref0 : memref<2xf32>
+
+// CHECK-LABEL: global_memref constant @memref1 : memref<2xf32> = dense<[0.000000e+00, 1.000000e+00]>
+global_memref constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]>
+
+// CHECK-LABEL: global_memref @memref2 : memref<2xf32> = uninitialized
+global_memref @memref2 : memref<2xf32> = uninitialized
+
+// CHECK-LABEL: global_memref "private" @memref3 : memref<2xf32> = uninitialized
+global_memref "private" @memref3 : memref<2xf32> = uninitialized
+
+// CHECK-LABEL: global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
+global_memref "private" constant @memref4 : memref<2xf32> = uninitialized
+
+// CHECK-LABEL: func @write_global_memref
+func @write_global_memref() {
+ %0 = get_global_memref @memref0 : memref<2xf32>
+ %1 = constant dense<[1.0, 2.0]> : tensor<2xf32>
+ tensor_store %1, %0 : memref<2xf32>
+ return
+}
+
+// CHECK-LABEL: func @read_global_memref
+func @read_global_memref() {
+ %0 = get_global_memref @memref0 : memref<2xf32>
+ %1 = tensor_load %0 : memref<2xf32>
+ return
+}
+
More information about the Mlir-commits
mailing list