[Mlir-commits] [mlir] b55d55e - [mlir][bufferize][NFC] Remove BufferizationState
Matthias Springer
llvmlistbot at llvm.org
Fri Jun 17 05:10:14 PDT 2022
Author: Matthias Springer
Date: 2022-06-17T14:04:11+02:00
New Revision: b55d55ecd9b2ce99b98bbb2595a1feb957d02a28
URL: https://github.com/llvm/llvm-project/commit/b55d55ecd9b2ce99b98bbb2595a1feb957d02a28
DIFF: https://github.com/llvm/llvm-project/commit/b55d55ecd9b2ce99b98bbb2595a1feb957d02a28.diff
LOG: [mlir][bufferize][NFC] Remove BufferizationState
With the recent refactorings, this class is no longer needed. We can use BufferizationOptions in all places were BufferizationState was used.
Differential Revision: https://reviews.llvm.org/D127653
Added:
Modified:
mlir/docs/Bufferization.md
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md
index 8317f3e87fbff..7d13e9d22eab8 100644
--- a/mlir/docs/Bufferization.md
+++ b/mlir/docs/Bufferization.md
@@ -30,40 +30,38 @@ and with aggressive in-place bufferization.
One-Shot Bufferize is:
-* **Monolithic**: A single MLIR pass does the entire
-work, whereas the previous bufferization in MLIR was split across multiple
-passes residing in
diff erent dialects. In One-Shot Bufferize,
-`BufferizableOpInterface` implementations are spread across
diff erent dialects.
-
-* A **whole-function at a time analysis**. In-place bufferization decisions are
-made by analyzing SSA use-def chains on tensors. Op interface implementations
-not only provide the rewrite logic from tensor ops to memref ops, but also
-helper methods for One-Shot Bufferize's analysis to query information about an
-op's bufferization/memory semantics.
-
-* **Extensible** via an op interface: All
-ops that implement `BufferizableOpInterface` can be bufferized.
-
-* **2-Pass**:
-Bufferization is internally broken down into 2 steps: First, analyze the entire
-IR and make bufferization decisions. Then, bufferize (rewrite) the IR. The
-analysis has access to exact SSA use-def information. It incrementally builds
-alias and equivalence sets and does not rely on a posteriori-alias analysis from
-preallocated memory.
-
-* **Greedy**: Operations are analyzed one-by-one and it is
-decided on the spot whether a tensor OpOperand must be copied or not. Heuristics
-determine the order of analysis.
-
-* **Modular**: The current One-Shot Analysis
-can be replaced with a
diff erent analysis. The result of the analysis are
-queried by the bufferization via `BufferizationState`, in particular
-`BufferizationState::isInPlace`. Any derived class of `BufferizationState` that
-implements a small number virtual functions can serve as a custom analysis. It
-is even possible to run One-Shot Bufferize without any analysis
-(`AlwaysCopyBufferizationState`), in which case One-Shot Bufferize behaves
-exactly like the old dialect conversion-based bufferization (i.e., copy every
-buffer before writing to it).
+* **Monolithic**: A single MLIR pass does the entire work, whereas the
+ previous bufferization in MLIR was split across multiple passes residing in
+
diff erent dialects. In One-Shot Bufferize, `BufferizableOpInterface`
+ implementations are spread across
diff erent dialects.
+
+* A **whole-function at a time analysis**. In-place bufferization decisions
+ are made by analyzing SSA use-def chains on tensors. Op interface
+ implementations not only provide the rewrite logic from tensor ops to memref
+ ops, but also helper methods for One-Shot Bufferize's analysis to query
+ information about an op's bufferization/memory semantics.
+
+* **Extensible** via an op interface: All ops that implement
+ `BufferizableOpInterface` can be bufferized.
+
+* **2-Pass**: Bufferization is internally broken down into 2 steps: First,
+ analyze the entire IR and make bufferization decisions. Then, bufferize
+ (rewrite) the IR. The analysis has access to exact SSA use-def information.
+ It incrementally builds alias and equivalence sets and does not rely on a
+ posteriori-alias analysis from preallocated memory.
+
+* **Greedy**: Operations are analyzed one-by-one and it is decided on the spot
+ whether a tensor OpOperand must be copied or not. Heuristics determine the
+ order of analysis.
+
+* **Modular**: The current One-Shot Analysis can be replaced with a
diff erent
+ analysis. The result of the analysis are queried by the bufferization via
+ `AnalysisState`, in particular `AnalysisState::isInPlace`. Any derived class
+ of `AnalysisState` that implements a small number virtual functions can
+ serve as a custom analysis. It is even possible to run One-Shot Bufferize
+ without any analysis (`AlwaysCopyAnalysisState`), in which case One-Shot
+ Bufferize behaves exactly like the old dialect conversion-based
+ bufferization (i.e., copy every buffer before writing to it).
To reduce complexity, One-Shot Bufferize should be
[run after other transformations](https://llvm.discourse.group/t/rfc-linalg-on-tensors-update-and-comprehensive-bufferization-rfc/3373),
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 3cd9d70138d11..fa44fde98b6ed 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -236,7 +236,7 @@ struct BufferizationOptions {
///
/// Note: Deactivating this flag can lead to incorrect bufferization results
/// when used incorrectly. This flag is useful with
- /// `AlwaysCopyBufferizationState` which bufferizes all writing tensor
+ /// `AlwaysCopyAnalysisState` which bufferizes all writing tensor
/// OpOperands out-of-place.
bool enforceAliasingInvariants = true;
@@ -464,33 +464,6 @@ class AnalysisState {
const BufferizationOptions &options;
};
-/// BufferizationState provides helper functions for performing bufferization
-/// rewrites and handling memref buffers.
-struct BufferizationState {
- BufferizationState(const BufferizationOptions &options) : options(options) {}
-
- /// Lookup the buffer for the given value. If the value was not bufferized
- /// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
- /// from which the memref operand is returned.
- Value getBuffer(RewriterBase &rewriter, Value value);
-
- /// Return the buffer type for a given Value (tensor) after bufferization.
- ///
- /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
- /// This function should only be used if `getBuffer` cannot be used.
- BaseMemRefType getBufferType(Value value) const;
-
- /// Return a reference to the BufferizationOptions.
- const BufferizationOptions &getOptions() const { return options; }
-
-protected:
- // BufferizationState should be passed as a reference.
- BufferizationState(const BufferizationState &) = delete;
-
-private:
- const BufferizationOptions &options;
-};
-
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
/// undefined contents is allocated.
@@ -498,6 +471,18 @@ Value allocateTensorForShapedValue(OpBuilder &b, Location loc,
Value shapedValue, bool escape,
bool copy = true);
+/// Lookup the buffer for the given value. If the value was not bufferized
+/// yet, wrap it in a ToMemrefOp. Otherwise, it is the result of a ToTensorOp,
+/// from which the memref operand is returned.
+Value getBuffer(RewriterBase &rewriter, Value value,
+ const BufferizationOptions &options);
+
+/// Return the buffer type for a given Value (tensor) after bufferization.
+///
+/// Note: Op implementations should preferrably call `getBuffer()->getType()`.
+/// This function should only be used if `getBuffer` cannot be used.
+BaseMemRefType getBufferType(Value value, const BufferizationOptions &options);
+
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
/// must be replaced with memref values.
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index e47e8478d4ad7..e550b900cb8ac 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -221,7 +221,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
InterfaceMethod<
/*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
- Buffers of tensor SSA values can be retrieved via `state.getBuffer`.
+ Buffers of tensor SSA values can be retrieved via `getBuffer`.
Uses of tensor results of the existing tensor op can be replaced with
`replaceOpWithBufferizedValues` or `replaceOpWithNewBufferizedOp`.
These two functions automatically handle the tensor-to-memref type
@@ -233,12 +233,6 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
a) A buffer that aliases one of buffers in getAliasingOpOperand(r).
b) Or: A newly allocated buffer.
- Regions of an op should be inlined into the new op instead of cloning
- them. This is not only more efficient, but also necessary so that no
- analysis results are lost. (Bufferization decisions are tracked via
- OpOperand pointers and cloned ops have new OpOperands.) If regions are
- cloned instead of inlined, additional buffer copies may be inserted.
-
This method will never be called on ops that do not have at least one
tensor operand/result.
@@ -252,7 +246,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"LogicalResult",
/*methodName=*/"bufferize",
/*args=*/(ins "RewriterBase &":$rewriter,
- "BufferizationState &":$state),
+ "const BufferizationOptions &":$options),
/*methodBody=*/"",
/*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index a0509767cfed8..93154f48c32ec 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -71,7 +71,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
let results = (outs AnyTensor:$result);
let extraClassDeclaration = [{
- LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);
+ LogicalResult bufferize(RewriterBase &rewriter,
+ const BufferizationOptions &options);
bool isMemoryWrite(OpResult opResult, const AnalysisState &state);
@@ -242,7 +243,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// results as not writable enforces a buffer copy and has the same effect.
LogicalResult bufferize(RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// to_tensor cannot be bufferized. However, other ops that are using
// to_tensor's result will eventually be bufferized. At that point, they
// will start using to_tensor's memref operand. Once all users of
@@ -334,7 +335,7 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
}
LogicalResult bufferize(RewriterBase &rewriter,
- BufferizationState &state);
+ const BufferizationOptions &options);
}];
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 92cb346b265f8..a2b7f7f5017d7 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -25,7 +25,6 @@ namespace mlir {
namespace bufferization {
class AnalysisState;
-struct BufferizationState;
struct BufferizationOptions;
class OpFilter;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
index 63a1e07d16a22..782b2c4aeeda0 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h
@@ -15,7 +15,6 @@ struct LogicalResult;
class ModuleOp;
namespace bufferization {
-struct BufferizationState;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 7bd89762e8a41..b73c9039abcf6 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -23,7 +23,7 @@ struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
// Only ranked tensors are supported.
@@ -38,7 +38,7 @@ struct ConstantOpInterface
// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
FailureOr<memref::GlobalOp> globalOp =
- getGlobalFor(constantOp, state.getOptions().bufferAlignment);
+ getGlobalFor(constantOp, options.bufferAlignment);
if (failed(globalOp))
return failure();
memref::GlobalOp globalMemref = globalOp.getValue();
@@ -80,11 +80,11 @@ struct IndexCastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = castOp.getType().cast<TensorType>();
- Value source = state.getBuffer(rewriter, castOp.getIn());
+ Value source = getBuffer(rewriter, castOp.getIn(), options);
auto sourceType = source.getType().cast<BaseMemRefType>();
// Result type should have same layout and address space as the source type.
@@ -132,7 +132,7 @@ struct SelectOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto selectOp = cast<arith::SelectOp>(op);
Location loc = selectOp.getLoc();
@@ -140,8 +140,8 @@ struct SelectOpInterface
// instead of its OpOperands. In the worst case, 2 copies are inserted at
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
- Value trueBuffer = state.getBuffer(rewriter, selectOp.getTrueValue());
- Value falseBuffer = state.getBuffer(rewriter, selectOp.getFalseValue());
+ Value trueBuffer = getBuffer(rewriter, selectOp.getTrueValue(), options);
+ Value falseBuffer = getBuffer(rewriter, selectOp.getFalseValue(), options);
// The "true" and the "false" operands must have the same type. If the
// buffers have
diff erent types, they
diff er only in their layout map. Cast
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 8279d7efce65d..3e97ecdcbfff2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -477,7 +477,8 @@ static void ensureToMemrefOpIsValid(Value tensor, Type memrefType) {
#endif
}
-Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
+Value bufferization::getBuffer(RewriterBase &rewriter, Value value,
+ const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
@@ -488,21 +489,22 @@ Value BufferizationState::getBuffer(RewriterBase &rewriter, Value value) {
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- Type memrefType = getMemRefType(tensorType, getOptions());
+ Type memrefType = getMemRefType(tensorType, options);
ensureToMemrefOpIsValid(value, memrefType);
return rewriter.create<bufferization::ToMemrefOp>(value.getLoc(), memrefType,
value);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
-BaseMemRefType BufferizationState::getBufferType(Value value) const {
+BaseMemRefType
+bufferization::getBufferType(Value value, const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
return toTensorOp.memref().getType().cast<BaseMemRefType>();
- return getMemRefType(tensorType, getOptions());
+ return getMemRefType(tensorType, options);
}
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index ad73f9da70fff..1b59a09280b20 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -150,7 +150,7 @@ void mlir::bufferization::populateDynamicDimSizes(
//===----------------------------------------------------------------------===//
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
OpBuilder::InsertionGuard g(rewriter);
Location loc = getLoc();
@@ -163,7 +163,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
// Create buffer allocation.
Value copyBuffer;
if (copy())
- copyBuffer = state.getBuffer(rewriter, copy());
+ copyBuffer = getBuffer(rewriter, copy(), options);
auto allocType =
MemRefType::get(getType().getShape(), getType().getElementType());
SmallVector<Value> dynamicDims = dynamicSizes();
@@ -172,25 +172,24 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
}
FailureOr<Value> alloc =
- state.getOptions().createAlloc(rewriter, loc, allocType, dynamicDims);
+ options.createAlloc(rewriter, loc, allocType, dynamicDims);
if (failed(alloc))
return failure();
// Create memory copy (if any).
if (copy()) {
- if (failed(
- state.getOptions().createMemCpy(rewriter, loc, copyBuffer, *alloc)))
+ if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
return failure();
}
// Should the buffer be deallocated?
- AnalysisState analysisState(state.getOptions());
+ AnalysisState analysisState(options);
bool dealloc;
if (escape().hasValue()) {
dealloc = !*escape();
} else {
// No "escape" annotation found.
- if (state.getOptions().createDeallocs) {
+ if (options.createDeallocs) {
// Perform an ad-hoc analysis.
dealloc = !analysisState.isTensorYielded(getResult());
} else {
@@ -206,7 +205,7 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
return success();
rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator());
- if (failed(state.getOptions().createDealloc(rewriter, loc, *alloc)))
+ if (failed(options.createDealloc(rewriter, loc, *alloc)))
return failure();
return success();
}
@@ -627,7 +626,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
(void)foldToMemrefToTensorPair(rewriter, *this);
// Note: The return value of `bufferize` indicates whether there was an error
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 8f4d2066e092e..dd096d0f7f967 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -401,7 +401,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
DenseSet<Operation *> erasedOps;
// Bufferize all ops.
- BufferizationState bufferizationState(options);
BufferizationRewriter rewriter(op->getContext(), erasedOps, toMemrefOps,
worklist, options, opFilter);
for (unsigned i = 0; i < worklist.size(); ++i) {
@@ -420,7 +419,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
continue;
// Bufferize the op.
rewriter.setInsertionPoint(op);
- if (failed(bufferizableOp.bufferize(rewriter, bufferizationState)))
+ if (failed(bufferizableOp.bufferize(rewriter, options)))
return op->emitError("failed to bufferize op");
}
@@ -433,7 +432,7 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
/// Check the result of bufferization. Return an error if an op was not
/// bufferized, unless partial bufferization is allowed.
- if (bufferizationState.getOptions().allowUnknownOps)
+ if (options.allowUnknownOps)
return success();
for (Operation *op : worklist) {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index e75338594d5bb..6805e76ca435c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -258,7 +258,7 @@ struct CallOpInterface
/// All function arguments are writable. It is the responsibility of the
/// CallOp to insert buffer copies where necessary.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
func::CallOp callOp = cast<func::CallOp>(op);
unsigned numResults = callOp.getNumResults();
unsigned numOperands = callOp->getNumOperands();
@@ -307,7 +307,7 @@ struct CallOpInterface
// Retrieve buffers for tensor operands.
Value buffer = newOperands[idx];
if (!buffer)
- buffer = state.getBuffer(rewriter, opOperand.get());
+ buffer = getBuffer(rewriter, opOperand.get(), options);
// Caller / callee type mismatch is handled with a CastOp.
auto memRefType = funcType.getInput(idx);
@@ -364,7 +364,7 @@ struct ReturnOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
#ifndef NDEBUG
auto returnOp = cast<func::ReturnOp>(op);
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -386,11 +386,9 @@ struct FuncOpInterface
/// All function bbArgs are writable unless they are explicitly marked as
/// read-only. Callers must insert copies when needed.
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto funcOp = cast<FuncOp>(op);
FunctionType funcType = funcOp.getFunctionType();
- const OneShotBufferizationOptions &options =
- static_cast<const OneShotBufferizationOptions &>(state.getOptions());
// Construct the bufferized function type.
SmallVector<Type> argTypes;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 773fa47692449..bb6f0532af705 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -429,7 +429,6 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
- BufferizationState bufferizationState(options);
// A list of functions in the order in which they are analyzed + bufferized.
SmallVector<func::FuncOp> orderedFuncOps;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 3ecab39cce61b..cc27b4403d898 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -20,11 +20,9 @@ using namespace mlir::bufferization;
namespace {
-// TODO: Ops in the linalg dialect can directly implement this interface.
-
/// Generic conversion for any LinalgOp on tensors.
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
@@ -46,14 +44,14 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
newInputBuffers.push_back(opOperand->get());
continue;
}
- newInputBuffers.push_back(state.getBuffer(rewriter, opOperand->get()));
+ newInputBuffers.push_back(getBuffer(rewriter, opOperand->get(), options));
}
// New output operands for the cloned op.
SmallVector<Value> newOutputBuffers;
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getOutputOperand(opResult.getResultNumber());
- Value resultBuffer = state.getBuffer(rewriter, opOperand->get());
+ Value resultBuffer = getBuffer(rewriter, opOperand->get(), options);
newOutputBuffers.push_back(resultBuffer);
}
@@ -123,8 +121,8 @@ struct LinalgOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
- return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
+ const BufferizationOptions &options) const {
+ return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), options);
}
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 55b6518028847..4bff6b56e240c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -73,7 +73,7 @@ struct ExecuteRegionOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
// Compute new result types.
@@ -81,7 +81,7 @@ struct ExecuteRegionOpInterface
for (Type type : executeRegionOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
- newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+ newResultTypes.push_back(getMemRefType(tensorType, options));
} else {
newResultTypes.push_back(type);
}
@@ -183,7 +183,7 @@ struct IfOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto ifOp = cast<scf::IfOp>(op);
// Compute new types of the bufferized scf.if op.
@@ -191,7 +191,7 @@ struct IfOpInterface
for (Type returnType : ifOp->getResultTypes()) {
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
- newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+ newTypes.push_back(getMemRefType(tensorType, options));
} else {
newTypes.push_back(returnType);
}
@@ -309,11 +309,11 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
/// given OpOperands. If an operand is not a tensor, return the original value.
static SmallVector<Value> getBuffers(RewriterBase &rewriter,
MutableArrayRef<OpOperand> operands,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
if (opOperand.get().getType().isa<TensorType>()) {
- Value resultBuffer = state.getBuffer(rewriter, opOperand.get());
+ Value resultBuffer = getBuffer(rewriter, opOperand.get(), options);
result.push_back(resultBuffer);
} else {
result.push_back(opOperand.get());
@@ -325,10 +325,11 @@ static SmallVector<Value> getBuffers(RewriterBase &rewriter,
/// Helper function for loop bufferization. Compute the buffer that should be
/// yielded from a loop block (loop body or loop condition).
static Value getYieldedBuffer(RewriterBase &rewriter, Value tensor,
- BaseMemRefType type, BufferizationState &state) {
+ BaseMemRefType type,
+ const BufferizationOptions &options) {
assert(tensor.getType().isa<TensorType>() && "expected tensor");
ensureToMemrefOpIsValid(tensor, type);
- Value yieldedVal = state.getBuffer(rewriter, tensor);
+ Value yieldedVal = getBuffer(rewriter, tensor, options);
return castBuffer(rewriter, yieldedVal, type);
}
@@ -352,12 +353,12 @@ convertTensorValues(ValueRange values, const DenseSet<int64_t> &tensorIndices,
SmallVector<Value> getYieldedValues(RewriterBase &rewriter, ValueRange values,
TypeRange bufferizedTypes,
const DenseSet<int64_t> &tensorIndices,
- BufferizationState &state) {
+ const BufferizationOptions &options) {
return convertTensorValues(
values, tensorIndices, [&](Value val, int64_t index) {
return getYieldedBuffer(rewriter, val,
bufferizedTypes[index].cast<BaseMemRefType>(),
- state);
+ options);
});
}
@@ -472,7 +473,7 @@ struct ForOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto forOp = cast<scf::ForOp>(op);
Block *oldLoopBody = &forOp.getLoopBody().front();
@@ -482,7 +483,7 @@ struct ForOpInterface
// The new memref init_args of the loop.
SmallVector<Value> initArgs =
- getBuffers(rewriter, forOp.getIterOpOperands(), state);
+ getBuffers(rewriter, forOp.getIterOpOperands(), options);
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = rewriter.create<scf::ForOp>(
@@ -511,7 +512,7 @@ struct ForOpInterface
auto yieldOp = cast<scf::YieldOp>(loopBody->getTerminator());
rewriter.setInsertionPoint(yieldOp);
SmallVector<Value> yieldValues = getYieldedValues(
- rewriter, yieldOp.getResults(), initArgsTypes, indices, state);
+ rewriter, yieldOp.getResults(), initArgsTypes, indices, options);
yieldOp.getResultsMutable().assign(yieldValues);
// Replace loop results.
@@ -704,7 +705,7 @@ struct WhileOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(whileOp.getBefore().getBlocks().size() == 1 &&
@@ -722,12 +723,12 @@ struct WhileOpInterface
// The new memref init_args of the loop.
SmallVector<Value> initArgs =
- getBuffers(rewriter, whileOp->getOpOperands(), state);
+ getBuffers(rewriter, whileOp->getOpOperands(), options);
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::to_vector(
llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- return state.getBufferType(bbArg).cast<Type>();
+ return getBufferType(bbArg, options).cast<Type>();
}));
// Construct a new scf.while op with memref instead of tensor values.
@@ -761,7 +762,7 @@ struct WhileOpInterface
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newConditionArgs =
getYieldedValues(rewriter, newConditionOp.getArgs(), argsTypesAfter,
- indicesAfter, state);
+ indicesAfter, options);
newConditionOp.getArgsMutable().assign(newConditionArgs);
// Set up new iter_args and move the loop body block to the new op.
@@ -780,7 +781,7 @@ struct WhileOpInterface
// TODO: This could be relaxed for better bufferization results.
SmallVector<Value> newYieldValues =
getYieldedValues(rewriter, newYieldOp.getResults(), argsTypesBefore,
- indicesBefore, state);
+ indicesBefore, options);
newYieldOp.getResultsMutable().assign(newYieldValues);
// Replace loop results.
@@ -866,7 +867,7 @@ struct YieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto yieldOp = cast<scf::YieldOp>(op);
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
yieldOp->getParentOp()))
@@ -954,7 +955,7 @@ struct ForeachThreadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
OpBuilder::InsertionGuard g(b);
auto foreachThreadOp = cast<ForeachThreadOp>(op);
@@ -966,7 +967,7 @@ struct ForeachThreadOpInterface
// Insert copies right before the PerformConcurrentlyOp terminator. They
// should not be inside terminator (which would be the default insertion
// point).
- Value buffer = state.getBuffer(b, insertDest->get());
+ Value buffer = getBuffer(b, insertDest->get(), options);
newResults.push_back(buffer);
}
@@ -991,8 +992,7 @@ struct ForeachThreadOpInterface
performConcurrentlyOp.walk([&](ParallelInsertSliceOp insertOp) {
Location loc = insertOp.getLoc();
Type srcType = getMemRefType(
- insertOp.getSource().getType().cast<RankedTensorType>(),
- state.getOptions());
+ insertOp.getSource().getType().cast<RankedTensorType>(), options);
// ParallelInsertSliceOp bufferizes to a copy.
auto srcMemref = b.create<bufferization::ToMemrefOp>(
loc, srcType, insertOp.getSource());
@@ -1001,8 +1001,8 @@ struct ForeachThreadOpInterface
loc, destMemref, insertOp.getMixedOffsets(),
insertOp.getMixedSizes(), insertOp.getMixedStrides());
// This memcpy will fold away if everything bufferizes in-place.
- if (failed(state.getOptions().createMemCpy(b, insertOp.getLoc(),
- srcMemref, subview)))
+ if (failed(options.createMemCpy(b, insertOp.getLoc(), srcMemref,
+ subview)))
return WalkResult::interrupt();
b.eraseOp(insertOp);
return WalkResult::advance();
@@ -1022,7 +1022,7 @@ struct PerformConcurrentlyOpInterface
: public BufferizableOpInterface::ExternalModel<
PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
LogicalResult bufferize(Operation *op, RewriterBase &b,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
llvm_unreachable("op does not have any tensor OpOperands / OpResults");
return failure();
}
@@ -1110,7 +1110,7 @@ struct ParallelInsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &b,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// Will be bufferized as part of ForeachThreadOp.
return failure();
}
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index 1240b65d1a7e4..177d820ccefcb 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -59,7 +59,7 @@ struct AssumingOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto assumingOp = cast<shape::AssumingOp>(op);
// Compute new result types.
@@ -67,7 +67,7 @@ struct AssumingOpInterface
for (Type type : assumingOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
// TODO: Infer the result type instead of computing it.
- newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
+ newResultTypes.push_back(getMemRefType(tensorType, options));
} else {
newResultTypes.push_back(type);
}
@@ -152,7 +152,7 @@ struct AssumingYieldOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// Op is bufferized as part of AssumingOp.
return failure();
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 430b2f6df8aa5..7695db6c59b4f 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,11 +48,11 @@ struct CastOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto castOp = cast<tensor::CastOp>(op);
// The result buffer still has the old (pre-cast) type.
- Value resultBuffer = state.getBuffer(rewriter, castOp.source());
+ Value resultBuffer = getBuffer(rewriter, castOp.source(), options);
auto sourceMemRefType = resultBuffer.getType().cast<BaseMemRefType>();
Attribute memorySpace = sourceMemRefType.getMemorySpace();
TensorType resultTensorType =
@@ -64,8 +64,8 @@ struct CastOpInterface
layout = rankedMemRefType.getLayout();
// Compute the new memref type.
- Type resultMemRefType = getMemRefType(resultTensorType, state.getOptions(),
- layout, memorySpace);
+ Type resultMemRefType =
+ getMemRefType(resultTensorType, options, layout, memorySpace);
// Replace the op with a memref.cast.
assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),
@@ -105,10 +105,10 @@ struct CollapseShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
- Value buffer = state.getBuffer(rewriter, collapseShapeOp.src());
+ Value buffer = getBuffer(rewriter, collapseShapeOp.src(), options);
auto bufferType = buffer.getType().cast<MemRefType>();
if (tensorResultType.getRank() == 0) {
@@ -146,7 +146,7 @@ struct CollapseShapeOpInterface
bufferType, collapseShapeOp.getReassociationIndices());
if (!canBeCollapsed) {
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
- AnalysisState analysisState(state.getOptions());
+ AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
rewriter, op->getLoc(), collapseShapeOp.src(),
analysisState.isTensorYielded(collapseShapeOp.result()));
@@ -185,9 +185,9 @@ struct DimOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto dimOp = cast<tensor::DimOp>(op);
- auto v = state.getBuffer(rewriter, dimOp.source());
+ auto v = getBuffer(rewriter, dimOp.source(), options);
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, v, dimOp.index());
return success();
}
@@ -220,10 +220,10 @@ struct ExpandShapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
- auto buffer = state.getBuffer(rewriter, expandShapeOp.src());
+ auto buffer = getBuffer(rewriter, expandShapeOp.src(), options);
// Memref result type is inferred by the builder based on reassociation
// indices and result shape.
@@ -261,13 +261,13 @@ struct ExtractSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
// Even if this op was decided to bufferize out-of-place, do not insert the
// buffer copy yet. This is done later in this function.
- auto srcMemref = state.getBuffer(rewriter, extractSliceOp.source());
+ auto srcMemref = getBuffer(rewriter, extractSliceOp.source(), options);
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
@@ -319,9 +319,9 @@ struct ExtractOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto extractOp = cast<tensor::ExtractOp>(op);
- Value srcMemref = state.getBuffer(rewriter, extractOp.tensor());
+ Value srcMemref = getBuffer(rewriter, extractOp.tensor(), options);
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, srcMemref,
extractOp.indices());
return success();
@@ -355,7 +355,7 @@ struct FromElementsOpInterface
: public BufferizableOpInterface::ExternalModel<FromElementsOpInterface,
tensor::FromElementsOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto fromElementsOp = cast<tensor::FromElementsOp>(op);
// Allocate a buffer for the result.
@@ -363,7 +363,7 @@ struct FromElementsOpInterface
auto tensorType = fromElementsOp.getType().cast<RankedTensorType>();
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
- AnalysisState analysisState(state.getOptions());
+ AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, fromElementsOp.result(),
analysisState.isTensorYielded(fromElementsOp.result()),
@@ -410,13 +410,13 @@ struct GenerateOpInterface
: public BufferizableOpInterface::ExternalModel<GenerateOpInterface,
tensor::GenerateOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto generateOp = cast<tensor::GenerateOp>(op);
auto tensorType = generateOp.getType().cast<RankedTensorType>();
// Allocate memory.
Location loc = op->getLoc();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
- AnalysisState analysisState(state.getOptions());
+ AnalysisState analysisState(options);
Value tensorAlloc = allocateTensorForShapedValue(
rewriter, loc, generateOp.result(),
analysisState.isTensorYielded(generateOp.result()),
@@ -493,9 +493,9 @@ struct InsertOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto insertOp = cast<tensor::InsertOp>(op);
- Value destMemref = state.getBuffer(rewriter, insertOp.dest());
+ Value destMemref = getBuffer(rewriter, insertOp.dest(), options);
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.scalar(),
destMemref, insertOp.indices());
replaceOpWithBufferizedValues(rewriter, op, destMemref);
@@ -645,7 +645,7 @@ struct InsertSliceOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
// insert_slice ops arise from tiling and bufferizing them out-of-place is
// generally a deal breaker. When used with loops, this ends up cloning the
// whole tensor on every single iteration and is a symptom of a
@@ -653,7 +653,7 @@ struct InsertSliceOpInterface
// TODO: be very loud about it or even consider failing the pass.
auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
Location loc = insertSliceOp.getLoc();
- Value dstMemref = state.getBuffer(rewriter, insertSliceOp.dest());
+ Value dstMemref = getBuffer(rewriter, insertSliceOp.dest(), options);
// Expand offsets, sizes and strides to the full rank to handle the
// rank-reducing case.
@@ -681,9 +681,8 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
- auto srcMemref = state.getBuffer(rewriter, insertSliceOp.source());
- if (failed(
- state.getOptions().createMemCpy(rewriter, loc, srcMemref, subView)))
+ auto srcMemref = getBuffer(rewriter, insertSliceOp.source(), options);
+ if (failed(options.createMemCpy(rewriter, loc, srcMemref, subView)))
return failure();
replaceOpWithBufferizedValues(rewriter, op, dstMemref);
@@ -711,9 +710,9 @@ struct RankOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto rankOp = cast<tensor::RankOp>(op);
- auto v = state.getBuffer(rewriter, rankOp.tensor());
+ auto v = getBuffer(rewriter, rankOp.tensor(), options);
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
v);
return success();
@@ -747,12 +746,12 @@ struct ReshapeOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
- Value srcBuffer = state.getBuffer(rewriter, reshapeOp.source());
- Value shapeBuffer = state.getBuffer(rewriter, reshapeOp.shape());
+ Value srcBuffer = getBuffer(rewriter, reshapeOp.source(), options);
+ Value shapeBuffer = getBuffer(rewriter, reshapeOp.shape(), options);
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
- auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions());
+ auto resultMemRefType = getMemRefType(resultTensorType, options);
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, srcBuffer, shapeBuffer);
return success();
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index b7344ee79481d..142e09bbb3da5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -46,11 +46,11 @@ struct TransferReadOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto readOp = cast<vector::TransferReadOp>(op);
assert(readOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
- Value buffer = state.getBuffer(rewriter, readOp.getSource());
+ Value buffer = getBuffer(rewriter, readOp.getSource(), options);
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
rewriter, readOp, readOp.getVectorType(), buffer, readOp.getIndices(),
readOp.getPermutationMap(), readOp.getPadding(), readOp.getMask(),
@@ -91,13 +91,13 @@ struct TransferWriteOpInterface
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
- BufferizationState &state) const {
+ const BufferizationOptions &options) const {
auto writeOp = cast<vector::TransferWriteOp>(op);
assert(writeOp.getShapedType().isa<TensorType>() &&
"only tensor types expected");
// Create a new transfer_write on buffer that doesn't have a return value.
- Value resultBuffer = state.getBuffer(rewriter, writeOp.getSource());
+ Value resultBuffer = getBuffer(rewriter, writeOp.getSource(), options);
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), writeOp.getVector(), resultBuffer,
writeOp.getIndices(), writeOp.getPermutationMapAttr(),
More information about the Mlir-commits
mailing list