[Mlir-commits] [mlir] 606f7c8 - [mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions

Matthias Springer llvmlistbot at llvm.org
Thu Jul 7 04:40:53 PDT 2022


Author: Matthias Springer
Date: 2022-07-07T13:36:28+02:00
New Revision: 606f7c8f7a770718bd7061d5a506711a9c84f482

URL: https://github.com/llvm/llvm-project/commit/606f7c8f7a770718bd7061d5a506711a9c84f482
DIFF: https://github.com/llvm/llvm-project/commit/606f7c8f7a770718bd7061d5a506711a9c84f482.diff

LOG: [mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions

The `unknownTypeConversion` bufferization option (enum) is now a type converter function option. Some logic of `getMemRefType` is now handled by that function.

This change makes type conversion more controllable. Previously, there were only two options when generating memref types for non-bufferizable ops: Static identity layout or fully dynamic layout. With this change, users of One-Shot Bufferize can provide a function with custom logic.

Differential Revision: https://reviews.llvm.org/D129273

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ff8db00f7644e..2cc84c99d2040 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -179,6 +179,10 @@ struct BufferizationOptions {
   /// Initializer function for dialect-specific analysis state.
   using DialectStateInitFn =
       std::function<std::unique_ptr<DialectAnalysisState>()>;
+  /// Tensor -> MemRef type converter.
+  /// Parameters: Value, memory space, bufferization options
+  using UnknownTypeConverterFn = std::function<BaseMemRefType(
+      Value, unsigned, const BufferizationOptions &)>;
 
   enum class LayoutMapOption : int8_t {
     InferLayoutMap = 0,
@@ -266,21 +270,11 @@ struct BufferizationOptions {
   LayoutMapOption functionBoundaryTypeConversion =
       LayoutMapOption::InferLayoutMap;
 
-  /// This flag controls buffer types on unknown ops (to_memref wrappers) and in
-  /// other cases where a precise memref type cannot be inferred (e.g., the
-  /// bufferization of "tensor.cast").
-  ///
-  /// * InferLayoutMap: This option is invalid and cannot be used.
-  /// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully
-  ///   dynamic layout maps after bufferization. This option is most efficient
-  ///   because any layout map can be casted to a fully dynamic one.
-  /// * IdentityLayoutMap: Assume that unknown ops have results with static
-  ///   identity layout (i.e., no layout map) after bufferization. This option
-  ///   introduces additional buffer allocs and copies if the unknown op is
-  ///   eventually bufferized to an op that returns a buffer with non-identity
-  ///   layout.
-  LayoutMapOption unknownTypeConversion =
-      LayoutMapOption::FullyDynamicLayoutMap;
+  /// Type converter from tensors to memrefs. This type converter is used if no
+  /// memref type could be inferred during bufferization. By default, a type
+  /// converter that returns a memref type with a fully dynamic layout map is
+  /// used.
+  UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
 
   /// Specifies whether dealloc ops should be generated along with alloc ops. If
   /// not, new memory allocations will leak.
@@ -505,20 +499,19 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
   return newOp;
 }
 
-/// Return a MemRefType to which the `tensorType` can be bufferized.
+/// Return a MemRefType to which the type of the given value can be bufferized.
 ///
 /// If possible, op bufferization implementations should not use this function
 /// and instead infer precise memref types for tensor results by themselves.
 ///
-/// Unless a layout map was specified, `options.unknownTypeConverter` determines
-/// what kind of layout map will be used. For best composability (without
-/// copies), the fully dynamic layout map is used by default.
+/// Unless a layout map was specified, `options.unknownTypeConverterFn`
+/// determines what kind of layout map will be used. For best composability
+/// (without copies), the fully dynamic layout map is used by default.
 ///
 /// Note: Canonicalization patterns could clean up layout maps and infer more
 /// precise layout maps after bufferization. However, many possible
 /// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(TensorType tensorType,
-                             const BufferizationOptions &options,
+BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
                              MemRefLayoutAttrInterface layout = {},
                              unsigned memorySpace = 0);
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 61caa18561d34..49a5c9115e718 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -351,8 +351,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*defaultImplementation=*/[{
           assert(bbArg.getOwner()->getParentOp() == $_op &&
                  "bbArg must belong to this op");
-          auto tensorType = bbArg.getType().cast<TensorType>();
-          return bufferization::getMemRefType(tensorType, options);
+          assert(bbArg.getType().isa<TensorType>() &&
+                 "expected tensor type");
+          return bufferization::getMemRefType(bbArg, options);
         }]
       >,
       InterfaceMethod<

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 374fbd7da664d..97a84bf220536 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -222,8 +222,17 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 // BufferizationOptions
 //===----------------------------------------------------------------------===//
 
+/// Default unknown type converter: Use a fully dynamic layout map.
+static BaseMemRefType
+defaultUnknownTypeConverter(Value value, unsigned memorySpace,
+                            const BufferizationOptions &options) {
+  return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
+                                             memorySpace);
+}
+
 // Default constructor for BufferizationOptions.
-BufferizationOptions::BufferizationOptions() = default;
+BufferizationOptions::BufferizationOptions()
+    : unknownTypeConverterFn(defaultUnknownTypeConverter) {}
 
 bool BufferizationOptions::isOpAllowed(Operation *op) const {
   // Special case: If function boundary bufferization is deactivated, do not
@@ -528,8 +537,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
 /// Return the buffer type for a given Value (tensor) after bufferization.
 FailureOr<BaseMemRefType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options) {
-  auto tensorType = value.getType().dyn_cast<TensorType>();
-  assert(tensorType && "unexpected non-tensor type");
+  assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
   Operation *op = getOwnerOfValue(value);
 
   // ToTensorOp: Take buffer type directly from the op.
@@ -566,7 +574,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
   if (!memorySpace.hasValue())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace);
+  return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
 }
 
 void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
@@ -652,10 +660,11 @@ bool bufferization::isFunctionArgument(Value value) {
   return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
 }
 
-BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
+BaseMemRefType bufferization::getMemRefType(Value value,
                                             const BufferizationOptions &options,
                                             MemRefLayoutAttrInterface layout,
                                             unsigned memorySpace) {
+  auto tensorType = value.getType().cast<TensorType>();
   auto memorySpaceAttr = IntegerAttr::get(
       IntegerType::get(tensorType.getContext(), 64), memorySpace);
 
@@ -674,17 +683,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
                            memorySpaceAttr);
   }
 
