[Mlir-commits] [mlir] 867cd94 - [mlir][linalg][bufferize][NFC] Move BufferizationOptions to op interface
Matthias Springer
llvmlistbot at llvm.org
Fri Dec 3 02:56:13 PST 2021
Author: Matthias Springer
Date: 2021-12-03T19:51:34+09:00
New Revision: 867cd948ace18c5eba8625005a62ea07f619a936
URL: https://github.com/llvm/llvm-project/commit/867cd948ace18c5eba8625005a62ea07f619a936
DIFF: https://github.com/llvm/llvm-project/commit/867cd948ace18c5eba8625005a62ea07f619a936.diff
LOG: [mlir][linalg][bufferize][NFC] Move BufferizationOptions to op interface
Also store a reference to BufferizationOptions in BufferizationState. This is in preparation of adding support for partial bufferization.
Differential Revision: https://reviews.llvm.org/D114661
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
index 52befbd726059..173126f5c7d4d 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -26,7 +26,84 @@ class FuncOp;
namespace linalg {
namespace comprehensive_bufferize {
-class BufferizationAliasInfo;
+// TODO: from some HW description.
+static constexpr int64_t kBufferAlignments = 128;
+
+struct BufferizationState;
+
+/// Callback functions that are used to allocate/deallocate/copy memory buffers.
+/// Comprehensive Bufferize provides default implementations of these functions.
+// TODO: Could be replaced with a "bufferization strategy" object with virtual
+// functions in the future.
+struct AllocationCallbacks {
+ using AllocationFn = std::function<Optional<Value>(
+ OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
+ using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
+ using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
+
+ AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
+ MemCpyFn copyFn)
+ : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
+
+ /// A function that allocates memory.
+ AllocationFn allocationFn;
+
+ /// A function that deallocated memory. Must be allocated by `allocationFn`.
+ DeallocationFn deallocationFn;
+
+ /// A function that copies memory between two allocations.
+ MemCpyFn memCpyFn;
+};
+
+/// Return default allocation callbacks.
+std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
+
+/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
+/// executed after the analysis, but before bufferization. They can be used
+/// implement custom dialect-specific optimizations.
+struct PostAnalysisStep {
+ virtual ~PostAnalysisStep() {}
+
+ /// Run the post analysis step. This function may modify the IR, but must keep
+ /// `aliasInfo` (inside `state`) consistent. Newly created operations and
+ /// operations that should be re-analyzed must be stored in `newOps`.
+ virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
+ SmallVector<Operation *> &newOps) = 0;
+};
+
+/// Options for ComprehensiveBufferize.
+struct BufferizationOptions {
+ BufferizationOptions();
+
+ // BufferizationOptions cannot be copied.
+ BufferizationOptions(const BufferizationOptions &other) = delete;
+
+ /// Register a "post analysis" step. Such steps are executed after the
+ /// analysis, but before bufferization.
+ template <typename Step, typename... Args>
+ void addPostAnalysisStep(Args... args) {
+ postAnalysisSteps.emplace_back(
+ std::make_unique<Step>(std::forward<Args>(args)...));
+ }
+
+ /// Helper functions for allocation, deallocation, memory copying.
+ std::unique_ptr<AllocationCallbacks> allocationFns;
+
+ /// Specifies whether returning newly allocated memrefs should be allowed.
+ /// Otherwise, a pass failure is triggered.
+ bool allowReturnMemref = false;
+
+ /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
+ /// Should be used only with `testAnalysisOnly = true`.
+ unsigned analysisFuzzerSeed = 0;
+
+ /// If set to `true`, does not modify the IR apart from adding attributes (for
+ /// checking the results of the analysis) and post analysis steps.
+ bool testAnalysisOnly = false;
+
+ /// Registered post analysis steps.
+ std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
+};
/// Specify fine-grain relationship between buffers to enable more analysis.
enum class BufferRelation {
@@ -204,32 +281,6 @@ findValueInReverseUseDefChain(Value value,
/// is returned regardless of whether it is a memory write or not.
Value findLastPrecedingWrite(Value value);
-struct BufferizationState;
-
-/// Callback functions that are used to allocate/deallocate/copy memory buffers.
-/// Comprehensive Bufferize provides default implementations of these functions.
-// TODO: Could be replaced with a "bufferization strategy" object with virtual
-// functions in the future.
-struct AllocationCallbacks {
- using AllocationFn = std::function<Optional<Value>(
- OpBuilder &, Location, MemRefType, ArrayRef<Value>)>;
- using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
- using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
-
- AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
- MemCpyFn copyFn)
- : allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
-
- /// A function that allocates memory.
- AllocationFn allocationFn;
-
- /// A function that deallocated memory. Must be allocated by `allocationFn`.
- DeallocationFn deallocationFn;
-
- /// A function that copies memory between two allocations.
- MemCpyFn memCpyFn;
-};
-
/// Dialect-specific bufferization state. Analysis/bufferization information
/// that is specific to ops from a certain dialect can be stored in derived
/// variants of this struct.
@@ -240,8 +291,8 @@ struct DialectBufferizationState {
/// BufferizationState keeps track of bufferization state and provides access to
/// the results of the analysis.
struct BufferizationState {
- BufferizationState(ModuleOp moduleOp, AllocationCallbacks &allocationFns)
- : aliasInfo(moduleOp), allocationFns(allocationFns) {}
+ BufferizationState(ModuleOp moduleOp, const BufferizationOptions &options)
+ : aliasInfo(moduleOp), options(options) {}
// BufferizationState should be passed as a reference.
BufferizationState(const BufferizationState &) = delete;
@@ -289,10 +340,6 @@ struct BufferizationState {
/// `aliasInfo` keeps track of aliasing and equivalent values.
BufferizationAliasInfo aliasInfo;
- /// `allocationFns` contains helper functions for creating alloc ops, dealloc
- /// ops and memcpy ops.
- AllocationCallbacks &allocationFns;
-
/// The mapping of tensors to buffers. May also contain mappings of non-tensor
/// values.
BlockAndValueMapping mapping;
@@ -302,6 +349,9 @@ struct BufferizationState {
/// Dialect-specific bufferization state.
DenseMap<StringRef, std::unique_ptr<DialectBufferizationState>> dialectState;
+
+ /// A reference to current bufferization options.
+ const BufferizationOptions &options;
};
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
@@ -320,19 +370,6 @@ LogicalResult bufferize(Block *block, BufferizationState &state);
/// method of `BufferizableOpInterface`.
LogicalResult bufferize(Operation *op, BufferizationState &state);
-/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
-/// executed after the analysis, but before bufferization. They can be used
-/// implement custom dialect-specific optimizations.
-struct PostAnalysisStep {
- virtual ~PostAnalysisStep() {}
-
- /// Run the post analysis step. This function may modify the IR, but must keep
- /// `aliasInfo` (inside `state`) consistent. Newly created operations and
- /// operations that should be re-analyzed must be stored in `newOps`.
- virtual LogicalResult run(FuncOp funcOp, BufferizationState &state,
- SmallVector<Operation *> &newOps) = 0;
-};
-
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace`.
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
index a1fd04dc4bd00..cd6b5268f442f 100644
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -9,54 +9,15 @@
#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_COMPREHENSIVE_BUFFERIZE_H
-#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
+#include "mlir/IR/BuiltinOps.h"
namespace mlir {
-class ModuleOp;
-
namespace linalg {
namespace comprehensive_bufferize {
-// TODO: from some HW description.
-static constexpr int64_t kBufferAlignments = 128;
-
-/// Return default allocation callbacks.
-std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
-
-/// Options for ComprehensiveBufferize.
-struct BufferizationOptions {
- BufferizationOptions();
-
- // BufferizationOptions cannot be copied.
- BufferizationOptions(const BufferizationOptions &other) = delete;
-
- /// Register a "post analysis" step. Such steps are executed after the
- /// analysis, but before bufferization.
- template <typename Step, typename... Args>
- void addPostAnalysisStep(Args... args) {
- postAnalysisSteps.emplace_back(
- std::make_unique<Step>(std::forward<Args>(args)...));
- }
-
- /// Helper functions for allocation, deallocation, memory copying.
- std::unique_ptr<AllocationCallbacks> allocationFns;
-
- /// Specifies whether returning newly allocated memrefs should be allowed.
- /// Otherwise, a pass failure is triggered.
- bool allowReturnMemref = false;
-
- /// Seed for the analysis fuzzer. If set to `0`, the fuzzer is deactivated.
- /// Should be used only with `testAnalysisOnly = true`.
- unsigned analysisFuzzerSeed = 0;
-
- /// If set to `true`, does not modify the IR apart from adding attributes (for
- /// checking the results of the analysis) and post analysis steps.
- bool testAnalysisOnly = false;
-
- /// Registered post analysis steps.
- std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
-};
+struct BufferizationOptions;
+struct BufferizationState;
/// Bufferize the given function. Does not bufferize the function boundary.
// TODO: This function is meant to be called from ModuleBufferize and not can
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
index e0c5a1020447e..c6ba42de66237 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp
@@ -35,6 +35,45 @@ namespace comprehensive_bufferize {
using namespace mlir;
using namespace linalg::comprehensive_bufferize;
+//===----------------------------------------------------------------------===//
+// BufferizationOptions
+//===----------------------------------------------------------------------===//
+
+/// Default allocation function that is used by the comprehensive bufferization
+/// pass. The default currently creates a ranked memref using `memref.alloc`.
+static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
+ MemRefType type,
+ ArrayRef<Value> dynShape) {
+ Value allocated = b.create<memref::AllocOp>(
+ loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
+ return allocated;
+}
+
+/// Default deallocation function that is used by the comprehensive
+/// bufferization pass. It expects to recieve back the value called from the
+/// `defaultAllocationFn`.
+static void defaultDeallocationFn(OpBuilder &b, Location loc,
+ Value allocatedBuffer) {
+ b.create<memref::DeallocOp>(loc, allocatedBuffer);
+}
+
+/// Default memory copy function that is used by the comprehensive bufferization
+/// pass. Creates a `memref.copy` op.
+static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
+ b.create<memref::CopyOp>(loc, from, to);
+}
+
+std::unique_ptr<AllocationCallbacks>
+mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
+ return std::make_unique<AllocationCallbacks>(
+ defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
+}
+
+// Default constructor for BufferizationOptions that sets all allocation
+// callbacks to their default functions.
+BufferizationOptions::BufferizationOptions()
+ : allocationFns(defaultAllocationCallbacks()) {}
+
//===----------------------------------------------------------------------===//
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
@@ -384,7 +423,8 @@ Value mlir::linalg::comprehensive_bufferize::getResultBuffer(
if (!skipCopy) {
// The copy happens right before the op that is bufferized.
b.setInsertionPoint(op);
- state.allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
+ state.options.allocationFns->memCpyFn(b, loc, operandBuffer,
+ resultBuffer);
}
return resultBuffer;
}
@@ -537,7 +577,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
MemRefType allocMemRefType =
getAllocationTypeAndShape(b, loc, shapedValue, dynShape);
Optional<Value> allocated =
- allocationFns.allocationFn(b, loc, allocMemRefType, dynShape);
+ options.allocationFns->allocationFn(b, loc, allocMemRefType, dynShape);
// TODO: For now just assert the value is returned. Eventually need to
// error-propagate.
assert(allocated && "allocation failed");
@@ -549,7 +589,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
// 2. Create memory deallocation.
b.setInsertionPoint(allocated.getValue().getParentBlock()->getTerminator());
- allocationFns.deallocationFn(b, loc, allocated.getValue());
+ options.allocationFns->deallocationFn(b, loc, allocated.getValue());
return casted;
}
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
index d4571d3ac6702..1a3ce684993c7 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -783,39 +783,3 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
return success();
}
-
-/// Default allocation function that is used by the comprehensive bufferization
-/// pass. The default currently creates a ranked memref using `memref.alloc`.
-static Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
- MemRefType type,
- ArrayRef<Value> dynShape) {
- Value allocated = b.create<memref::AllocOp>(
- loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments));
- return allocated;
-}
-
-/// Default deallocation function that is used by the comprehensive
-/// bufferization pass. It expects to recieve back the value called from the
-/// `defaultAllocationFn`.
-static void defaultDeallocationFn(OpBuilder &b, Location loc,
- Value allocatedBuffer) {
- b.create<memref::DeallocOp>(loc, allocatedBuffer);
-}
-
-/// Default memory copy function that is used by the comprehensive bufferization
-/// pass. Creates a `memref.copy` op.
-static void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to) {
- b.create<memref::CopyOp>(loc, from, to);
-}
-
-std::unique_ptr<AllocationCallbacks>
-mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
- return std::make_unique<AllocationCallbacks>(
- defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
-}
-
-// Default constructor for BufferizationOptions that sets all allocation
-// callbacks to their default functions.
-BufferizationOptions::BufferizationOptions()
- : allocationFns(defaultAllocationCallbacks()) {}
-
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
index b55395b69e46d..577065149efdb 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -648,7 +648,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap)))
return failure();
- BufferizationState state(moduleOp, *options.allocationFns);
+ BufferizationState state(moduleOp, options);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
index 695e448d8b170..efc292f402657 100644
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp
@@ -189,8 +189,8 @@ struct ExtractSliceOpInterface
if (!inplace) {
// Do not copy if the copied data is never read.
if (isValueRead(extractSliceOp.result()))
- state.allocationFns.memCpyFn(b, extractSliceOp.getLoc(), subView,
- alloc);
+ state.options.allocationFns->memCpyFn(b, extractSliceOp.getLoc(),
+ subView, alloc);
subView = alloc;
}
@@ -464,8 +464,8 @@ struct InsertSliceOpInterface
state.aliasInfo.insertNewBufferAlias(subView, dstMemref);
// Copy tensor.
Value srcMemref = state.lookupBuffer(insertSliceOp.source());
- state.allocationFns.memCpyFn(b, insertSliceOp.getLoc(), srcMemref,
- subView);
+ state.options.allocationFns->memCpyFn(b, insertSliceOp.getLoc(),
+ srcMemref, subView);
}
state.mapBuffer(insertSliceOp.result(), dstMemref);
More information about the Mlir-commits
mailing list