[Mlir-commits] [mlir] b06614e - [mlir][bufferization][NFC] Change signature of getMemRefType

Matthias Springer llvmlistbot at llvm.org
Mon Jun 27 01:41:50 PDT 2022


Author: Matthias Springer
Date: 2022-06-27T10:41:40+02:00
New Revision: b06614e2e8d74339f65a46931dbf1521552df35b

URL: https://github.com/llvm/llvm-project/commit/b06614e2e8d74339f65a46931dbf1521552df35b
DIFF: https://github.com/llvm/llvm-project/commit/b06614e2e8d74339f65a46931dbf1521552df35b.diff

LOG: [mlir][bufferization][NFC] Change signature of getMemRefType

These functions now accep unsigned attributes for address spaces instead of Attributes.

Differential Revision: https://reviews.llvm.org/D128275

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index fa44fde98b6e..f28db1e26c09 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -513,18 +513,17 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
 BaseMemRefType getMemRefType(TensorType tensorType,
                              const BufferizationOptions &options,
                              MemRefLayoutAttrInterface layout = {},
-                             Attribute memorySpace = {});
+                             unsigned memorySpace = 0);
 
 /// Return a MemRef type with fully dynamic layout. If the given tensor type
 /// is unranked, return an unranked MemRef type.
 BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
-                                                   Attribute memorySpace = {});
+                                                   unsigned memorySpace = 0);
 
 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
 /// the given tensor type is unranked, return an unranked MemRef type.
-BaseMemRefType
-getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
-                                      Attribute memorySpace = {});
+BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
+                                                     unsigned memorySpace = 0);
 
 } // namespace bufferization
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 85a3f562ce99..3fe32bc293ea 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -596,12 +596,15 @@ bool bufferization::isFunctionArgument(Value value) {
 BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,
-                                            Attribute memorySpace) {
+                                            unsigned memorySpace) {
+  auto memorySpaceAttr = IntegerAttr::get(
+      IntegerType::get(tensorType.getContext(), 64), memorySpace);
+
   // Case 1: Unranked memref type.
   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
     assert(!layout && "UnrankedTensorType cannot have a layout map");
     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
-                                   memorySpace);
+                                   memorySpaceAttr);
   }
 
   // Case 2: Ranked memref type with specified layout.
@@ -609,7 +612,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
   if (layout) {
     return MemRefType::get(rankedTensorType.getShape(),
                            rankedTensorType.getElementType(), layout,
-                           memorySpace);
+                           memorySpaceAttr);
   }
 
   // Case 3: Configured with "fully dynamic layout maps".
@@ -627,7 +630,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
 
 BaseMemRefType
 bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
-                                                   Attribute memorySpace) {
+                                                   unsigned memorySpace) {
   // Case 1: Unranked memref type.
   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
@@ -635,6 +638,8 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
   }
 
   // Case 2: Ranked memref type.
+  auto memorySpaceAttr = IntegerAttr::get(
+      IntegerType::get(tensorType.getContext(), 64), memorySpace);
   auto rankedTensorType = tensorType.cast<RankedTensorType>();
   int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
   SmallVector<int64_t> dynamicStrides(rankedTensorType.getRank(),
@@ -643,14 +648,14 @@ bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType,
       dynamicStrides, dynamicOffset, rankedTensorType.getContext());
   return MemRefType::get(rankedTensorType.getShape(),
                          rankedTensorType.getElementType(), stridedLayout,
-                         memorySpace);
+                         memorySpaceAttr);
 }
 
 /// Return a MemRef type with a static identity layout (i.e., no layout map). If
 /// the given tensor type is unranked, return an unranked MemRef type.
 BaseMemRefType
 bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
-                                                     Attribute memorySpace) {
+                                                     unsigned memorySpace) {
   // Case 1: Unranked memref type.
   if (auto unrankedTensorType = tensorType.dyn_cast<UnrankedTensorType>()) {
     return UnrankedMemRefType::get(unrankedTensorType.getElementType(),
@@ -659,8 +664,10 @@ bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType,
 
   // Case 2: Ranked memref type.
   auto rankedTensorType = tensorType.cast<RankedTensorType>();
+  auto memorySpaceAttr = IntegerAttr::get(
+      IntegerType::get(tensorType.getContext(), 64), memorySpace);
   MemRefLayoutAttrInterface layout = {};
   return MemRefType::get(rankedTensorType.getShape(),
                          rankedTensorType.getElementType(), layout,
-                         memorySpace);
+                         memorySpaceAttr);
 }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 6bdae8ce7a05..7f8b6a35491d 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -54,7 +54,6 @@ struct CastOpInterface
     // The result buffer still has the old (pre-cast) type.
     Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options);
     auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
-    Attribute memorySpace = sourceMemRefType.getMemorySpace();
     TensorType resultTensorType =
         castOp.getResult().getType().cast<TensorType>();
     MemRefLayoutAttrInterface layout;
@@ -65,7 +64,8 @@ struct CastOpInterface
 
     // Compute the new memref type.
     Type resultMemRefType =
-        getMemRefType(resultTensorType, options, layout, memorySpace);
+        getMemRefType(resultTensorType, options, layout,
+                      sourceMemRefType.getMemorySpaceAsInt());
 
     // Replace the op with a memref.cast.
     assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),


        


More information about the Mlir-commits mailing list