-  // Case 3: Configured with "fully dynamic layout maps".
-  if (options.unknownTypeConversion ==
-      BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
-    return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
-
-  // Case 4: Configured with "static identity layout maps".
-  if (options.unknownTypeConversion ==
-      BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
-    return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
-
-  llvm_unreachable("InferLayoutMap is an invalid option");
+  return options.unknownTypeConverterFn(value, memorySpace, options);
 }
 
 BaseMemRefType

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c68d1d120be6a..f1dfbd113947a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -192,8 +192,26 @@ struct OneShotBufferizePass
       opt.printConflicts = printConflicts;
       opt.testAnalysisOnly = testAnalysisOnly;
       opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
-      opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
 
+      // Configure type converter.
+      BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
+          parseLayoutMapOption(unknownTypeConversion);
+      opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
+                                       const BufferizationOptions &options) {
+        auto tensorType = value.getType().cast<TensorType>();
+        if (unknownTypeConversionOption ==
+            BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
+          return bufferization::getMemRefTypeWithStaticIdentityLayout(
+              tensorType, memorySpace);
+        assert(
+            unknownTypeConversionOption ==
+                BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
+            "invalid layout map option");
+        return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
+                                                                  memorySpace);
+      };
+
+      // Configure op filter.
       OpFilter::Entry::FilterFn filterFn =
           [&](Operation *op) {
             // Filter may be specified via options.
@@ -372,10 +390,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
                                          const BufferizationOptions &options,
                                          bool copyBeforeWrite,
                                          const OpFilter *opFilter) {
-  assert(options.unknownTypeConversion !=
-             BufferizationOptions::LayoutMapOption::InferLayoutMap &&
-         "invalid layout map option");
-
   if (copyBeforeWrite) {
     AnalysisState state(options);
     if (failed(insertTensorCopies(op, state)))
@@ -474,8 +488,11 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
   options.allowUnknownOps = true;
   options.createDeallocs = false;
   options.enforceAliasingInvariants = false;
-  options.unknownTypeConversion =
-      BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
+  options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
+                                      const BufferizationOptions &options) {
+    return getMemRefTypeWithStaticIdentityLayout(
+        value.getType().cast<TensorType>(), memorySpace);
+  };
   options.opFilter.allowDialect<BufferizationDialect>();
   return options;
 }

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 97da5969a3004..6cd9134b097ab 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -67,7 +67,7 @@ struct CastOpInterface
 
     // Compute the new memref type.
     Type resultMemRefType =
-        getMemRefType(resultTensorType, options, layout,
+        getMemRefType(castOp.getResult(), options, layout,
                       sourceMemRefType.getMemorySpaceAsInt());
 
     // Replace the op with a memref.cast.
@@ -780,9 +780,8 @@ struct ReshapeOpInterface
         getBuffer(rewriter, reshapeOp.getShape(), options);
     if (failed(srcBuffer) || failed(shapeBuffer))
       return failure();
-    auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
     auto resultMemRefType = getMemRefType(
-        resultTensorType, options, /*layout=*/{},
+        reshapeOp.getResult(), options, /*layout=*/{},
         srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
     replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
         rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);


        


More information about the Mlir-commits mailing list