[Mlir-commits] [mlir] 995ab92 - [mlir] Add a new builtin DenseResourceElementsAttr
River Riddle
llvmlistbot at llvm.org
Mon Aug 1 12:50:14 PDT 2022
Author: River Riddle
Date: 2022-08-01T12:37:16-07:00
New Revision: 995ab92964d667123efd90d1f8016602c4a9df01
URL: https://github.com/llvm/llvm-project/commit/995ab92964d667123efd90d1f8016602c4a9df01
DIFF: https://github.com/llvm/llvm-project/commit/995ab92964d667123efd90d1f8016602c4a9df01.diff
LOG: [mlir] Add a new builtin DenseResourceElementsAttr
This attributes is intended cover the current set of use cases that abuse
DenseElementsAttr, e.g. when the data is large. Using resources for large
data is one of the major reasons why they were added; e.g. they can be
deallocated mid-compilation, they support a wide variety of data origins
(e.g, heap allocated, mmap'd, etc.), they can support mutation, etc.
I considered at length not having a builtin variant of this, and instead
having multiple versions of this attribute for dialects that are interested,
but they all boiled down to the exact same attribute definition. Given the
generality of this attribute, it feels more aligned to keep it next to DenseArrayAttr
(given that DenseArrayAttr covers the "small" case, and DenseResourcesElementsAttr
covers the "large" case). The underlying infra used to build this attribute is
general, and having a builtin attribute doesn't preclude users from defining
their own when it makes sense (they can even share a blob manager with the
builtin dialect to avoid data duplication).
Differential Revision: https://reviews.llvm.org/D130022
Added:
mlir/test/IR/dense-resource-elements-attr.mlir
Modified:
mlir/include/mlir/IR/BuiltinAttributes.h
mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/OpImplementation.h
mlir/lib/AsmParser/AsmParserImpl.h
mlir/lib/AsmParser/AttributeParser.cpp
mlir/lib/AsmParser/Parser.cpp
mlir/lib/AsmParser/Parser.h
mlir/lib/AsmParser/TokenKinds.def
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/BuiltinDialect.cpp
mlir/lib/IR/DialectResourceBlobManager.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/test/IR/invalid-builtin-attributes.mlir
mlir/test/IR/invalid-file-metadata.mlir
mlir/unittests/IR/AttributeTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h
index 7adec3305a48c..eb8f0ca8334ec 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.h
+++ b/mlir/include/mlir/IR/BuiltinAttributes.h
@@ -17,8 +17,12 @@
namespace mlir {
class AffineMap;
+class AsmResourceBlob;
class BoolAttr;
+class BuiltinDialect;
class DenseIntElementsAttr;
+template <typename T>
+struct DialectResourceBlobHandle;
class FlatSymbolRefAttr;
class FunctionType;
class IntegerSet;
@@ -729,6 +733,13 @@ class SplatElementsAttr : public DenseElementsAttr {
return denseAttr && denseAttr.isSplat();
}
};
+
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+using DenseResourceElementsHandle = DialectResourceBlobHandle<BuiltinDialect>;
+
} // namespace mlir
//===----------------------------------------------------------------------===//
@@ -743,6 +754,9 @@ class SplatElementsAttr : public DenseElementsAttr {
//===----------------------------------------------------------------------===//
namespace mlir {
+//===----------------------------------------------------------------------===//
+// DenseArrayAttr
+
namespace detail {
/// Base class for DenseArrayAttr that is instantiated and specialized for each
/// supported element type below.
@@ -795,6 +809,71 @@ using DenseI64ArrayAttr = detail::DenseArrayAttr<int64_t>;
using DenseF32ArrayAttr = detail::DenseArrayAttr<float>;
using DenseF64ArrayAttr = detail::DenseArrayAttr<double>;
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+
+namespace detail {
+/// Base class for DenseResourceElementsAttr that is instantiated and
+/// specialized for each supported element type below.
+template <typename T>
+class DenseResourceElementsAttrBase : public DenseResourceElementsAttr {
+public:
+ using DenseResourceElementsAttr::DenseResourceElementsAttr;
+
+ /// A builder that inserts a new resource using the provided blob. The handle
+ /// of the inserted blob is used when building the attribute. The provided
+ /// `blobName` is used as a hint for the key of the new handle for the `blob`
+ /// resource, but may be changed if necessary to ensure uniqueness during
+ /// insertion.
+ static DenseResourceElementsAttrBase<T>
+ get(ShapedType type, StringRef blobName, AsmResourceBlob blob);
+
+ /// Return the data of this attribute as an ArrayRef<T> if it is present,
+ /// returns None otherwise.
+ Optional<ArrayRef<T>> tryGetAsArrayRef() const;
+
+ /// Support for isa<>/cast<>.
+ static bool classof(Attribute attr);
+};
+
+extern template class DenseResourceElementsAttrBase<bool>;
+extern template class DenseResourceElementsAttrBase<int8_t>;
+extern template class DenseResourceElementsAttrBase<int16_t>;
+extern template class DenseResourceElementsAttrBase<int32_t>;
+extern template class DenseResourceElementsAttrBase<int64_t>;
+extern template class DenseResourceElementsAttrBase<uint8_t>;
+extern template class DenseResourceElementsAttrBase<uint16_t>;
+extern template class DenseResourceElementsAttrBase<uint32_t>;
+extern template class DenseResourceElementsAttrBase<uint64_t>;
+extern template class DenseResourceElementsAttrBase<float>;
+extern template class DenseResourceElementsAttrBase<double>;
+} // namespace detail
+
+// Public names for all the supported DenseResourceElementsAttr.
+
+using DenseBoolResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<bool>;
+using DenseI8ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<int8_t>;
+using DenseI16ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<int16_t>;
+using DenseI32ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<int32_t>;
+using DenseI64ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<int64_t>;
+using DenseUI8ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<uint8_t>;
+using DenseUI16ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<uint16_t>;
+using DenseUI32ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<uint32_t>;
+using DenseUI64ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<uint64_t>;
+using DenseF32ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<float>;
+using DenseF64ResourceElementsAttr =
+ detail::DenseResourceElementsAttrBase<double>;
+
//===----------------------------------------------------------------------===//
// BoolAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index 0b620908c2069..7a771d85d435d 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/SubElementInterfaces.td"
// TODO: Currently the attributes defined in this file are prefixed with
@@ -424,6 +425,65 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
let skipDefaultBuilders = 1;
}
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [
+ ElementsAttrInterface, TypedAttrInterface
+ ]> {
+ let summary = "An Attribute containing a dense multi-dimensional array "
+ "backed by a resource";
+ let description = [{
+ Syntax:
+
+ ```
+ dense-resource-elements-attribute ::=
+ `dense_resource` `<` resource-handle `>` `:` shaped-type
+ ```
+
+ A dense resource elements attribute is an elements attribute backed by a
+ handle to a builtin dialect resource containing a densely packed array of
+ values. This class provides the low-level attribute, which should only be
+ interacted with in very generic terms, actual access to the underlying
+ resource data is intended to be managed through one of the subclasses, such
+ as; `DenseBoolResourceElementsAttr`, `DenseUI64ResourceElementsAttr`,
+ `DenseI32ResourceElementsAttr`, `DenseF32ResourceElementsAttr`,
+ `DenseF64ResourceElementsAttr`, etc.
+
+ Examples:
+
+ ```mlir
+ // A tensor referencing a builtin dialect resource, `resource_1`, with two
+ // unsigned i32 elements.
+ dense_resource<resource_1> : tensor<2xui32>
+ ```
+ }];
+ let parameters = (ins
+ AttributeSelfTypeParameter<"", "ShapedType">:$type,
+ ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle
+ );
+ let builders = [
+ AttrBuilderWithInferredContext<(ins
+ "ShapedType":$type, "DenseResourceElementsHandle":$handle
+ )>
+ ];
+ let extraClassDeclaration = [{
+ protected:
+ /// A builder that inserts a new resource into the builtin dialect's blob
+ /// manager using the provided blob. The handle of the inserted blob is used
+ /// when building the attribute. The provided `blobName` is used as a hint
+ /// for the key of the new handle for the `blob` resource, but may be
+ /// changed if necessary to ensure uniqueness during insertion.
+ static DenseResourceElementsAttr get(
+ ShapedType type, StringRef blobName, AsmResourceBlob blob
+ );
+
+ public:
+ }];
+ let skipDefaultBuilders = 1;
+}
+
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index a69e3b4a4e074..5de78883a50b5 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -1023,8 +1023,17 @@ class AsmParser {
template <typename ResourceT>
FailureOr<ResourceT> parseResourceHandle() {
SMLoc handleLoc = getCurrentLocation();
- FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(
- getContext()->getOrLoadDialect<typename ResourceT::Dialect>());
+
+ // Try to load the dialect that owns the handle.
+ auto *dialect =
+ getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
+ if (!dialect) {
+ return emitError(handleLoc)
+ << "dialect '" << ResourceT::Dialect::getDialectNamespace()
+ << "' is unknown";
+ }
+
+ FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
if (failed(handle))
return failure();
if (auto *result = dyn_cast<ResourceT>(&*handle))
diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h
index c06eb689964b3..5bc6c79faf94e 100644
--- a/mlir/lib/AsmParser/AsmParserImpl.h
+++ b/mlir/lib/AsmParser/AsmParserImpl.h
@@ -460,7 +460,7 @@ class AsmParserImpl : public BaseT {
/// Parse a handle to a resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) override {
- const auto *interface = dyn_cast_or_null<OpAsmDialectInterface>(dialect);
+ const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
if (!interface) {
return parser.emitError() << "dialect '" << dialect->getNamespace()
<< "' does not expect resource handles";
diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp
index dff8510c94fa1..faa60b6ffd3cf 100644
--- a/mlir/lib/AsmParser/AttributeParser.cpp
+++ b/mlir/lib/AsmParser/AttributeParser.cpp
@@ -15,9 +15,10 @@
#include "AsmParserImpl.h"
#include "mlir/AsmParser/AsmParserState.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Endian.h"
@@ -97,6 +98,10 @@ Attribute Parser::parseAttribute(Type type) {
case Token::kw_dense:
return parseDenseElementsAttr(type);
+ // Parse a dense resource elements attribute.
+ case Token::kw_dense_resource:
+ return parseDenseResourceElementsAttr(type);
+
// Parse a dictionary attribute.
case Token::l_brace: {
NamedAttrList elements;
@@ -241,6 +246,7 @@ OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
case Token::kw_affine_map:
case Token::kw_affine_set:
case Token::kw_dense:
+ case Token::kw_dense_resource:
case Token::kw_false:
case Token::kw_loc:
case Token::kw_opaque:
@@ -928,6 +934,39 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
return literalParser.getAttr(loc, type);
}
+Attribute Parser::parseDenseResourceElementsAttr(Type attrType) {
+ auto loc = getToken().getLoc();
+ consumeToken(Token::kw_dense_resource);
+ if (parseToken(Token::less, "expected '<' after 'dense_resource'"))
+ return nullptr;
+
+ // Parse the resource handle.
+ FailureOr<AsmDialectResourceHandle> rawHandle =
+ parseResourceHandle(getContext()->getLoadedDialect<BuiltinDialect>());
+ if (failed(rawHandle) || parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+
+ auto *handle = dyn_cast<DenseResourceElementsHandle>(&*rawHandle);
+ if (!handle)
+ return emitError(loc, "invalid `dense_resource` handle type"), nullptr;
+
+ // Parse the type of the attribute if the user didn't provide one.
+ SMLoc typeLoc = loc;
+ if (!attrType) {
+ typeLoc = getToken().getLoc();
+ if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType()))
+ return nullptr;
+ }
+
+ ShapedType shapedType = attrType.dyn_cast<ShapedType>();
+ if (!shapedType) {
+ emitError(typeLoc, "`dense_resource` expected a shaped type");
+ return nullptr;
+ }
+
+ return DenseResourceElementsAttr::get(shapedType, *handle);
+}
+
/// Parse an opaque elements attribute.
Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
SMLoc loc = getToken().getLoc();
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 9934ca529f7aa..6cc96e7cfd0ac 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -340,6 +340,17 @@ Parser::parseResourceHandle(const OpAsmDialectInterface *dialect,
return entry.second;
}
+FailureOr<AsmDialectResourceHandle>
+Parser::parseResourceHandle(Dialect *dialect) {
+ const auto *interface = dyn_cast<OpAsmDialectInterface>(dialect);
+ if (!interface) {
+ return emitError() << "dialect '" << dialect->getNamespace()
+ << "' does not expect resource handles";
+ }
+ StringRef resourceName;
+ return parseResourceHandle(interface, resourceName);
+}
+
//===----------------------------------------------------------------------===//
// Code Completion
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index 615f940b12a8b..d48eeb943ea6c 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -160,6 +160,7 @@ class Parser {
/// Parse a handle to a dialect resource within the assembly format.
FailureOr<AsmDialectResourceHandle>
parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name);
+ FailureOr<AsmDialectResourceHandle> parseResourceHandle(Dialect *dialect);
//===--------------------------------------------------------------------===//
// Type Parsing
@@ -272,6 +273,9 @@ class Parser {
Attribute parseDenseElementsAttr(Type attrType);
ShapedType parseElementsLiteralType(Type type);
+ /// Parse a dense resource elements attribute.
+ Attribute parseDenseResourceElementsAttr(Type attrType);
+
/// Parse a DenseArrayAttr.
Attribute parseDenseArrayAttr();
diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def
index 207af3871f8d3..f56e048a7fa11 100644
--- a/mlir/lib/AsmParser/TokenKinds.def
+++ b/mlir/lib/AsmParser/TokenKinds.def
@@ -87,6 +87,7 @@ TOK_KEYWORD(bf16)
TOK_KEYWORD(ceildiv)
TOK_KEYWORD(complex)
TOK_KEYWORD(dense)
+TOK_KEYWORD(dense_resource)
TOK_KEYWORD(f16)
TOK_KEYWORD(f32)
TOK_KEYWORD(f64)
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index e5fd5eaa2d0a7..433fe225223e6 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
@@ -1896,6 +1897,10 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
os << " ";
denseArrayAttr.printWithoutBraces(os);
os << "]";
+ } else if (auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>()) {
+ os << "dense_resource<";
+ printResourceHandle(resourceAttr.getRawHandle());
+ os << ">";
} else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
printLocation(locAttr);
} else {
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp
index 021da17b3c334..ec1988136e38b 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
@@ -36,11 +37,10 @@ using namespace mlir::detail;
//===----------------------------------------------------------------------===//
void BuiltinDialect::registerAttributes() {
- addAttributes<AffineMapAttr, ArrayAttr, DenseArrayBaseAttr,
- DenseIntOrFPElementsAttr, DenseStringElementsAttr,
- DictionaryAttr, FloatAttr, SymbolRefAttr, IntegerAttr,
- IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
- SparseElementsAttr, StringAttr, TypeAttr, UnitAttr>();
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/IR/BuiltinAttributes.cpp.inc"
+ >();
}
//===----------------------------------------------------------------------===//
@@ -1576,6 +1576,130 @@ bool DenseIntElementsAttr::classof(Attribute attr) {
return false;
}
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+DenseResourceElementsAttr
+DenseResourceElementsAttr::get(ShapedType type,
+ DenseResourceElementsHandle handle) {
+ return Base::get(type.getContext(), type, handle);
+}
+
+DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
+ StringRef blobName,
+ AsmResourceBlob blob) {
+ // Extract the builtin dialect resource manager from context and construct a
+ // handle by inserting a new resource using the provided blob.
+ auto &manager =
+ DenseResourceElementsHandle::getManagerInterface(type.getContext());
+ return get(type, manager.insert(blobName, std::move(blob)));
+}
+
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttrBase
+
+namespace {
+/// Instantiations of this class provide utilities for interacting with native
+/// data types in the context of DenseResourceElementsAttr.
+template <typename T>
+struct DenseResourceAttrUtil;
+template <size_t width, bool isSigned>
+struct DenseResourceElementsAttrIntUtil {
+ static bool checkElementType(Type eltType) {
+ IntegerType type = eltType.dyn_cast<IntegerType>();
+ if (!type || type.getWidth() != width)
+ return false;
+ return isSigned ? !type.isUnsigned() : !type.isSigned();
+ }
+};
+template <>
+struct DenseResourceAttrUtil<bool> {
+ static bool checkElementType(Type eltType) {
+ return eltType.isSignlessInteger(1);
+ }
+};
+template <>
+struct DenseResourceAttrUtil<int8_t>
+ : public DenseResourceElementsAttrIntUtil<8, true> {};
+template <>
+struct DenseResourceAttrUtil<uint8_t>
+ : public DenseResourceElementsAttrIntUtil<8, false> {};
+template <>
+struct DenseResourceAttrUtil<int16_t>
+ : public DenseResourceElementsAttrIntUtil<16, true> {};
+template <>
+struct DenseResourceAttrUtil<uint16_t>
+ : public DenseResourceElementsAttrIntUtil<16, false> {};
+template <>
+struct DenseResourceAttrUtil<int32_t>
+ : public DenseResourceElementsAttrIntUtil<32, true> {};
+template <>
+struct DenseResourceAttrUtil<uint32_t>
+ : public DenseResourceElementsAttrIntUtil<32, false> {};
+template <>
+struct DenseResourceAttrUtil<int64_t>
+ : public DenseResourceElementsAttrIntUtil<64, true> {};
+template <>
+struct DenseResourceAttrUtil<uint64_t>
+ : public DenseResourceElementsAttrIntUtil<64, false> {};
+template <>
+struct DenseResourceAttrUtil<float> {
+ static bool checkElementType(Type eltType) { return eltType.isF32(); }
+};
+template <>
+struct DenseResourceAttrUtil<double> {
+ static bool checkElementType(Type eltType) { return eltType.isF64(); }
+};
+} // namespace
+
+template <typename T>
+DenseResourceElementsAttrBase<T>
+DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
+ AsmResourceBlob blob) {
+ // Check that the blob is in the form we were expecting.
+ assert(blob.getDataAlignment() == alignof(T) &&
+ "alignment mismatch between expected alignment and blob alignment");
+ assert(((blob.getData().size() % sizeof(T)) == 0) &&
+ "size mismatch between expected element width and blob size");
+ assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&
+ "invalid shape element type for provided type `T`");
+ return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
+ .template cast<DenseResourceElementsAttrBase<T>>();
+}
+
+template <typename T>
+Optional<ArrayRef<T>>
+DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
+ if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
+ return blob->template getDataAs<T>();
+ return llvm::None;
+}
+
+template <typename T>
+bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
+ auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
+ return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
+ resourceAttr.getElementType());
+}
+
+namespace mlir {
+namespace detail {
+// Explicit instantiation for all the supported DenseResourceElementsAttr.
+template class DenseResourceElementsAttrBase<bool>;
+template class DenseResourceElementsAttrBase<int8_t>;
+template class DenseResourceElementsAttrBase<int16_t>;
+template class DenseResourceElementsAttrBase<int32_t>;
+template class DenseResourceElementsAttrBase<int64_t>;
+template class DenseResourceElementsAttrBase<uint8_t>;
+template class DenseResourceElementsAttrBase<uint16_t>;
+template class DenseResourceElementsAttrBase<uint32_t>;
+template class DenseResourceElementsAttrBase<uint64_t>;
+template class DenseResourceElementsAttrBase<float>;
+template class DenseResourceElementsAttrBase<double>;
+} // namespace detail
+} // namespace mlir
+
//===----------------------------------------------------------------------===//
// OpaqueElementsAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp
index 662bcd811db4e..7df22a9038f71 100644
--- a/mlir/lib/IR/BuiltinDialect.cpp
+++ b/mlir/lib/IR/BuiltinDialect.cpp
@@ -16,6 +16,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
@@ -23,14 +24,27 @@
using namespace mlir;
//===----------------------------------------------------------------------===//
-// Builtin Dialect
+// TableGen'erated dialect
//===----------------------------------------------------------------------===//
#include "mlir/IR/BuiltinDialect.cpp.inc"
+//===----------------------------------------------------------------------===//
+// BuiltinBlobManagerInterface
+//===----------------------------------------------------------------------===//
+
+using BuiltinBlobManagerInterface =
+ ResourceBlobManagerDialectInterfaceBase<DenseResourceElementsHandle>;
+
+//===----------------------------------------------------------------------===//
+// BuiltinOpAsmDialectInterface
+//===----------------------------------------------------------------------===//
+
namespace {
struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
- using OpAsmDialectInterface::OpAsmDialectInterface;
+ BuiltinOpAsmDialectInterface(Dialect *dialect,
+ BuiltinBlobManagerInterface &mgr)
+ : OpAsmDialectInterface(dialect), blobManager(mgr) {}
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (attr.isa<AffineMapAttr>()) {
@@ -57,6 +71,38 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface {
}
return AliasResult::NoAlias;
}
+
+ //===------------------------------------------------------------------===//
+ // Resources
+ //===------------------------------------------------------------------===//
+
+ std::string
+ getResourceKey(const AsmDialectResourceHandle &handle) const override {
+ return cast<DenseResourceElementsHandle>(handle).getKey().str();
+ }
+ FailureOr<AsmDialectResourceHandle>
+ declareResource(StringRef key) const final {
+ return blobManager.insert(key);
+ }
+ LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
+ FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
+ if (failed(blob))
+ return failure();
+
+ // Update the blob for this entry.
+ blobManager.update(entry.getKey(), std::move(*blob));
+ return success();
+ }
+ void
+ buildResources(Operation *op,
+ const SetVector<AsmDialectResourceHandle> &referencedResources,
+ AsmResourceBuilder &provider) const final {
+ blobManager.buildResources(provider, referencedResources.getArrayRef());
+ }
+
+private:
+ /// The blob manager for the dialect.
+ BuiltinBlobManagerInterface &blobManager;
};
} // namespace
@@ -68,7 +114,9 @@ void BuiltinDialect::initialize() {
#define GET_OP_LIST
#include "mlir/IR/BuiltinOps.cpp.inc"
>();
- addInterfaces<BuiltinOpAsmDialectInterface>();
+
+ auto &blobInterface = addInterface<BuiltinBlobManagerInterface>();
+ addInterface<BuiltinOpAsmDialectInterface>(blobInterface);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/DialectResourceBlobManager.cpp b/mlir/lib/IR/DialectResourceBlobManager.cpp
index dbfe9c1ef85e9..60a2fb2e3c591 100644
--- a/mlir/lib/IR/DialectResourceBlobManager.cpp
+++ b/mlir/lib/IR/DialectResourceBlobManager.cpp
@@ -57,7 +57,7 @@ auto DialectResourceBlobManager::insert(StringRef name,
Twine(nameCounter++).toVector(nameStorage);
// Try inserting with the new name.
- if (BlobEntry *entry = tryInsertion(name))
+ if (BlobEntry *entry = tryInsertion(nameStorage))
return *entry;
nameStorage.resize(name.size() + 1);
} while (true);
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index e338876ac95f9..d0de010e96f2d 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -712,8 +712,9 @@ class LSPCodeCompleteContext : public AsmParserCodeCompleteContext {
/// Signal a completion for an attribute.
void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
- appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
- "loc", "opaque", "sparse", "true", "unit"},
+ appendSimpleCompletions({"affine_set", "affine_map", "dense",
+ "dense_resource", "false", "loc", "opaque",
+ "sparse", "true", "unit"},
lsp::CompletionItemKind::Field,
/*sortText=*/"1");
diff --git a/mlir/test/IR/dense-resource-elements-attr.mlir b/mlir/test/IR/dense-resource-elements-attr.mlir
new file mode 100644
index 0000000000000..adba97994ff60
--- /dev/null
+++ b/mlir/test/IR/dense-resource-elements-attr.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file | FileCheck %s
+
+// CHECK: attr = dense_resource<blob1> : tensor<3xi64>
+"test.user_op"() {attr = dense_resource<blob1> : tensor<3xi64> } : () -> ()
+
+{-#
+ dialect_resources: {
+ builtin: {
+ // CHECK: blob1: "0x08000000010000000000000002000000000000000300000000000000"
+ blob1: "0x08000000010000000000000002000000000000000300000000000000"
+ }
+ }
+#-}
diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir
index f6df53f0248e1..d996c226d2e9c 100644
--- a/mlir/test/IR/invalid-builtin-attributes.mlir
+++ b/mlir/test/IR/invalid-builtin-attributes.mlir
@@ -519,3 +519,23 @@ func.func @duplicate_dictionary_attr_key() {
"J// -----
" // expected-error {{expected}}
+
+// -----
+
+// expected-error at +1 {{expected '<' after 'dense_resource'}}
+#attr = dense_resource>
+
+// -----
+
+// expected-error at +1 {{expected '>'}}
+#attr = dense_resource<resource
+
+// -----
+
+// expected-error at +1 {{expected ':'}}
+#attr = dense_resource<resource>
+
+// -----
+
+// expected-error at +1 {{`dense_resource` expected a shaped type}}
+#attr = dense_resource<resource> : i32
diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir
index 42f7b8ec68447..352cf19f11bef 100644
--- a/mlir/test/IR/invalid-file-metadata.mlir
+++ b/mlir/test/IR/invalid-file-metadata.mlir
@@ -59,10 +59,10 @@
// -----
-// expected-error at +4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}}
+// expected-error at +4 {{unknown 'resource' key 'unknown_entry' for dialect 'ml_program'}}
{-#
dialect_resources: {
- builtin: {
+ ml_program: {
unknown_entry: "foo"
}
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 82a1bcd5d1735..5611eacd96022 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -6,6 +6,8 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "gtest/gtest.h"
@@ -13,6 +15,10 @@
using namespace mlir;
using namespace mlir::detail;
+//===----------------------------------------------------------------------===//
+// DenseElementsAttr
+//===----------------------------------------------------------------------===//
+
template <typename EltTy>
static void testSplat(Type eltType, const EltTy &splatElt) {
RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
@@ -203,7 +209,119 @@ TEST(DenseScalarTest, ExtractZeroRankElement) {
auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue}));
EXPECT_TRUE(attr.getValues<Attribute>()[0] == value);
}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// DenseResourceElementsAttr
+//===----------------------------------------------------------------------===//
+
+template <typename AttrT, typename T>
+static void checkNativeAccess(MLIRContext *ctx, ArrayRef<T> data,
+ Type elementType) {
+ auto type = RankedTensorType::get(data.size(), elementType);
+ auto attr =
+ AttrT::get(type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+
+ // Check that we can access and iterate the data properly.
+ Optional<ArrayRef<T>> attrData = attr.tryGetAsArrayRef();
+ EXPECT_TRUE(attrData.hasValue());
+ EXPECT_EQ(*attrData, data);
+
+ // Check that we cast to this attribute when possible.
+ Attribute genericAttr = attr;
+ EXPECT_TRUE(genericAttr.template isa<AttrT>());
+}
+template <typename AttrT, typename T>
+static void checkNativeIntAccess(Builder &builder, size_t intWidth) {
+ T data[] = {0, 1, 2};
+ checkNativeAccess<AttrT, T>(builder.getContext(), llvm::makeArrayRef(data),
+ builder.getIntegerType(intWidth));
+}
+
+namespace {
+TEST(DenseResourceElementsAttrTest, CheckNativeAccess) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ // Bool
+ bool boolData[] = {true, false, true};
+ checkNativeAccess<DenseBoolResourceElementsAttr>(
+ &context, llvm::makeArrayRef(boolData), builder.getI1Type());
+
+ // Unsigned integers
+ checkNativeIntAccess<DenseUI8ResourceElementsAttr, uint8_t>(builder, 8);
+ checkNativeIntAccess<DenseUI16ResourceElementsAttr, uint16_t>(builder, 16);
+ checkNativeIntAccess<DenseUI32ResourceElementsAttr, uint32_t>(builder, 32);
+ checkNativeIntAccess<DenseUI64ResourceElementsAttr, uint64_t>(builder, 64);
+
+ // Signed integers
+ checkNativeIntAccess<DenseI8ResourceElementsAttr, int8_t>(builder, 8);
+ checkNativeIntAccess<DenseI16ResourceElementsAttr, int16_t>(builder, 16);
+ checkNativeIntAccess<DenseI32ResourceElementsAttr, int32_t>(builder, 32);
+ checkNativeIntAccess<DenseI64ResourceElementsAttr, int64_t>(builder, 64);
+
+ // Float
+ float floatData[] = {0, 1, 2};
+ checkNativeAccess<DenseF32ResourceElementsAttr>(
+ &context, llvm::makeArrayRef(floatData), builder.getF32Type());
+
+ // Double
+ double doubleData[] = {0, 1, 2};
+ checkNativeAccess<DenseF64ResourceElementsAttr>(
+ &context, llvm::makeArrayRef(doubleData), builder.getF64Type());
+}
+
+TEST(DenseResourceElementsAttrTest, CheckNoCast) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ // Create a i32 attribute.
+ ArrayRef<uint32_t> data;
+ auto type = RankedTensorType::get(data.size(), builder.getI32Type());
+ Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get(
+ type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+
+ EXPECT_TRUE(i32ResourceAttr.isa<DenseI32ResourceElementsAttr>());
+ EXPECT_FALSE(i32ResourceAttr.isa<DenseF32ResourceElementsAttr>());
+ EXPECT_FALSE(i32ResourceAttr.isa<DenseBoolResourceElementsAttr>());
+}
+TEST(DenseResourceElementsAttrTest, CheckInvalidData) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ // Create a bool attribute with data of the incorrect type.
+ ArrayRef<uint32_t> data;
+ auto type = RankedTensorType::get(data.size(), builder.getI32Type());
+ ASSERT_DEATH(
+ {
+ DenseBoolResourceElementsAttr::get(
+ type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+ },
+ "alignment mismatch between expected alignment and blob alignment");
+}
+
+TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ // Create a bool attribute with incorrect type.
+ ArrayRef<bool> data;
+ auto type = RankedTensorType::get(data.size(), builder.getI32Type());
+ ASSERT_DEATH(
+ {
+ DenseBoolResourceElementsAttr::get(
+ type, "resource", UnmanagedAsmResourceBlob::allocate(data));
+ },
+ "invalid shape element type for provided type `T`");
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// SparseElementsAttr
+//===----------------------------------------------------------------------===//
+
+namespace {
TEST(SparseElementsAttrTest, GetZero) {
MLIRContext context;
context.allowUnregisteredDialects();
More information about the Mlir-commits
mailing list