[Mlir-commits] [mlir] [MLIR] Setting MemorySpace During Bufferization (PR #78484)

ian Bearman llvmlistbot at llvm.org
Tue Feb 6 08:28:30 PST 2024


https://github.com/manbearian updated https://github.com/llvm/llvm-project/pull/78484

>From e855b2f3907dd7eb444475768c68d0d47029199e Mon Sep 17 00:00:00 2001
From: Ian Bearman <ianb at microsoft.com>
Date: Tue, 6 Feb 2024 15:57:16 +0000
Subject: [PATCH] add defaultMemorySpaceFn (includes all PR feedback)

---
 .../Bufferization/IR/BufferizableOpInterface.h  | 15 ++++++++++-----
 .../Transforms/BufferizableOpInterfaceImpl.cpp  | 13 +++++++------
 .../IR/BufferizableOpInterface.cpp              | 14 ++++++++------
 .../Bufferization/IR/BufferizationOps.cpp       |  4 ++--
 .../Bufferization/Transforms/Bufferize.cpp      |  8 ++++++--
 .../FuncBufferizableOpInterfaceImpl.cpp         |  5 +++--
 .../Transforms/BufferizableOpInterfaceImpl.cpp  | 17 ++++++++++-------
 .../Bufferization/TestTensorCopyInsertion.cpp   |  6 ++++--
 8 files changed, 50 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 226a2fbd08563..d8cfeee246636 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -257,6 +257,9 @@ struct BufferizationOptions {
   /// Parameters: Value, memory space, bufferization options
   using UnknownTypeConverterFn = std::function<BaseMemRefType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
+  // Produce a MemorySpace attribute from a tensor type
+  using DefaultMemorySpaceFn =
+      std::function<std::optional<Attribute>(TensorType t)>;
 
   BufferizationOptions();
 
@@ -296,11 +299,6 @@ struct BufferizationOptions {
   /// bufferized or not.
   bool bufferizeFunctionBoundaries = false;
 
-  /// The default memory space that should be used when it cannot be inferred
-  /// from the context. If case of std::nullopt, bufferization fails when the
-  /// memory space cannot be inferred at any point.
-  std::optional<Attribute> defaultMemorySpace = Attribute();
-
   /// Certain ops have aliasing OpOperand/OpResult invariants (e.g., scf.for).
   /// If this flag is set to `false`, those invariants are no longer enforced
   /// with buffer copies.
@@ -351,6 +349,13 @@ struct BufferizationOptions {
   /// used.
   UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
 
+  // Use during type conversion to determine the memory space for memref based
+  // on the original tensor type if the memory space cannot be inferred.
+  // Returning std::nullopt will cause bufferization to fail (useful to indicate
+  // failure to determine memory space for a tensor type).
+  DefaultMemorySpaceFn defaultMemorySpaceFn =
+      [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+
   /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
   /// Should be used only with `testAnalysisOnly = true`.
   unsigned analysisFuzzerSeed = 0;
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f69b2557eec92..d7492c9e25db3 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,17 +26,18 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
+    auto type = constantOp.getType().dyn_cast<RankedTensorType>();
+
+    // Only ranked tensors are supported.
+    if (!type)
+      return failure();
 
     Attribute memorySpace;
-    if (options.defaultMemorySpace.has_value())
-      memorySpace = *options.defaultMemorySpace;
+    if (auto memSpace = options.defaultMemorySpaceFn(type))
+      memorySpace = *memSpace;
     else
       return constantOp->emitError("could not infer memory space");
 
-    // Only ranked tensors are supported.
-    if (!isa<RankedTensorType>(constantOp.getType()))
-      return failure();
-
     // Only constants inside a module are supported.
     auto moduleOp = constantOp->getParentOfType<ModuleOp>();
     if (!moduleOp)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 6ca9702cbbc66..8f0f6d1fcc849 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -682,11 +682,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
     return bufferizableOp.getBufferType(value, options, invocationStack);
 
   // Op is not bufferizable.
-  if (!options.defaultMemorySpace.has_value())
+  auto memSpace =
+      options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
+  if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{},
-                       *options.defaultMemorySpace);
+  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -936,11 +937,12 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
 
   // If we do not know the memory space and there is no default memory space,
   // report a failure.
-  if (!options.defaultMemorySpace.has_value())
+  auto memSpace =
+      options.defaultMemorySpaceFn(value.getType().cast<TensorType>());
+  if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(value, options, /*layout=*/{},
-                       *options.defaultMemorySpace);
+  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index eb4a96f354990..34a0c594a5a5a 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -234,8 +234,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     if (failed(copyBufferType))
       return failure();
     memorySpace = copyBufferType->getMemorySpace();
-  } else if (options.defaultMemorySpace.has_value()) {
-    memorySpace = *options.defaultMemorySpace;
+  } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
+    memorySpace = *ms;
   } else {
     return getOperation()->emitError("could not infer memory space");
   }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index dc94b72edcdf0..208cbda3a9eb6 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -210,8 +210,12 @@ struct OneShotBufferizePass
       opt.dumpAliasSets = dumpAliasSets;
       opt.setFunctionBoundaryTypeConversion(
           parseLayoutMapOption(functionBoundaryTypeConversion));
-      if (mustInferMemorySpace)
-        opt.defaultMemorySpace = std::nullopt;
+      if (mustInferMemorySpace) {
+        opt.defaultMemorySpaceFn =
+            [](TensorType t) -> std::optional<Attribute> {
+          return std::nullopt;
+        };
+      }
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 07cd1f90b17df..4cdbbf35dc876 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
   assert(tensorType && "expected TensorType");
 
   BaseMemRefType memrefType = options.functionArgTypeConverterFn(
-      tensorType, *options.defaultMemorySpace, funcOp, options);
+      tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
 
   auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
       index, BufferizationDialect::kBufferLayoutAttrName);
@@ -443,7 +443,8 @@ struct FuncOpInterface
       // Note: If `inferFunctionResultLayout = true`, cast are later folded
       // away.
       BaseMemRefType resultType = options.functionArgTypeConverterFn(
-          tensorType, *options.defaultMemorySpace, funcOp, options);
+          tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+          options);
       Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
           loc, resultType, returnVal);
       returnValues.push_back(toMemrefOp);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 678b7c099fa36..957f6314f3587 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -473,14 +473,14 @@ struct FromElementsOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto fromElementsOp = cast<tensor::FromElementsOp>(op);
