[Mlir-commits] [mlir] 867cd94 - [mlir][linalg][bufferize][NFC] Move BufferizationOptions to op interface

Matthias Springer llvmlistbot at llvm.org
Fri Dec 3 02:56:13 PST 2021


Author: Matthias Springer
Date: 2021-12-03T19:51:34+09:00
New Revision: 867cd948ace18c5eba8625005a62ea07f619a936

URL: https://github.com/llvm/llvm-project/commit/867cd948ace18c5eba8625005a62ea07f619a936
DIFF: https://github.com/llvm/llvm-project/commit/867cd948ace18c5eba8625005a62ea07f619a936.diff

LOG: [mlir][linalg][bufferize][NFC] Move BufferizationOptions to op interface

Also store a reference to BufferizationOptions in BufferizationState. This is in preparation of adding support for partial bufferization.

Differential Revision: https://reviews.llvm.org/D114661

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
    mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 52befbd726059..173126f5c7d4d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -26,7 +26,84 @@ class FuncOp;
 namespace linalg {
 namespace comprehensive_bufferize {
 
-class BufferizationAliasInfo;
+// TODO: from some HW description.
+static constexpr int64_t kBufferAlignments = 128;
+
+struct BufferizationState;
+
+/// Callback functions that are used to allocate/deallocate/copy memory buffers.
+/// Comprehensive Bufferize provides default implementations of these functions.
+// TODO: Could be replaced with a "bufferization strategy" object with virtual
+// functions in the future.
+struct AllocationCallbacks {
+  using AllocationFn = std::function<Optional<Value>(
+      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
+  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
+  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
+
+  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
+                      MemCpyFn copyFn)
+      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
+
+  /// A function that allocates memory.
+  AllocationFn allocationFn;
+
+  /// A function that deallocated memory. Must be allocated by `allocationFn`.
+  DeallocationFn deallocationFn;
+
+  /// A function that copies memory between two allocations.
+  MemCpyFn memCpyFn;
+};
+
+/// Return default allocation callbacks.
+std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
+
+/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
+/// executed after the analysis, but before bufferization. They can be used
+/// implement custom dialect-specific optimizations.
+struct PostAnalysisStep {
+  virtual ~PostAnalysisStep() {}
+
+  /// Run the post analysis step. This function may modify the IR, but must keep
+  /// `aliasInfo` (inside `state`) consistent. Newly created operations and
+  /// operations that should be re-analyzed must be stored in `newOps`.
+  virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                            SmallVector<Operation *> &newOps) = 0;
+};
+
+/// Options for ComprehensiveBufferize.
+struct BufferizationOptions {
+  BufferizationOptions();
+
+  // BufferizationOptions cannot be copied.
+  BufferizationOptions(const BufferizationOptions &other) = delete;
+
+  /// Register a "post analysis" step. Such steps are executed after the
+  /// analysis, but before bufferization.
+  template <typename Step, typename... Args>
+  void addPostAnalysisStep(Args... args) {
+    postAnalysisSteps.emplace_back(
+        std::make_unique<Step>(std::forward<Args>(args)...));
+  }
+
+  /// Helper functions for allocation, deallocation, memory copying.
+  std::unique_ptr<AllocationCallbacks> allocationFns;
+
+  /// Specifies whether returning newly allocated memrefs should be allowed.
+  /// Otherwise, a pass failure is triggered.
+  bool allowReturnMemref = false;
+
+  /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
+  /// Should be used only with `testAnalysisOnly = true`.
+  unsigned analysisFuzzerSeed = 0;
+
+  /// If set to `true`, does not modify the IR apart from adding attributes (for
+  /// checking the results of the analysis) and post analysis steps.
+  bool testAnalysisOnly = false;
+
+  /// Registered post analysis steps.
+  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
+};
 
 /// Specify fine-grain relationship between buffers to enable more analysis.
 enum class BufferRelation {
@@ -204,32 +281,6 @@ findValueInReverseUseDefChain(Value value,
 /// is returned regardless of whether it is a memory write or not.
 Value findLastPrecedingWrite(Value value);
 
-struct BufferizationState;
-
-/// Callback functions that are used to allocate/deallocate/copy memory buffers.
-/// Comprehensive Bufferize provides default implementations of these functions.
-// TODO: Could be replaced with a "bufferization strategy" object with virtual
-// functions in the future.
-struct AllocationCallbacks {
-  using AllocationFn = std::function<Optional<Value>(
-      OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
-  using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
-  using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
-
-  AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
-                      MemCpyFn copyFn)
-      : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
-
-  /// A function that allocates memory.
-  AllocationFn allocationFn;
-
-  /// A function that deallocated memory. Must be allocated by `allocationFn`.
-  DeallocationFn deallocationFn;
-
-  /// A function that copies memory between two allocations.
-  MemCpyFn memCpyFn;
-};
-
 /// Dialect-specific bufferization state. Analysis/bufferization information
 /// that is specific to ops from a certain dialect can be stored in derived
 /// variants of this struct.
@@ -240,8 +291,8 @@ struct DialectBufferizationState {
 /// BufferizationState keeps track of bufferization state and provides access to
 /// the results of the analysis.
 struct BufferizationState {
-  BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns)
-      : aliasInfo(moduleOp), allocationFns(allocationFns) {}
+  BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
+      : aliasInfo(moduleOp), options(options) {}
 
   // BufferizationState should be passed as a reference.
   BufferizationState(const BufferizationState &) = delete;
@@ -289,10 +340,6 @@ struct BufferizationState {
   /// `aliasInfo` keeps track of aliasing and equivalent values.
   BufferizationAliasInfo aliasInfo;
 
-  /// `allocationFns` contains helper functions for creating alloc ops, dealloc
-  /// ops and memcpy ops.
-  AllocationCallbacks &allocationFns;
-
   /// The mapping of tensors to buffers. May also contain mappings of non-tensor
   /// values.
   BlockAndValueMapping mapping;
@@ -302,6 +349,9 @@ struct BufferizationState {
 
   /// Dialect-specific bufferization state.
   DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
+
+  /// A reference to current bufferization options.
+  const BufferizationOptions &options;
 };
 
 /// Return the result buffer (memref) for a given OpResult (tensor). Allocate
@@ -320,19 +370,6 @@ LogicalResult bufferize(Block *block, BufferizationState &state);
 /// method of `BufferizableOpInterface`.
 LogicalResult bufferize(Operation *op, BufferizationState &state);
 
-/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
-/// executed after the analysis, but before bufferization. They can be used
-/// implement custom dialect-specific optimizations.
-struct PostAnalysisStep {
-  virtual ~PostAnalysisStep() {}
-
-  /// Run the post analysis step. This function may modify the IR, but must keep
-  /// `aliasInfo` (inside `state`) consistent. Newly created operations and
-  /// operations that should be re-analyzed must be stored in `newOps`.
-  virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
-                            SmallVector<Operation *> &newOps) = 0;
-};
-
 /// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
 /// with the same shape as `shapedType` and specified `layout` and
 /// `addressSpace`.

diff  --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index a1fd04dc4bd00..cd6b5268f442f 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -9,54 +9,15 @@
 #ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
 #define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
 
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinOps.h"
 
 namespace mlir {
 
-class ModuleOp;
-
 namespace linalg {
 namespace comprehensive_bufferize {
 
-// TODO: from some HW description.
-static constexpr int64_t kBufferAlignments = 128;
-
-/// Return default allocation callbacks.
-std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
-
-/// Options for ComprehensiveBufferize.
-struct BufferizationOptions {
-  BufferizationOptions();
-
-  // BufferizationOptions cannot be copied.
-  BufferizationOptions(const BufferizationOptions &other) = delete;
-
-  /// Register a "post analysis" step. Such steps are executed after the
-  /// analysis, but before bufferization.
-  template <typename Step, typename... Args>
-  void addPostAnalysisStep(Args... args) {
-    postAnalysisSteps.emplace_back(
-        std::make_unique<Step>(std::forward<Args>(args)...));
-  }
-
-  /// Helper functions for allocation, deallocation, memory copying.
-  std::unique_ptr<AllocationCallbacks> allocationFns;
-
-  /// Specifies whether returning newly allocated memrefs should be allowed.
-  /// Otherwise, a pass failure is triggered.
-  bool allowReturnMemref = false;
-
-  /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
-  /// Should be used only with `testAnalysisOnly = true`.
-  unsigned analysisFuzzerSeed = 0;
-
-  /// If set to `true`, does not modify the IR apart from adding attributes (for
-  /// checking the results of the analysis) and post analysis steps.
-  bool testAnalysisOnly = false;
-
-  /// Registered post analysis steps.
-  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
-};
+struct BufferizationOptions;
+struct BufferizationState;
 
 /// Bufferize the given function. Does not bufferize the function boundary.
 // TODO: This function is meant to be called from ModuleBufferize and not can

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e0c5a1020447e..c6ba42de66237 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -35,6 +35,45 @@ namespace comprehensive_bufferize {
 using namespace mlir;
 using namespace linalg::comprehensive_bufferize;
 
+//===----------------------------------------------------------------------===//
+// BufferizationOptions
+//===----------------------------------------------------------------------===//
+
+/// Default allocation function that is used by the comprehensive bufferization
+/// pass. The default currently creates a ranked memref using `memref.alloc`.
+static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
+                                           MemRefType type,
+                                           ArrayRef<Value> dynShape) {
+  Value allocated = b.create<memref::AllocOp>(
+      loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+  return allocated;
+}
+
+/// Default deallocation function that is used by the comprehensive
+/// bufferization pass. It expects to recieve back the value called from the
+/// `defaultAllocationFn`.
+static void defaultDeallocationFn(OpBuilder &b, Location loc,
+                                  Value allocatedBuffer) {
+  b.create<memref::DeallocOp>(loc, allocatedBuffer);
+}
+
+/// Default memory copy function that is used by the comprehensive bufferization
+/// pass. Creates a `memref.copy` op.
+static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
+  b.create<memref::CopyOp>(loc, from, to);
+}
+
+std::unique_ptr<AllocationCallbacks>
+mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
+  return std::make_unique<AllocationCallbacks>(
+      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
+}
+
+// Default constructor for BufferizationOptions that sets all allocation
+// callbacks to their default functions.
+BufferizationOptions::BufferizationOptions()
+    : allocationFns(defaultAllocationCallbacks()) {}
+
 //===----------------------------------------------------------------------===//
 // BufferizationAliasInfo
 //===----------------------------------------------------------------------===//
@@ -384,7 +423,8 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
     if (!skipCopy) {
       // The copy happens right before the op that is bufferized.
       b.setInsertionPoint(op);
-      state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
+      state.options.allocationFns->memCpyFn(b, loc, operandBuffer,
+                                            resultBuffer);
     }
     return resultBuffer;
   }
@@ -537,7 +577,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
   MemRefType allocMemRefType =
       getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
   Optional<Value> allocated =
-      allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
+      options.allocationFns->allocationFn(b, loc, allocMemRefType, dynShape);
   // TODO: For now just assert the value is returned. Eventually need to
   // error-propagate.
   assert(allocated && "allocation failed");
@@ -549,7 +589,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
 
   // 2. Create memory deallocation.
   b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
-  allocationFns.deallocationFn(b, loc, allocated.getValue());
+  options.allocationFns->deallocationFn(b, loc, allocated.getValue());
   return casted;
 }
 

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d4571d3ac6702..1a3ce684993c7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -783,39 +783,3 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
 
   return success();
 }
-
-/// Default allocation function that is used by the comprehensive bufferization
-/// pass. The default currently creates a ranked memref using `memref.alloc`.
-static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
-                                           MemRefType type,
-                                           ArrayRef<Value> dynShape) {
-  Value allocated = b.create<memref::AllocOp>(
-      loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
-  return allocated;
-}
-
-/// Default deallocation function that is used by the comprehensive
-/// bufferization pass. It expects to recieve back the value called from the
-/// `defaultAllocationFn`.
-static void defaultDeallocationFn(OpBuilder &b, Location loc,
-                                  Value allocatedBuffer) {
-  b.create<memref::DeallocOp>(loc, allocatedBuffer);
-}
-
-/// Default memory copy function that is used by the comprehensive bufferization
-/// pass. Creates a `memref.copy` op.
-static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
-  b.create<memref::CopyOp>(loc, from, to);
-}
-
-std::unique_ptr<AllocationCallbacks>
-mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
-  return std::make_unique<AllocationCallbacks>(
-      defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
-}
-
-// Default constructor for BufferizationOptions that sets all allocation
-// callbacks to their default functions.
-BufferizationOptions::BufferizationOptions()
-    : allocationFns(defaultAllocationCallbacks()) {}
-

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index b55395b69e46d..577065149efdb 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -648,7 +648,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
   if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
     return failure();
 
-  BufferizationState state(moduleOp, *options.allocationFns);
+  BufferizationState state(moduleOp, options);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // Interestingly, all function args that are not visible outside of a module

diff  --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 695e448d8b170..efc292f402657 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -189,8 +189,8 @@ struct ExtractSliceOpInterface
     if (!inplace) {
       // Do not copy if the copied data is never read.
       if (isValueRead(extractSliceOp.result()))
-        state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
-                                     alloc);
+        state.options.allocationFns->memCpyFn(b, extractSliceOp.getLoc(),
+                                              subView, alloc);
       subView = alloc;
     }
 
@@ -464,8 +464,8 @@ struct InsertSliceOpInterface
       state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
       // Copy tensor.
       Value srcMemref = state.lookupBuffer(insertSliceOp.source());
-      state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
-                                   subView);
+      state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
+                                            srcMemref, subView);
     }
 
     state.mapBuffer(insertSliceOp.result(), dstMemref);


        


More information about the Mlir-commits mailing list