[Mlir-commits] [mlir] [mlir] `ResourceAttrInterface` to abstract AsmResourceBlob from resource handle. (PR #101780)

Pavel Prokofyev llvmlistbot at llvm.org
Fri Aug 2 17:31:16 PDT 2024


https://github.com/integralpro created https://github.com/llvm/llvm-project/pull/101780

None

>From de0f6424be51449d2e6dd1d9d99874ec8257acec Mon Sep 17 00:00:00 2001
From: Pavel Prokofyev <pprokofyev at apple.com>
Date: Fri, 2 Aug 2024 17:19:29 -0700
Subject: [PATCH] [mlir] `ResourceAttrInterface` to abstract AsmResourceBlob
 from resource handle.

---
 .../mlir/IR/BuiltinAttributeInterfaces.h      |  2 ++
 .../mlir/IR/BuiltinAttributeInterfaces.td     | 31 +++++++++++++++++++
 mlir/include/mlir/IR/BuiltinAttributes.td     |  2 +-
 mlir/lib/Target/LLVMIR/ModuleTranslation.cpp  | 28 ++++++++---------
 mlir/unittests/IR/AttributeTest.cpp           | 17 ++++++++++
 5 files changed, 64 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
index c4a42020d1389..bb96de2ef1f92 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h
@@ -19,6 +19,8 @@
 
 namespace mlir {
 
+class AsmResourceBlob;
+
 //===----------------------------------------------------------------------===//
 // ElementsAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
index 954429c7d8eae..768964ac578ac 100644
--- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td
@@ -495,4 +495,35 @@ def MemRefLayoutAttrInterface : AttrInterface<"MemRefLayoutAttrInterface"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ResourceAttrInterface
+//===----------------------------------------------------------------------===//
+
+def ResourceAttrInterface : AttrInterface<"ResourceAttr", [TypedAttrInterface]> {
+  let cppNamespace = "::mlir";
+
+  let description = [{
+    The interface abstracts the nature of underlying resource blob from its handle.
+  }];
+
+  let methods = [
+    InterfaceMethod<
+      "Get blob key associated with the resource.",
+      "::mlir::StringRef", "getBlobKey", (ins),
+      [{}],
+      [{
+        return $_attr.getRawHandle().getKey();
+      }]
+    >,
+    InterfaceMethod<
+      "Get blob associated with the resource.",
+      "::mlir::AsmResourceBlob *", "getBlob", (ins),
+      [{}],
+      [{
+        return $_attr.getRawHandle().getBlob();
+      }]
+    >
+  ];
+}
+
 #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_
diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td
index d9295936ee97b..aa0b29ffb6fd4 100644
--- a/mlir/include/mlir/IR/BuiltinAttributes.td
+++ b/mlir/include/mlir/IR/BuiltinAttributes.td
@@ -431,7 +431,7 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements",
-    "dense_resource_elements", [ElementsAttrInterface]> {
+    "dense_resource_elements", [ElementsAttrInterface, ResourceAttrInterface]> {
   let summary = "An Attribute containing a dense multi-dimensional array "
                 "backed by a resource";
   let description = [{
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b468228ea78b7..9549d4628fac0 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -458,10 +458,11 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
 /// of the innermost dimension. Constants for other dimensions are still
 /// constructed recursively. Returns nullptr on failure and emits errors at
 /// `loc`.
-static llvm::Constant *convertDenseResourceElementsAttr(
-    Location loc, DenseResourceElementsAttr denseResourceAttr,
-    llvm::Type *llvmType, const ModuleTranslation &moduleTranslation) {
-  assert(denseResourceAttr && "expected non-null attribute");
+static llvm::Constant *
+convertResourceAttr(Location loc, ResourceAttr resourceAttr,
+                    llvm::Type *llvmType,
+                    const ModuleTranslation &moduleTranslation) {
+  assert(resourceAttr && "expected non-null attribute");
 
   llvm::Type *innermostLLVMType = getInnermostElementType(llvmType);
   if (!llvm::ConstantDataSequential::isElementTypeCompatible(
@@ -470,10 +471,10 @@ static llvm::Constant *convertDenseResourceElementsAttr(
     return nullptr;
   }
 
-  ShapedType type = denseResourceAttr.getType();
+  ShapedType type = mlir::cast<ShapedType>(resourceAttr.getType());
   assert(type.getNumElements() > 0 && "Expected non-empty elements attribute");
 
-  AsmResourceBlob *blob = denseResourceAttr.getRawHandle().getBlob();
+  AsmResourceBlob *blob = resourceAttr.getBlob();
   if (!blob) {
     emitError(loc, "resource does not exist");
     return nullptr;
@@ -486,7 +487,7 @@ static llvm::Constant *convertDenseResourceElementsAttr(
   // raw data.
   // TODO: we may also need to consider endianness when cross-compiling to an
   // architecture where it is different.
-  int64_t numElements = denseResourceAttr.getType().getNumElements();
+  int64_t numElements = type.getNumElements();
   int64_t elementByteSize = rawData.size() / numElements;
   if (8 * elementByteSize != innermostLLVMType->getScalarSizeInBits()) {
     emitError(loc, "raw data size does not match element type size");
@@ -497,9 +498,7 @@ static llvm::Constant *convertDenseResourceElementsAttr(
   // innermost dimension may be that of the vector element type.
   bool hasVectorElementType = isa<VectorType>(type.getElementType());
   int64_t numAggregates =
-      numElements / (hasVectorElementType
-                         ? 1
-                         : denseResourceAttr.getType().getShape().back());
+      numElements / (hasVectorElementType ? 1 : type.getShape().back());
   ArrayRef<int64_t> outerShape = type.getShape();
   if (!hasVectorElementType)
     outerShape = outerShape.drop_back();
@@ -533,8 +532,8 @@ static llvm::Constant *convertDenseResourceElementsAttr(
   // Create innermost constants and defer to the default constant creation
   // mechanism for other dimensions.
   SmallVector<llvm::Constant *> constants;
-  int64_t aggregateSize = denseResourceAttr.getType().getShape().back() *
-                          (innermostLLVMType->getScalarSizeInBits() / 8);
+  int64_t aggregateSize =
+      type.getShape().back() * (innermostLLVMType->getScalarSizeInBits() / 8);
   constants.reserve(numAggregates);
   for (unsigned i = 0; i < numAggregates; ++i) {
     StringRef data(rawData.data() + i * aggregateSize, aggregateSize);
@@ -679,9 +678,8 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
     return result;
   }
 
-  if (auto denseResourceAttr = dyn_cast<DenseResourceElementsAttr>(attr)) {
-    return convertDenseResourceElementsAttr(loc, denseResourceAttr, llvmType,
-                                            moduleTranslation);
+  if (auto resourceAttr = dyn_cast<ResourceAttr>(attr)) {
+    return convertResourceAttr(loc, resourceAttr, llvmType, moduleTranslation);
   }
 
   // Fall back to element-by-element construction otherwise.
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index e72bfe9d82e7c..207373722dcbb 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -382,6 +382,23 @@ TEST(DenseResourceElementsAttrTest, CheckInvalidType) {
       },
       "invalid shape element type for provided type `T`");
 }
+
+TEST(DenseResourceElementsAttrTest, CheckResourceInterface) {
+  MLIRContext context;
+  Builder builder(&context);
+
+  ArrayRef<double> data = {0, 1, 2};
+  auto elementType = builder.getF64Type();
+  auto type = RankedTensorType::get(data.size(), elementType);
+  auto attr = DenseF64ResourceElementsAttr::get(
+      type, "resource", UnmanagedAsmResourceBlob::allocateInferAlign(data));
+
+  EXPECT_TRUE(isa<DenseF64ResourceElementsAttr>(attr));
+  auto resourceAttr = dyn_cast<ResourceAttr>(attr);
+  EXPECT_TRUE(resourceAttr);
+  EXPECT_TRUE(resourceAttr.getBlobKey() == "resource");
+  EXPECT_TRUE(resourceAttr.getBlob());
+}
 } // namespace
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list