+    auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
 
     // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpace != Attribute())
+    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
       return op->emitError("memory space not implemented yet");
 
     // Allocate a buffer for the result.
     Location loc = op->getLoc();
-    auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
     auto shape = tensorType.getShape();
     // TODO: Create alloc_tensor ops during TensorCopyInsertion.
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -588,8 +588,10 @@ struct GenerateOpInterface
                           const BufferizationOptions &options) const {
     auto generateOp = cast<tensor::GenerateOp>(op);
 
+    auto type = generateOp.getResult().getType();
+
     // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpace != Attribute())
+    if (options.defaultMemorySpaceFn(type) != Attribute())
       return op->emitError("memory space not implemented yet");
 
     // Allocate memory.
@@ -1007,10 +1009,6 @@ struct SplatOpInterface
     OpBuilder::InsertionGuard g(rewriter);
     auto splatOp = cast<tensor::SplatOp>(op);
 
-    // TODO: Implement memory space for this op.
-    if (options.defaultMemorySpace != Attribute())
-      return op->emitError("memory space not implemented yet");
-
     // Allocate memory.
     Location loc = op->getLoc();
     FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
@@ -1021,6 +1019,11 @@ struct SplatOpInterface
 
     // Create linalg::MapOp.
     auto tensorType = cast<RankedTensorType>(tensorAlloc->getType());
+
+    // TODO: Implement memory space for this op.
+    if (options.defaultMemorySpaceFn(tensorType) != Attribute())
+      return op->emitError("memory space not implemented yet");
+
     auto linalgOp =
         rewriter.create<linalg::MapOp>(loc, tensorType, /*inputs=*/ValueRange(),
                                        /*init=*/*tensorAlloc);
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
index fedfbe350a51a..2991a3c165ee2 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
@@ -44,8 +44,10 @@ struct TestTensorCopyInsertionPass
     bufferization::OneShotBufferizationOptions options;
     options.allowReturnAllocsFromLoops = allowReturnAllocsFromLoops;
     options.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
-    if (mustInferMemorySpace)
-      options.defaultMemorySpace = std::nullopt;
+    if (mustInferMemorySpace) {
+      options.defaultMemorySpaceFn =
+          [](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
+    }
     if (failed(bufferization::insertTensorCopies(getOperation(), options)))
       signalPassFailure();
   }



More information about the Mlir-commits mailing list