[Mlir-commits] [mlir] [MLIR] Add bufferization state class to OneShotBufferization pass (PR #141019)
Michele Scuttari
llvmlistbot at llvm.org
Thu May 22 00:59:07 PDT 2025
https://github.com/mscuttari created https://github.com/llvm/llvm-project/pull/141019
Follow-up on #138143, which was reverted due to a missing update a method signature (more specifically, the bufferization interface for `tensor::ConcatOp`) that was not catched before merging.
The old PR description is reported in the next lines.
This PR is a follow-up on https://github.com/llvm/llvm-project/pull/138125, and adds a bufferization state class providing information about the IR.
The information currently consists of a cached list of symbol tables, which aims to solve the quadratic scaling of the bufferization task with respect to the number of symbols.
The PR breaks API compatibility: the bufferize method of the BufferizableOpInterface has been enriched with a reference to a BufferizationState object.
The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the Symbol trait is inserted or replaced, its parent SymbolTable must be updated accordingly (see, for example, the bufferization of arith::ConstantOp, where the symbol table of the module gets the new global symbol inserted). Similarly, the invalidation of a symbol table must be performed if an operation with the SymbolTable trait is removed (this can be performed using the invalidateSymbolTable method, introduced in https://github.com/llvm/llvm-project/pull/138014).
>From b3f9243b0a09ae78ecefbb959ebbdc9cd6410f5b Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Thu, 22 May 2025 09:50:22 +0200
Subject: [PATCH] [MLIR] Add bufferization state class to OneShotBufferization
pass
---
.../IR/BufferizableOpInterface.h | 14 +++++
.../IR/BufferizableOpInterface.td | 3 +-
.../Bufferization/IR/BufferizationOps.td | 15 ++++--
.../Bufferization/Transforms/BufferUtils.h | 6 +++
.../Bufferization/Transforms/Bufferize.h | 1 +
.../Transforms/OneShotAnalysis.h | 1 +
.../Transforms/OneShotModuleBufferize.h | 4 +-
.../Dialect/Linalg/Transforms/Transforms.h | 1 +
.../BufferizableOpInterfaceImpl.cpp | 12 +++--
.../IR/BufferizableOpInterface.cpp | 4 ++
.../Bufferization/IR/BufferizationOps.cpp | 12 +++--
.../BufferizationTransformOps.cpp | 9 +++-
.../Bufferization/Transforms/BufferUtils.cpp | 23 +++++++--
.../Bufferization/Transforms/Bufferize.cpp | 10 ++--
.../FuncBufferizableOpInterfaceImpl.cpp | 9 ++--
.../Transforms/OneShotAnalysis.cpp | 9 ++--
.../Transforms/OneShotModuleBufferize.cpp | 12 ++---
.../BufferizableOpInterfaceImpl.cpp | 3 +-
.../BufferizableOpInterfaceImpl.cpp | 7 ++-
.../Transforms/ConvertToDestinationStyle.cpp | 25 ++++++---
.../BufferizableOpInterfaceImpl.cpp | 15 ++++--
.../BufferizableOpInterfaceImpl.cpp | 27 ++++++----
.../BufferizableOpInterfaceImpl.cpp | 6 ++-
.../BufferizableOpInterfaceImpl.cpp | 3 +-
.../SparsificationAndBufferizationPass.cpp | 5 +-
.../BufferizableOpInterfaceImpl.cpp | 51 ++++++++++++-------
.../BufferizableOpInterfaceImpl.cpp | 15 ++++--
27 files changed, 215 insertions(+), 87 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cb6ef8bc17220..43c97d57e1834 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,6 +578,20 @@ class AnalysisState {
insideMutuallyExclusiveRegionsCache;
};
+/// BufferizationState provides information about the state of the IR during the
+/// bufferization process.
+class BufferizationState {
+public:
+ /// Get a reference to the collection of cached symbol tables.
+ SymbolTableCollection &getSymbolTables();
+
+private:
+ /// The cached symbol tables.
+ /// The user is expected to update / invalidate the cached symbol tables if
+ /// the bufferized operation has the Symbol or SymbolTable traits.
+ SymbolTableCollection symbolTables;
+};
+
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..b599a9f053215 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::BufferizationOptions &":$options),
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 7a1a701bea6dc..dafa4b9b183f2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool resultBufferizesToMemoryWrite(OpResult opResult,
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
bool bufferizesToMemoryRead(OpOperand &opOperand,
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
}
@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
//===------------------------------------------------------------------===//
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
// to_tensor/to_buffer pairs fold away after bufferization.
return success();
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index e5f3b6d571f43..c08bd6c436133 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,6 +29,7 @@ class GlobalOp;
} // namespace memref
namespace bufferization {
+class BufferizationState;
/// A simple analysis that detects allocation operations.
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
uint64_t alignment,
Attribute memorySpace = {});
+void removeSymbol(Operation *op, BufferizationState &state);
+
+void insertSymbol(Operation *op, BufferizationState &state);
+
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index d5cb8d8eb673c..70e3defee0867 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
/// additional buffer copies or set "options.copyBeforeWrite = true". The
/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 673027f76190d..15189d2c1cb87 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 4e5f5e9c730fa..2cf801dd1d951 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,6 +20,7 @@ namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
+class BufferizationState;
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics = nullptr);
+ BufferizationState &state, BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 4f90fc8831bc6..2eef0a06d0eb4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,6 +30,7 @@ namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
+class BufferizationState;
} // namespace bufferization
namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..f646326ffc58f 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,7 +24,8 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
@@ -46,7 +47,8 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
+ getGlobalFor(constantOp, state.getSymbolTables(),
+ options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
@@ -131,7 +134,8 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1fc34051680f1..14fa4c1ed8159 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,6 +125,10 @@ void AnalysisState::resetCache() {
insideMutuallyExclusiveRegionsCache.clear();
}
+SymbolTableCollection &BufferizationState::getSymbolTables() {
+ return symbolTables;
+}
+
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ecd2ef15546a4..91eccb0ab7430 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
if (failed(buffer))
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToBufferToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index a1d7bb995fc73..db1eb20512033 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,6 +83,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
}
auto payloadOps = state.getPayloadOps(getTarget());
+ BufferizationState bufferizationState;
+
for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
@@ -90,10 +92,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
+ if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
- if (failed(bufferization::runOneShotBufferize(target, options)))
+ if (failed(bufferization::runOneShotBufferize(target, options,
+ bufferizationState)))
return emitSilenceableError() << "bufferization failed";
}
}
@@ -162,6 +166,7 @@ class BufferizationTransformDialectExtension
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
+
>();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index c2e90764b1335..ff2c83d228dbb 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
//===----------------------------------------------------------------------===//
FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
- Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp,
+ SymbolTableCollection &symbolTables,
+ uint64_t alignment, Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(moduleOp.getContext());
- SymbolTable symbolTable(moduleOp);
+ SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
// Create a pretty name.
SmallString<64> buf;
@@ -158,3 +159,19 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
global->moveBefore(&moduleOp.front());
return global;
}
+
+namespace mlir::bufferization {
+void removeSymbol(Operation *op, BufferizationState &state) {
+ SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
+ op->getParentWithTrait<OpTrait::SymbolTable>());
+
+ symbolTable.remove(op);
+}
+
+void insertSymbol(Operation *op, BufferizationState &state) {
+ SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
+ op->getParentWithTrait<OpTrait::SymbolTable>());
+
+ symbolTable.insert(op);
+}
+} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 824b505517119..67f373d912dd4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,10 +161,12 @@ struct OneShotBufferizePass
return signalPassFailure();
}
+ BufferizationState state;
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
- if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
+ if (failed(
+ runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
@@ -175,7 +177,7 @@ struct OneShotBufferizePass
"'bufferize-function-boundaries'");
return signalPassFailure();
}
- if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
+ if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
@@ -275,6 +277,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
if (options.copyBeforeWrite) {
AnalysisState state(options);
@@ -331,7 +334,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
<< "//===-------------------------------------------===//\n"
<< "IR after bufferizing: " << nextOp->getName() << "\n");
rewriter.setInsertionPoint(nextOp);
- if (failed(bufferizableOp.bufferize(rewriter, options))) {
+ if (failed(
+ bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
LLVM_DEBUG(llvm::dbgs()
<< "failed to bufferize\n"
<< "//===-------------------------------------------===//\n");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 755477713668e..080796208bfc1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -239,7 +239,8 @@ struct CallOpInterface
/// All function arguments are writable. It is the responsibility of the
/// CallOp to insert buffer copies where necessary.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
func::CallOp callOp = cast<func::CallOp>(op);
// 1. Compute the result types of the new CallOp.
@@ -349,7 +350,8 @@ struct ReturnOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
#ifndef NDEBUG
auto returnOp = cast<func::ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -418,7 +420,8 @@ struct FuncOpInterface
/// All function bbArgs are writable unless they are explicitly marked as
/// read-only. Callers must insert copies when needed.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index 6e93b36d2d5a2..de820e9c8f8af 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1365,10 +1365,9 @@ LogicalResult bufferization::analyzeOp(Operation *op,
return success(!failedAnalysis);
}
-LogicalResult
-bufferization::runOneShotBufferize(Operation *op,
- const OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics) {
+LogicalResult bufferization::runOneShotBufferize(
+ Operation *op, const OneShotBufferizationOptions &options,
+ BufferizationState &state, BufferizationStatistics *statistics) {
// copy-before-write deactivates the analysis. It cannot be used together with
// test-analysis-only.
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -1391,5 +1390,5 @@ bufferization::runOneShotBufferize(Operation *op,
// Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
// a new buffer copy is allocated every time a buffer is written to.
- return bufferizeOp(op, options, statistics);
+ return bufferizeOp(op, options, state, statistics);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index a025da8635135..90ceea4d69680 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -512,7 +512,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
LogicalResult mlir::bufferization::bufferizeModuleOp(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics) {
+ BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
@@ -548,10 +548,10 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
updatedOptions.copyBeforeWrite = true;
- if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
+ if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
return failure();
} else {
- if (failed(bufferizeOp(funcOp, options, statistics)))
+ if (failed(bufferizeOp(funcOp, options, state, statistics)))
return failure();
}
@@ -565,7 +565,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Functions were already bufferized.
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
- if (failed(bufferizeOp(&op, options, statistics)))
+ if (failed(bufferizeOp(&op, options, state, statistics)))
return failure();
}
@@ -577,7 +577,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationStatistics *statistics) {
+ BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -606,7 +606,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
}
if (options.testAnalysisOnly)
return success();
- if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
+ if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
index 72f4a1a4f4c66..6a1546fb48683 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -43,7 +43,8 @@ struct BranchLikeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
// The operands of this op are bufferized together with the block signature.
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index be158af09d398..b6a498a57c036 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -148,7 +148,8 @@ struct LinalgOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
return bufferizeDestinationStyleOpInterface(
rewriter, cast<DestinationStyleOpInterface>(op), options);
}
@@ -174,7 +175,8 @@ struct SoftmaxOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
FailureOr<Value> inputBuffer =
getBuffer(rewriter, softmaxOp.getInput(), options);
@@ -202,6 +204,7 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
LinalgOpInterfaceHelper<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index a62510deefc4a..94a4b9011c16b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -263,7 +263,11 @@ Value linalg::bufferizeToAllocation(
assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
"expected single masked op");
OpBuilder::InsertionGuard g(rewriter);
+
+ // Should the bufferization options and state be function arguments?
bufferization::BufferizationOptions bufferizationOptions;
+ bufferization::BufferizationState bufferizationState;
+
Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
@@ -279,7 +283,7 @@ Value linalg::bufferizeToAllocation(
// Bufferize terminator.
rewriter.setInsertionPoint(yieldOp);
if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
- rewriter, bufferizationOptions)))
+ rewriter, bufferizationOptions, bufferizationState)))
return nullptr;
// Erase dead to_tensor ops inside of the mask op. This is necessary because
@@ -300,8 +304,9 @@ Value linalg::bufferizeToAllocation(
for (OpOperand &use : result.getUses())
resultUses.push_back(&use);
rewriter.setInsertionPoint(maskOp);
- if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
- .bufferize(rewriter, bufferizationOptions)))
+ if (failed(
+ cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
+ .bufferize(rewriter, bufferizationOptions, bufferizationState)))
return nullptr;
// Set "restrict" attribute, indicating that no other tensor aliases with
@@ -484,8 +489,11 @@ Value linalg::bufferizeToAllocation(
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
if (!bufferizableOp)
return nullptr;
+
+ // Should the bufferization options and states be function arguments?
BufferizationOptions bufferizationOptions;
- AnalysisState state(bufferizationOptions);
+ AnalysisState analysisState(bufferizationOptions);
+ BufferizationState bufferizationState;
#ifndef NDEBUG
if (!options.bufferizeDestinationOnly) {
@@ -527,7 +535,7 @@ Value linalg::bufferizeToAllocation(
};
for (OpResult result : tensorResults) {
AliasingOpOperandList aliasingOperands =
- state.getAliasingOpOperands(result);
+ analysisState.getAliasingOpOperands(result);
for (const AliasingOpOperand &operand : aliasingOperands) {
addOutOfPlaceOperand(operand.opOperand);
for (OpOperand &resultUse : result.getUses())
@@ -535,7 +543,7 @@ Value linalg::bufferizeToAllocation(
}
}
for (OpOperand &operand : op->getOpOperands()) {
- if (!state.bufferizesToMemoryWrite(operand))
+ if (!analysisState.bufferizesToMemoryWrite(operand))
continue;
if (!isa<RankedTensorType>(operand.get().getType()))
continue;
@@ -553,7 +561,7 @@ Value linalg::bufferizeToAllocation(
Value alloc = createAllocationForTensor(
rewriter, op->getLoc(), operand->get(), options, memorySpace);
allocs.push_back(alloc);
- if (!state.findDefinitions(operand).empty()) {
+ if (!analysisState.findDefinitions(operand).empty()) {
// Initialize buffer with a copy of the operand data. Not needed if the
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
@@ -575,7 +583,8 @@ Value linalg::bufferizeToAllocation(
// Bufferize the op.
rewriter.setInsertionPoint(op);
- if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
+ if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,
+ bufferizationState)))
return nullptr;
// Set "restrict" attribute, indicating that no other tensor aliases with
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index 926d580ac7852..a69bc9e5088ae 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -52,15 +53,18 @@ struct GlobalOpInterface
bool hasTensorSemantics(Operation *) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &) const {
+ const BufferizationOptions &,
+ BufferizationState &state) const {
auto globalOp = cast<GlobalOp>(op);
if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value");
+ bufferization::removeSymbol(globalOp, state);
+
auto tensorType = cast<TensorType>(globalOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
- replaceOpWithNewBufferizedOp<memref::GlobalOp>(
+ auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
rewriter, globalOp, globalOp.getSymName(),
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
/*type=*/cast<MemRefType>(memrefType),
@@ -68,6 +72,7 @@ struct GlobalOpInterface
/*constant=*/!globalOp.getIsMutable(),
/*alignment=*/nullptr);
+ bufferization::insertSymbol(replacement, state);
return success();
}
};
@@ -91,7 +96,8 @@ struct GlobalLoadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &) const {
+ const BufferizationOptions &,
+ BufferizationState &state) const {
auto globalLoadOp = cast<GlobalLoadOp>(op);
auto tensorType = cast<TensorType>(globalLoadOp.getType());
@@ -121,7 +127,8 @@ struct GlobalStoreOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto globalStoreOp = cast<GlobalStoreOp>(op);
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index d6a9d8f6401f1..3ff1f5c49aece 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -95,7 +95,8 @@ struct ConditionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto conditionOp = cast<scf::ConditionOp>(op);
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
@@ -181,7 +182,8 @@ struct ExecuteRegionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
auto yieldOp = getUniqueYieldOp(executeRegionOp);
TypeRange newResultTypes(yieldOp.getResults());
@@ -237,7 +239,8 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto ifOp = cast<scf::IfOp>(op);
@@ -347,7 +350,8 @@ struct IndexSwitchOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto switchOp = cast<scf::IndexSwitchOp>(op);
@@ -722,7 +726,8 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = forOp.getBody();
@@ -939,7 +944,8 @@ struct WhileOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto whileOp = cast<scf::WhileOp>(op);
// Indices of all bbArgs that have tensor type. These are the ones that
@@ -1144,7 +1150,8 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
scf::WhileOp>(yieldOp->getParentOp()))
@@ -1220,7 +1227,8 @@ struct ForallOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
OpBuilder::InsertionGuard guard(rewriter);
auto forallOp = cast<ForallOp>(op);
int64_t rank = forallOp.getRank();
@@ -1327,7 +1335,8 @@ struct InParallelOpInterface
: public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
InParallelOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
return failure();
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 6c3b23937f98f..e8cab76d3c753 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -47,7 +47,8 @@ struct AssumingOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto assumingOp = cast<shape::AssumingOp>(op);
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
@@ -112,7 +113,8 @@ struct AssumingYieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto yieldOp = cast<shape::AssumingYieldOp>(op);
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 7734d1d258453..f952b68ba7e67 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,7 +30,8 @@ template <typename ConcreteModel, typename ConcreteOp>
struct SparseBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
return op->emitError(
"sparse_tensor ops must be bufferized with the sparsifier");
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 6e882a8d0ff30..7c7c64f2aef01 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -114,8 +114,11 @@ class SparsificationAndBufferizationPass
return false;
});
+ bufferization::BufferizationState bufferizationState;
+
if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions)))
+ updatedOptions,
+ bufferizationState)))
return failure();
bufferization::removeBufferizationAttributesInModule(getOperation());
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index b6843e560a899..630e970cd4b19 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -83,7 +83,8 @@ struct CastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto castOp = cast<tensor::CastOp>(op);
// The result buffer still has the old (pre-cast) type.
@@ -162,7 +163,8 @@ struct CollapseShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
FailureOr<Value> maybeBuffer =
@@ -247,7 +249,8 @@ struct DimOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
if (failed(v))
@@ -271,7 +274,8 @@ struct EmptyOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto emptyOp = cast<tensor::EmptyOp>(op);
// Optimization: Fold away the op if it has no uses.
@@ -329,7 +333,8 @@ struct ExpandShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
FailureOr<Value> buffer =
@@ -367,7 +372,8 @@ struct ExtractSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
@@ -432,7 +438,8 @@ struct ExtractOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
FailureOr<Value> srcMemref =
getBuffer(rewriter, extractOp.getTensor(), options);
@@ -474,7 +481,8 @@ struct FromElementsOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
@@ -586,7 +594,8 @@ struct GenerateOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto generateOp = cast<tensor::GenerateOp>(op);
auto type = generateOp.getResult().getType();
@@ -620,7 +629,8 @@ struct InsertOpInterface
: public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
tensor::InsertOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
getBuffer(rewriter, insertOp.getDest(), options);
@@ -670,7 +680,8 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
@@ -752,7 +763,8 @@ struct PadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto padOp = cast<tensor::PadOp>(op);
Location loc = padOp.getLoc();
RankedTensorType resultType = padOp.getResultType();
@@ -831,7 +843,8 @@ struct RankOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto rankOp = cast<tensor::RankOp>(op);
FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
if (failed(v))
@@ -868,7 +881,8 @@ struct ReshapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, reshapeOp.getSource(), options);
@@ -940,7 +954,8 @@ struct ParallelInsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
@@ -1015,7 +1030,8 @@ struct SplatOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto splatOp = cast<tensor::SplatOp>(op);
@@ -1073,7 +1089,8 @@ struct ConcatOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
OpBuilder::InsertionGuard g(rewriter);
auto concatOp = cast<tensor::ConcatOp>(op);
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index b2272c5fda876..45b6e7c512947 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,7 +48,8 @@ struct TransferReadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(isa<TensorType>(readOp.getShapedType()) &&
"only tensor types expected");
@@ -103,7 +104,8 @@ struct TransferWriteOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(isa<TensorType>(writeOp.getShapedType()) &&
"only tensor types expected");
@@ -148,7 +150,8 @@ struct GatherOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto gatherOp = cast<vector::GatherOp>(op);
assert(isa<TensorType>(gatherOp.getBaseType()) &&
"only tensor types expected");
@@ -202,7 +205,8 @@ struct MaskOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto maskOp = cast<vector::MaskOp>(op);
// Do not bufferize if the masked op is not bufferizable.
@@ -279,7 +283,8 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options) const {
+ const BufferizationOptions &options,
+ BufferizationState &state) const {
auto yieldOp = cast<vector::YieldOp>(op);
// Only supported as a vector.mask terminator.
More information about the Mlir-commits
mailing list