[Mlir-commits] [mlir] [mlir] `ResourceAttrInterface` to abstract AsmResourceBlob from resource handle. (PR #101780)
Pavel Prokofyev
llvmlistbot at llvm.org
Fri Aug 2 17:31:16 PDT 2024
https://github.com/integralpro created https://github.com/llvm/llvm-project/pull/101780
None
>From de0f6424be51449d2e6dd1d9d99874ec8257acec Mon Sep 17 00:00:00 2001
From: Pavel Prokofyev <pprokofyev at apple.com>
Date: Fri, 2 Aug 2024 17:19:29 -0700
Subject: [PATCH] [mlir] `ResourceAttrInterface` to abstract AsmResourceBlob
from resource handle.
---
.../mlir/IR/BuiltinAttributeInterfaces.h | 2 ++
.../mlir/IR/BuiltinAttributeInterfaces.td | 31 +++++++++++++++++++
mlir/include/mlir/IR/BuiltinAttributes.td | 2 +-
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 28 ++++++++---------
mlir/unittests/IR/AttributeTest.cpp | 17 ++++++++++
5 files changed, 64 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index c4a42020d1389..bb96de2ef1f92 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -19,6 +19,8 @@
namespace mlir {
+class AsmResourceBlob;
+
//===----------------------------------------------------------------------===//
// ElementsAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 954429c7d8eae..768964ac578ac 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,35 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
];
}
+//===----------------------------------------------------------------------===//
+// ResourceAttrInterface
+//===----------------------------------------------------------------------===//
+
+def ResourceAttrInterface : AttrInterface<"ResourceAttr", [TypedAttrInterface]> {
+ let cppNamespace = "::mlir";
+
+ let description = [{
+ The interface abstracts the nature of underlying resource blob from its handle.
+ }];
+
+ let methods = [
+ InterfaceMethod<
+ "Get blob key associated with the resource.",
+ "::mlir::StringRef", "getBlobKey", (ins),
+ [{}],
+ [{
+ return $_attr.getRawHandle().getKey();
+ }]
+ >,
+ InterfaceMethod<
+ "Get blob associated with the resource.",
+ "::mlir::AsmResourceBlob *", "getBlob", (ins),
+ [{}],
+ [{
+ return $_attr.getRawHandle().getBlob();
+ }]
+ >
+ ];
+}
+
#endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97b..aa0b29ffb6fd4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -431,7 +431,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
//===----------------------------------------------------------------------===//
def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements",
- "dense_resource_elements", [ElementsAttrInterface]> {
+ "dense_resource_elements", [ElementsAttrInterface, ResourceAttrInterface]> {
let summary = "An Attribute containing a dense multi-dimensional array "
"backed by a resource";
let description = [{
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b468228ea78b7..9549d4628fac0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -458,10 +458,11 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
/// of the innermost dimension. Constants for other dimensions are still
/// constructed recursively. Returns nullptr on failure and emits errors at
/// `loc`.
-static llvm::Constant *convertDenseResourceElementsAttr(
- Location loc, DenseResourceElementsAttr denseResourceAttr,
- llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
- assert(denseResourceAttr && "expected non-null attribute");
+static llvm::Constant *
+convertResourceAttr(Location loc, ResourceAttr resourceAttr,
+ llvm::Type *llvmType,
+ const ModuleTranslation &moduleTranslation) {
+ assert(resourceAttr && "expected non-null attribute");
llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
if (!llvm::ConstantDataSequential::isElementTypeCompatible(
@@ -470,10 +471,10 @@ static llvm::Constant *convertDenseResourceElementsAttr(
return nullptr;
}
- ShapedType type = denseResourceAttr.getType();
+ ShapedType type = mlir::cast<ShapedType>(resourceAttr.getType());
assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
- AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
+ AsmResourceBlob *blob = resourceAttr.getBlob();
if (!blob) {
emitError(loc, "resource does not exist");
return nullptr;
@@ -486,7 +487,7 @@ static llvm::Constant *convertDenseResourceElementsAttr(
// raw data.
// TODO: we may also need to consider endianness when cross-compiling to an
// architecture where it is different.
- int64_t numElements = denseResourceAttr.getType().getNumElements();
+ int64_t numElements = type.getNumElements();
int64_t elementByteSize = rawData.size() / numElements;
if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
emitError(loc, "raw data size does not match element type size");
@@ -497,9 +498,7 @@ static llvm::Constant *convertDenseResourceElementsAttr(
// innermost dimension may be that of the vector element type.
bool hasVectorElementType = isa<VectorType>(type.getElementType());
int64_t numAggregates =
- numElements / (hasVectorElementType
- ? 1
- : denseResourceAttr.getType().getShape().back());
+ numElements / (hasVectorElementType ? 1 : type.getShape().back());
ArrayRef<int64_t> outerShape = type.getShape();
if (!hasVectorElementType)
outerShape = outerShape.drop_back();
@@ -533,8 +532,8 @@ static llvm::Constant *convertDenseResourceElementsAttr(
// Create innermost constants and defer to the default constant creation
// mechanism for other dimensions.
SmallVector<llvm::Constant *> constants;
- int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
- (innermostLLVMType->getScalarSizeInBits() / 8);
+ int64_t aggregateSize =
+ type.getShape().back() * (innermostLLVMType->getScalarSizeInBits() / 8);
constants.reserve(numAggregates);
for (unsigned i = 0; i < numAggregates; ++i) {
StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
@@ -679,9 +678,8 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
return result;
}
- if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
- return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
- moduleTranslation);
+ if (auto resourceAttr = dyn_cast<ResourceAttr>(attr)) {
+ return convertResourceAttr(loc, resourceAttr, llvmType, moduleTranslation);
}
// Fall back to element-by-element construction otherwise.
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e72bfe9d82e7c..207373722dcbb 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -382,6 +382,23 @@ TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
},
"invalid shape element type for provided type `T`");
}
+
+TEST(DenseResourceElementsAttrTest, CheckResourceInterface) {
+ MLIRContext context;
+ Builder builder(&context);
+
+ ArrayRef<double> data = {0, 1, 2};
+ auto elementType = builder.getF64Type();
+ auto type = RankedTensorType::get(data.size(), elementType);
+ auto attr = DenseF64ResourceElementsAttr::get(
+ type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
+
+ EXPECT_TRUE(isa<DenseF64ResourceElementsAttr>(attr));
+ auto resourceAttr = dyn_cast<ResourceAttr>(attr);
+ EXPECT_TRUE(resourceAttr);
+ EXPECT_TRUE(resourceAttr.getBlobKey() == "resource");
+ EXPECT_TRUE(resourceAttr.getBlob());
+}
} // namespace
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list