[Mlir-commits] [mlir] 8bb5ca5 - [mlir] Support bufferization of arith.constant to memref.global with memory space

Maya Amrami llvmlistbot at llvm.org
Mon Mar 20 08:13:48 PDT 2023


Author: Maya Amrami
Date: 2023-03-20T17:09:51+02:00
New Revision: 8bb5ca58327ec5d0788b1546844b06b1118c5cb5

URL: https://github.com/llvm/llvm-project/commit/8bb5ca58327ec5d0788b1546844b06b1118c5cb5
DIFF: https://github.com/llvm/llvm-project/commit/8bb5ca58327ec5d0788b1546844b06b1118c5cb5.diff

LOG: [mlir] Support bufferization of arith.constant to memref.global with memory space

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D146381

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
    mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
    mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index 6c521acd0e146..85e9c47ad5302 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -125,7 +125,8 @@ class BufferPlacementTransformationBase {
 // Globals are created lazily at the top of the enclosing ModuleOp with pretty
 // names. Duplicates are avoided.
 FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
-                                         uint64_t alignment);
+                                         uint64_t alignment,
+                                         Attribute memorySpace = {});
 
 } // namespace bufferization
 } // namespace mlir

diff  --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 8408aad6e3fc1..9602d530cf826 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 
@@ -26,10 +27,11 @@ struct ConstantOpInterface
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
 
-    // TODO: Implement memory space for this op. E.g., by adding a memory_space
-    // attribute to ConstantOp.
-    if (options.defaultMemorySpace != Attribute())
-      return op->emitError("memory space not implemented yet");
+    Attribute memorySpace;
+    if (options.defaultMemorySpace.has_value())
+      memorySpace = *options.defaultMemorySpace;
+    else
+      return constantOp->emitError("could not infer memory space");
 
     // Only ranked tensors are supported.
     if (!constantOp.getType().isa<RankedTensorType>())
@@ -43,7 +45,7 @@ struct ConstantOpInterface
     // Create global memory segment and replace tensor with memref pointing to
     // that memory segment.
     FailureOr<memref::GlobalOp> globalOp =
-        getGlobalFor(constantOp, options.bufferAlignment);
+        getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
     if (failed(globalOp))
       return failure();
     memref::GlobalOp globalMemref = *globalOp;

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index 38d69194be1e8..b9776e2fb2095 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -147,7 +147,8 @@ bool BufferPlacementTransformationBase::isLoop(Operation *op) {
 //===----------------------------------------------------------------------===//
 
 FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+                            Attribute memorySpace) {
   auto type = constantOp.getType().cast<RankedTensorType>();
   auto moduleOp = constantOp->getParentOfType<ModuleOp>();
   if (!moduleOp)
@@ -184,10 +185,13 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) {
                     : IntegerAttr();
 
   BufferizeTypeConverter typeConverter;
+  auto memrefType = typeConverter.convertType(type).cast<MemRefType>();
+  if (memorySpace)
+    memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace);
   auto global = globalBuilder.create<memref::GlobalOp>(
       constantOp.getLoc(), (Twine("__constant_") + os.str()).str(),
       /*sym_visibility=*/globalBuilder.getStringAttr("private"),
-      /*type=*/typeConverter.convertType(type).cast<MemRefType>(),
+      /*type=*/memrefType,
       /*initial_value=*/constantOp.getValue().cast<ElementsAttr>(),
       /*constant=*/true,
       /*alignment=*/memrefAlignment);

diff  --git a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
index 315da00a00d78..deda8bb74b323 100644
--- a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
+++ b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir
@@ -13,8 +13,8 @@ func.func @inconsistent_memory_space_arith_select(%c: i1) -> tensor<10xf32> {
 
 // -----
 
-func.func @constant_memory_space(%idx: index, %v: i32) -> tensor<3xi32> {
-  // expected-error @+2 {{memory space not implemented yet}}
+func.func @unknown_memory_space(%idx: index, %v: i32) -> tensor<3xi32> {
+  // expected-error @+2 {{could not infer memory space}}
   // expected-error @+1 {{failed to bufferize op}}
   %cst = arith.constant dense<[5, 1000, 20]> : tensor<3xi32>
   %0 = tensor.insert %v into %cst[%idx] : tensor<3xi32>


        


More information about the Mlir-commits mailing list