[llvm-branch-commits] [mlir] 06eb7d7 - Revert "[MLIR] Add bufferization state class to OneShotBufferization pass (#1…"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu May 22 00:12:35 PDT 2025
Author: Michele Scuttari
Date: 2025-05-22T09:12:32+02:00
New Revision: 06eb7d7fe399ab4a5fbd289afa35f12ba008685f
URL: https://github.com/llvm/llvm-project/commit/06eb7d7fe399ab4a5fbd289afa35f12ba008685f
DIFF: https://github.com/llvm/llvm-project/commit/06eb7d7fe399ab4a5fbd289afa35f12ba008685f.diff
LOG: Revert "[MLIR] Add bufferization state class to OneShotBufferization pass (#1…"
This reverts commit 67fc1660d987d145a68a3c3dfbccfbe4b91fba59.
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..cb6ef8bc17220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -578,20 +578,6 @@ class AnalysisState {
insideMutuallyExclusiveRegionsCache;
};
-/// BufferizationState provides information about the state of the IR during the
-/// bufferization process.
-class BufferizationState {
-public:
- /// Get a reference to the collection of cached symbol tables.
- SymbolTableCollection &getSymbolTables();
-
-private:
- /// The cached symbol tables.
- /// The user is expected to update / invalidate the cached symbol tables if
- /// the bufferized operation has the Symbol or SymbolTable traits.
- SymbolTableCollection symbolTables;
-};
-
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..95022d7d665d2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -426,8 +426,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::BufferizationOptions &":$options,
- "::mlir::bufferization::BufferizationState &":$state),
+ "const ::mlir::bufferization::BufferizationOptions &":$options),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..7a1a701bea6dc 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -93,8 +93,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state);
+ const BufferizationOptions &options);
bool resultBufferizesToMemoryWrite(OpResult opResult,
const AnalysisState &state);
@@ -283,8 +282,7 @@ def Bufferization_MaterializeInDestinationOp
let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state);
+ const BufferizationOptions &options);
bool bufferizesToMemoryRead(OpOperand &opOperand,
const AnalysisState &state);
@@ -377,8 +375,7 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state);
+ const BufferizationOptions &options);
}];
}
@@ -461,8 +458,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
//===------------------------------------------------------------------===//
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// to_tensor/to_buffer pairs fold away after bufferization.
return success();
}
@@ -554,8 +550,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
}
LogicalResult bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state);
+ const BufferizationOptions &options);
}];
let assemblyFormat = [{
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
index c08bd6c436133..e5f3b6d571f43 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h
@@ -29,7 +29,6 @@ class GlobalOp;
} // namespace memref
namespace bufferization {
-class BufferizationState;
/// A simple analysis that detects allocation operations.
class BufferPlacementAllocs {
@@ -123,14 +122,9 @@ class BufferPlacementTransformationBase {
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
// names. Duplicates are avoided.
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
- SymbolTableCollection &symbolTables,
uint64_t alignment,
Attribute memorySpace = {});
-void removeSymbol(Operation *op, BufferizationState &state);
-
-void insertSymbol(Operation *op, BufferizationState &state);
-
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..d5cb8d8eb673c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -45,7 +45,6 @@ struct BufferizationStatistics {
/// additional buffer copies or set "options.copyBeforeWrite = true". The
/// general bufferization entry point is `runOneShotBufferize`.
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
- BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Bufferize the signature of `block` and its callers (i.e., ops that have the
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index 15189d2c1cb87..673027f76190d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -270,7 +270,6 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
LogicalResult
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
- BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 2cf801dd1d951..4e5f5e9c730fa 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -20,7 +20,6 @@ namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
-class BufferizationState;
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
@@ -39,7 +38,6 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -52,7 +50,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
const bufferization::OneShotBufferizationOptions &options,
- BufferizationState &state, BufferizationStatistics *statistics = nullptr);
+ BufferizationStatistics *statistics = nullptr);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..4f90fc8831bc6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -30,7 +30,6 @@ namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
-class BufferizationState;
} // namespace bufferization
namespace linalg {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..5e69a98db8f1e 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,8 +24,7 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
@@ -47,8 +46,7 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, state.getSymbolTables(),
- options.bufferAlignment, memorySpace);
+ getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = *globalOp;
@@ -85,8 +83,7 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
@@ -134,8 +131,7 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..1fc34051680f1 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -125,10 +125,6 @@ void AnalysisState::resetCache() {
insideMutuallyExclusiveRegionsCache.clear();
}
-SymbolTableCollection &BufferizationState::getSymbolTables() {
- return symbolTables;
-}
-
Region *bufferization::getNextEnclosingRepetitiveRegion(
Region *region, const BufferizationOptions &options) {
assert(isRepetitiveRegion(region, options) && "expected repetitive region");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..ecd2ef15546a4 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -149,8 +149,7 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@@ -530,8 +529,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
if (failed(buffer))
return failure();
@@ -578,8 +576,7 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
@@ -864,8 +861,7 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToBufferToTensorPair(rewriter, *this, options);
// Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
index db1eb20512033..a1d7bb995fc73 100644
--- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp
@@ -83,8 +83,6 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
}
auto payloadOps = state.getPayloadOps(getTarget());
- BufferizationState bufferizationState;
-
for (Operation *target : payloadOps) {
if (!isa<ModuleOp, FunctionOpInterface>(target))
return emitSilenceableError() << "expected module or function target";
@@ -92,12 +90,10 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
- if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
- bufferizationState)))
+ if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
return emitSilenceableError() << "bufferization failed";
} else {
- if (failed(bufferization::runOneShotBufferize(target, options,
- bufferizationState)))
+ if (failed(bufferization::runOneShotBufferize(target, options)))
return emitSilenceableError() << "bufferization failed";
}
}
@@ -166,7 +162,6 @@ class BufferizationTransformDialectExtension
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
-
>();
}
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
index ff2c83d228dbb..c2e90764b1335 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp
@@ -103,9 +103,8 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
//===----------------------------------------------------------------------===//
FailureOr<memref::GlobalOp>
-bufferization::getGlobalFor(arith::ConstantOp constantOp,
- SymbolTableCollection &symbolTables,
- uint64_t alignment, Attribute memorySpace) {
+bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
+ Attribute memorySpace) {
auto type = cast<RankedTensorType>(constantOp.getType());
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
@@ -128,7 +127,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(moduleOp.getContext());
- SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
+ SymbolTable symbolTable(moduleOp);
// Create a pretty name.
SmallString<64> buf;
@@ -159,19 +158,3 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp,
global->moveBefore(&moduleOp.front());
return global;
}
-
-namespace mlir::bufferization {
-void removeSymbol(Operation *op, BufferizationState &state) {
- SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
- op->getParentWithTrait<OpTrait::SymbolTable>());
-
- symbolTable.remove(op);
-}
-
-void insertSymbol(Operation *op, BufferizationState &state) {
- SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
- op->getParentWithTrait<OpTrait::SymbolTable>());
-
- symbolTable.insert(op);
-}
-} // namespace mlir::bufferization
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 38de525316f7a..824b505517119 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -161,13 +161,10 @@ struct OneShotBufferizePass
return signalPassFailure();
}
- BufferizationState state;
-
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
- if (failed(
- runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
+ if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
signalPassFailure();
return;
}
@@ -178,7 +175,7 @@ struct OneShotBufferizePass
"'bufferize-function-boundaries'");
return signalPassFailure();
}
- if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
+ if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
signalPassFailure();
return;
}
@@ -278,7 +275,6 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
- BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
if (options.copyBeforeWrite) {
AnalysisState state(options);
@@ -335,8 +331,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
<< "//===-------------------------------------------===//\n"
<< "IR after bufferizing: " << nextOp->getName() << "\n");
rewriter.setInsertionPoint(nextOp);
- if (failed(
- bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
+ if (failed(bufferizableOp.bufferize(rewriter, options))) {
LLVM_DEBUG(llvm::dbgs()
<< "failed to bufferize\n"
<< "//===-------------------------------------------===//\n");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 080796208bfc1..755477713668e 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -239,8 +239,7 @@ struct CallOpInterface
/// All function arguments are writable. It is the responsibility of the
/// CallOp to insert buffer copies where necessary.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
func::CallOp callOp = cast<func::CallOp>(op);
// 1. Compute the result types of the new CallOp.
@@ -350,8 +349,7 @@ struct ReturnOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
#ifndef NDEBUG
auto returnOp = cast<func::ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -420,8 +418,7 @@ struct FuncOpInterface
/// All function bbArgs are writable unless they are explicitly marked as
/// read-only. Callers must insert copies when needed.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index de820e9c8f8af..6e93b36d2d5a2 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1365,9 +1365,10 @@ LogicalResult bufferization::analyzeOp(Operation *op,
return success(!failedAnalysis);
}
-LogicalResult bufferization::runOneShotBufferize(
- Operation *op, const OneShotBufferizationOptions &options,
- BufferizationState &state, BufferizationStatistics *statistics) {
+LogicalResult
+bufferization::runOneShotBufferize(Operation *op,
+ const OneShotBufferizationOptions &options,
+ BufferizationStatistics *statistics) {
// copy-before-write deactivates the analysis. It cannot be used together with
// test-analysis-only.
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -1390,5 +1391,5 @@ LogicalResult bufferization::runOneShotBufferize(
// Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
// a new buffer copy is allocated every time a buffer is written to.
- return bufferizeOp(op, options, state, statistics);
+ return bufferizeOp(op, options, statistics);
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 90ceea4d69680..a025da8635135 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -512,7 +512,7 @@ void mlir::bufferization::removeBufferizationAttributesInModule(
LogicalResult mlir::bufferization::bufferizeModuleOp(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationState &state, BufferizationStatistics *statistics) {
+ BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
@@ -548,10 +548,10 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Buffer copies must be inserted before every write.
OneShotBufferizationOptions updatedOptions = options;
updatedOptions.copyBeforeWrite = true;
- if (failed(bufferizeOp(funcOp, updatedOptions, state, statistics)))
+ if (failed(bufferizeOp(funcOp, updatedOptions, statistics)))
return failure();
} else {
- if (failed(bufferizeOp(funcOp, options, state, statistics)))
+ if (failed(bufferizeOp(funcOp, options, statistics)))
return failure();
}
@@ -565,7 +565,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// Functions were already bufferized.
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
+ if (failed(bufferizeOp(&op, options, statistics)))
return failure();
}
@@ -577,7 +577,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
- BufferizationState &state, BufferizationStatistics *statistics) {
+ BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
@@ -606,7 +606,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
}
if (options.testAnalysisOnly)
return success();
- if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
+ if (failed(bufferizeModuleOp(moduleOp, options, statistics)))
return failure();
return success();
}
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
index 6a1546fb48683..72f4a1a4f4c66 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -43,8 +43,7 @@ struct BranchLikeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// The operands of this op are bufferized together with the block signature.
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index b6a498a57c036..be158af09d398 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -148,8 +148,7 @@ struct LinalgOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
return bufferizeDestinationStyleOpInterface(
rewriter, cast<DestinationStyleOpInterface>(op), options);
}
@@ -175,8 +174,7 @@ struct SoftmaxOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
FailureOr<Value> inputBuffer =
getBuffer(rewriter, softmaxOp.getInput(), options);
@@ -204,7 +202,6 @@ void mlir::linalg::registerBufferizableOpInterfaceExternalModels(
LinalgOpInterfaceHelper<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
-
>::registerOpInterface(ctx);
SoftmaxOp::attachInterface<SoftmaxOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 94a4b9011c16b..a62510deefc4a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -263,11 +263,7 @@ Value linalg::bufferizeToAllocation(
assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
"expected single masked op");
OpBuilder::InsertionGuard g(rewriter);
-
- // Should the bufferization options and state be function arguments?
bufferization::BufferizationOptions bufferizationOptions;
- bufferization::BufferizationState bufferizationState;
-
Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
assert(isa<vector::YieldOp>(yieldOp) && "expected yield op terminator");
@@ -283,7 +279,7 @@ Value linalg::bufferizeToAllocation(
// Bufferize terminator.
rewriter.setInsertionPoint(yieldOp);
if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
- rewriter, bufferizationOptions, bufferizationState)))
+ rewriter, bufferizationOptions)))
return nullptr;
// Erase dead to_tensor ops inside of the mask op. This is necessary because
@@ -304,9 +300,8 @@ Value linalg::bufferizeToAllocation(
for (OpOperand &use : result.getUses())
resultUses.push_back(&use);
rewriter.setInsertionPoint(maskOp);
- if (failed(
- cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
- .bufferize(rewriter, bufferizationOptions, bufferizationState)))
+ if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
+ .bufferize(rewriter, bufferizationOptions)))
return nullptr;
// Set "restrict" attribute, indicating that no other tensor aliases with
@@ -489,11 +484,8 @@ Value linalg::bufferizeToAllocation(
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
if (!bufferizableOp)
return nullptr;
-
- // Should the bufferization options and states be function arguments?
BufferizationOptions bufferizationOptions;
- AnalysisState analysisState(bufferizationOptions);
- BufferizationState bufferizationState;
+ AnalysisState state(bufferizationOptions);
#ifndef NDEBUG
if (!options.bufferizeDestinationOnly) {
@@ -535,7 +527,7 @@ Value linalg::bufferizeToAllocation(
};
for (OpResult result : tensorResults) {
AliasingOpOperandList aliasingOperands =
- analysisState.getAliasingOpOperands(result);
+ state.getAliasingOpOperands(result);
for (const AliasingOpOperand &operand : aliasingOperands) {
addOutOfPlaceOperand(operand.opOperand);
for (OpOperand &resultUse : result.getUses())
@@ -543,7 +535,7 @@ Value linalg::bufferizeToAllocation(
}
}
for (OpOperand &operand : op->getOpOperands()) {
- if (!analysisState.bufferizesToMemoryWrite(operand))
+ if (!state.bufferizesToMemoryWrite(operand))
continue;
if (!isa<RankedTensorType>(operand.get().getType()))
continue;
@@ -561,7 +553,7 @@ Value linalg::bufferizeToAllocation(
Value alloc = createAllocationForTensor(
rewriter, op->getLoc(), operand->get(), options, memorySpace);
allocs.push_back(alloc);
- if (!analysisState.findDefinitions(operand).empty()) {
+ if (!state.findDefinitions(operand).empty()) {
// Initialize buffer with a copy of the operand data. Not needed if the
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
@@ -583,8 +575,7 @@ Value linalg::bufferizeToAllocation(
// Bufferize the op.
rewriter.setInsertionPoint(op);
- if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,
- bufferizationState)))
+ if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
return nullptr;
// Set "restrict" attribute, indicating that no other tensor aliases with
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index a69bc9e5088ae..926d580ac7852 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -9,7 +9,6 @@
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
-#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -53,18 +52,15 @@ struct GlobalOpInterface
bool hasTensorSemantics(Operation *) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &,
- BufferizationState &state) const {
+ const BufferizationOptions &) const {
auto globalOp = cast<GlobalOp>(op);
if (!globalOp.getValue().has_value())
return globalOp.emitError("global op must have a value");
- bufferization::removeSymbol(globalOp, state);
-
auto tensorType = cast<TensorType>(globalOp.getType());
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
- auto replacement = replaceOpWithNewBufferizedOp<memref::GlobalOp>(
+ replaceOpWithNewBufferizedOp<memref::GlobalOp>(
rewriter, globalOp, globalOp.getSymName(),
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
/*type=*/cast<MemRefType>(memrefType),
@@ -72,7 +68,6 @@ struct GlobalOpInterface
/*constant=*/!globalOp.getIsMutable(),
/*alignment=*/nullptr);
- bufferization::insertSymbol(replacement, state);
return success();
}
};
@@ -96,8 +91,7 @@ struct GlobalLoadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &,
- BufferizationState &state) const {
+ const BufferizationOptions &) const {
auto globalLoadOp = cast<GlobalLoadOp>(op);
auto tensorType = cast<TensorType>(globalLoadOp.getType());
@@ -127,8 +121,7 @@ struct GlobalStoreOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto globalStoreOp = cast<GlobalStoreOp>(op);
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 3ff1f5c49aece..d6a9d8f6401f1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -95,8 +95,7 @@ struct ConditionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto conditionOp = cast<scf::ConditionOp>(op);
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
@@ -182,8 +181,7 @@ struct ExecuteRegionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
auto yieldOp = getUniqueYieldOp(executeRegionOp);
TypeRange newResultTypes(yieldOp.getResults());
@@ -239,8 +237,7 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto ifOp = cast<scf::IfOp>(op);
@@ -350,8 +347,7 @@ struct IndexSwitchOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto switchOp = cast<scf::IndexSwitchOp>(op);
@@ -726,8 +722,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = forOp.getBody();
@@ -944,8 +939,7 @@ struct WhileOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto whileOp = cast<scf::WhileOp>(op);
// Indices of all bbArgs that have tensor type. These are the ones that
@@ -1150,8 +1144,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp,
scf::WhileOp>(yieldOp->getParentOp()))
@@ -1227,8 +1220,7 @@ struct ForallOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
OpBuilder::InsertionGuard guard(rewriter);
auto forallOp = cast<ForallOp>(op);
int64_t rank = forallOp.getRank();
@@ -1335,8 +1327,7 @@ struct InParallelOpInterface
: public BufferizableOpInterface::ExternalModel<InParallelOpInterface,
InParallelOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
return failure();
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index e8cab76d3c753..6c3b23937f98f 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -47,8 +47,7 @@ struct AssumingOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
assert(llvm::hasSingleElement(assumingOp.getDoRegion().getBlocks()) &&
"only 1 block supported");
@@ -113,8 +112,7 @@ struct AssumingYieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto yieldOp = cast<shape::AssumingYieldOp>(op);
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index f952b68ba7e67..7734d1d258453 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -30,8 +30,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct SparseBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
return op->emitError(
"sparse_tensor ops must be bufferized with the sparsifier");
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 7c7c64f2aef01..6e882a8d0ff30 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -114,11 +114,8 @@ class SparsificationAndBufferizationPass
return false;
});
- bufferization::BufferizationState bufferizationState;
-
if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions,
- bufferizationState)))
+ updatedOptions)))
return failure();
bufferization::removeBufferizationAttributesInModule(getOperation());
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 64443d210e163..b6843e560a899 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -83,8 +83,7 @@ struct CastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto castOp = cast<tensor::CastOp>(op);
// The result buffer still has the old (pre-cast) type.
@@ -163,8 +162,7 @@ struct CollapseShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
FailureOr<Value> maybeBuffer =
@@ -249,8 +247,7 @@ struct DimOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto dimOp = cast<tensor::DimOp>(op);
FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
if (failed(v))
@@ -274,8 +271,7 @@ struct EmptyOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto emptyOp = cast<tensor::EmptyOp>(op);
// Optimization: Fold away the op if it has no uses.
@@ -333,8 +329,7 @@ struct ExpandShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
FailureOr<Value> buffer =
@@ -372,8 +367,7 @@ struct ExtractSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
@@ -438,8 +432,7 @@ struct ExtractOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto extractOp = cast<tensor::ExtractOp>(op);
FailureOr<Value> srcMemref =
getBuffer(rewriter, extractOp.getTensor(), options);
@@ -481,8 +474,7 @@ struct FromElementsOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
auto tensorType = cast<RankedTensorType>(fromElementsOp.getType());
@@ -594,8 +586,7 @@ struct GenerateOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto generateOp = cast<tensor::GenerateOp>(op);
auto type = generateOp.getResult().getType();
@@ -629,8 +620,7 @@ struct InsertOpInterface
: public DstBufferizableOpInterfaceExternalModel<InsertOpInterface,
tensor::InsertOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
getBuffer(rewriter, insertOp.getDest(), options);
@@ -680,8 +670,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
@@ -763,8 +752,7 @@ struct PadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto padOp = cast<tensor::PadOp>(op);
Location loc = padOp.getLoc();
RankedTensorType resultType = padOp.getResultType();
@@ -843,8 +831,7 @@ struct RankOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto rankOp = cast<tensor::RankOp>(op);
FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
if (failed(v))
@@ -881,8 +868,7 @@ struct ReshapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
FailureOr<Value> srcBuffer =
getBuffer(rewriter, reshapeOp.getSource(), options);
@@ -954,8 +940,7 @@ struct ParallelInsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
ParallelCombiningOpInterface parallelCombiningParent =
@@ -1030,8 +1015,7 @@ struct SplatOpInterface
bool bufferizesToAllocation(Operation *op, Value value) const { return true; }
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(rewriter);
auto splatOp = cast<tensor::SplatOp>(op);
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 45b6e7c512947..b2272c5fda876 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,8 +48,7 @@ struct TransferReadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(isa<TensorType>(readOp.getShapedType()) &&
"only tensor types expected");
@@ -104,8 +103,7 @@ struct TransferWriteOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(isa<TensorType>(writeOp.getShapedType()) &&
"only tensor types expected");
@@ -150,8 +148,7 @@ struct GatherOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto gatherOp = cast<vector::GatherOp>(op);
assert(isa<TensorType>(gatherOp.getBaseType()) &&
"only tensor types expected");
@@ -205,8 +202,7 @@ struct MaskOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto maskOp = cast<vector::MaskOp>(op);
// Do not bufferize if the masked op is not bufferizable.
@@ -283,8 +279,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- const BufferizationOptions &options,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto yieldOp = cast<vector::YieldOp>(op);
// Only supported as a vector.mask terminator.
More information about the llvm-branch-commits
mailing list