[Mlir-commits] [mlir] [MLIR] Add bufferization state to `getBufferType` and `resolveConflicts` interface methods (PR #141466)
Michele Scuttari
llvmlistbot at llvm.org
Tue May 27 02:27:49 PDT 2025
https://github.com/mscuttari updated https://github.com/llvm/llvm-project/pull/141466
>From f3a97b3dfe3fd7572b860a57949e31cd0172a762 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Mon, 26 May 2025 11:22:30 +0200
Subject: [PATCH 1/3] [MLIR] Add bufferization state to `getBufferType` and
`resolveConflicts` interface methods
---
.../IR/BufferizableOpInterface.h | 10 +-
.../IR/BufferizableOpInterface.td | 11 +-
.../Bufferization/IR/BufferizationOps.td | 3 +-
.../IR/UnstructuredControlFlow.h | 5 +-
.../Bufferization/Transforms/Bufferize.h | 3 +-
.../Bufferization/Transforms/Transforms.h | 5 +-
.../BufferizableOpInterfaceImpl.cpp | 18 +--
.../IR/BufferizableOpInterface.cpp | 51 +++++----
.../Bufferization/IR/BufferizationOps.cpp | 17 +--
.../Bufferization/Transforms/Bufferize.cpp | 11 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 13 ++-
.../Transforms/OneShotAnalysis.cpp | 2 +-
.../Transforms/OneShotModuleBufferize.cpp | 5 +-
.../Transforms/TensorCopyInsertion.cpp | 21 ++--
.../BufferizableOpInterfaceImpl.cpp | 26 ++---
.../BufferizableOpInterfaceImpl.cpp | 3 +-
.../BufferizableOpInterfaceImpl.cpp | 105 ++++++++++--------
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../SparsificationAndBufferizationPass.cpp | 6 +-
.../BufferizableOpInterfaceImpl.cpp | 83 ++++++++------
.../BufferizableOpInterfaceImpl.cpp | 17 ++-
.../Bufferization/TestTensorCopyInsertion.cpp | 6 +-
22 files changed, 248 insertions(+), 175 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
- bool copy = true);
+ BufferizationState &state, bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::AnalysisState &":$state),
+ "const ::mlir::bufferization::AnalysisState &":$analysisState,
+ "::mlir::bufferization::BufferizationState &":$bufferizationState),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
- rewriter, state);
+ rewriter, analysisState, bufferizationState);
}]
>,
InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
- value, options, invocationStack);
+ value, options, state, invocationStack);
}]
>,
InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// form of `bufferization.alloc_tensor` ops.
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
- const ::mlir::bufferization::AnalysisState &state);
+ const ::mlir::bufferization::AnalysisState &analysisState,
+ ::mlir::bufferization::BufferizationState &bufferizationState);
/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- SmallVector<Value> &invocationStack) {
+ BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
// the aliasing OpOperand (if any).
if (isa<OpResult>(value))
- return bufferization::detail::defaultGetBufferType(value, options,
+ return bufferization::detail::defaultGetBufferType(value, options, state,
invocationStack);
// Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
- bufferization::getBufferType(opOperand->get(), options,
+ bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
/// computed with `BufferizableOpIntercace::getBufferType`.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState);
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
- FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+ FailureOr<Value> source =
+ getBuffer(rewriter, castOp.getIn(), options, state);
if (failed(source))
return failure();
auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
FailureOr<Value> maybeTrueBuffer =
- getBuffer(rewriter, selectOp.getTrueValue(), options);
+ getBuffer(rewriter, selectOp.getTrueValue(), options, state);
FailureOr<Value> maybeFalseBuffer =
- getBuffer(rewriter, selectOp.getFalseValue(), options);
+ getBuffer(rewriter, selectOp.getFalseValue(), options, state);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options);
+ bufferization::getBufferType(selectOp.getResult(), options, state);
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
- options, invocationStack);
- auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
- options, invocationStack);
+ auto trueType = bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack);
+ auto falseType = bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
- const BufferizationOptions &options, bool copy) {
+ const BufferizationOptions &options, BufferizationState &state, bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BaseMemRefType> copyBufferType =
+ getBufferType(tensor, options, state);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
- RewriterBase &rewriter, const AnalysisState &state) {
+ RewriterBase &rewriter, const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
- if (state.isInPlace(opOperand))
+ if (analysisState.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");
- AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+ AliasingValueList aliasingValues =
+ analysisState.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
- !state.bufferizesToMemoryWrite(opOperand) &&
- state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+ !analysisState.bufferizesToMemoryWrite(opOperand) &&
+ analysisState
+ .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// cannot be copied at the moment).
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
- copiedOpOperands.contains(opOperand));
+ rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+ bufferizationState, copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), value, state.getOptions(),
- copiedOpValues.count(value));
+ rewriter, op->getLoc(), value, analysisState.getOptions(),
+ bufferizationState, copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
}
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+ FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
if (failed(memrefType))
return failure();
ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state) {
SmallVector<Value> invocationStack;
- return getBufferType(value, options, invocationStack);
+ return getBufferType(value, options, state, invocationStack);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
- return bufferizableOp.getBufferType(value, options, invocationStack);
+ return bufferizableOp.getBufferType(value, options, state, invocationStack);
// Op is not bufferizable.
auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
+ BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
@@ -954,14 +962,15 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
auto opResult = llvm::cast<OpResult>(value);
- AnalysisState state(options);
- AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
+ AnalysisState analysisState(options);
+ AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() > 0 &&
aliases.getAliases()[0].relation == BufferRelation::Equivalent) {
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return getBufferType(equivalentOperand, options, invocationStack);
+ return getBufferType(equivalentOperand, options, bufferizationState,
+ invocationStack);
}
// If we do not know the memory space and there is no default memory space,
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 91eccb0ab7430..41b86437e11cf 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -163,14 +163,15 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
// Get "copy" buffer.
Value copyBuffer;
if (getCopy()) {
- FailureOr<Value> maybeCopyBuffer = getBuffer(rewriter, getCopy(), options);
+ FailureOr<Value> maybeCopyBuffer =
+ getBuffer(rewriter, getCopy(), options, state);
if (failed(maybeCopyBuffer))
return failure();
copyBuffer = *maybeCopyBuffer;
}
// Create memory allocation.
- auto allocType = bufferization::getBufferType(getResult(), options);
+ auto allocType = bufferization::getBufferType(getResult(), options, state);
if (failed(allocType))
return failure();
SmallVector<Value> dynamicDims = getDynamicSizes();
@@ -223,6 +224,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
FailureOr<BaseMemRefType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(value == getResult() && "invalid value");
@@ -231,8 +233,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
if (getMemorySpace().has_value()) {
memorySpace = *getMemorySpace();
} else if (getCopy()) {
- auto copyBufferType =
- bufferization::getBufferType(getCopy(), options, invocationStack);
+ auto copyBufferType = bufferization::getBufferType(getCopy(), options,
+ state, invocationStack);
if (failed(copyBufferType))
return failure();
memorySpace = copyBufferType->getMemorySpace();
@@ -532,7 +534,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options,
BufferizationState &state) {
- FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
+ FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
if (failed(buffer))
return failure();
rewriter.create<memref::DeallocOp>(getLoc(), *buffer);
@@ -583,7 +585,8 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
- FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, getDest(), options, state);
if (failed(maybeBuffer))
return failure();
buffer = *maybeBuffer;
@@ -591,7 +594,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
buffer = getDest();
}
- auto srcBuffer = getBuffer(rewriter, getSource(), options);
+ auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
if (failed(srcBuffer))
return failure();
if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 67f373d912dd4..c7681d309a4af 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -280,8 +280,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
if (options.copyBeforeWrite) {
- AnalysisState state(options);
- if (failed(insertTensorCopies(op, state)))
+ AnalysisState analysisState(options);
+ if (failed(insertTensorCopies(op, analysisState, bufferizationState)))
return failure();
}
@@ -396,7 +396,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
LogicalResult
bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
OpBuilder::InsertionGuard g(rewriter);
auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp());
if (!bufferizableOp)
@@ -412,7 +413,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
}
FailureOr<BaseMemRefType> memrefType =
- bufferization::getBufferType(bbArg, options);
+ bufferization::getBufferType(bbArg, options, state);
if (failed(memrefType))
return failure();
newTypes.push_back(*memrefType);
@@ -463,7 +464,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
continue;
}
FailureOr<BaseMemRefType> operandBufferType =
- bufferization::getBufferType(operand, options);
+ bufferization::getBufferType(operand, options, state);
if (failed(operandBufferType))
return failure();
rewriter.setInsertionPointAfterValue(operand);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 6210f1d787bf4..9a0a85a71debd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -213,6 +213,7 @@ struct CallOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
@@ -255,7 +256,7 @@ struct CallOpInterface
// Returning a memref.
FailureOr<BaseMemRefType> resultType =
- bufferization::getBufferType(result, options);
+ bufferization::getBufferType(result, options, state);
if (failed(resultType))
return failure();
resultTypes.push_back(*resultType);
@@ -278,7 +279,7 @@ struct CallOpInterface
// Retrieve buffers for tensor operands.
FailureOr<Value> maybeBuffer =
- getBuffer(rewriter, opOperand.get(), options);
+ getBuffer(rewriter, opOperand.get(), options, state);
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
@@ -291,7 +292,8 @@ struct CallOpInterface
// result type.
FailureOr<BaseMemRefType> maybeMemRefType =
bufferization::getBufferType(
- funcOp.getArgument(opOperand.getOperandNumber()), options);
+ funcOp.getArgument(opOperand.getOperandNumber()), options,
+ state);
if (failed(maybeMemRefType))
return failure();
memRefType = *maybeMemRefType;
@@ -396,6 +398,7 @@ struct FuncOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto funcOp = cast<FuncOp>(op);
auto bbArg = cast<BlockArgument>(value);
@@ -406,7 +409,7 @@ struct FuncOpInterface
options);
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
- getBufferType(op, value, options, invocationStack);
+ getBufferType(op, value, options, state, invocationStack);
}
/// Rewrite function bbArgs and return values into buffer form. This function
@@ -459,7 +462,7 @@ struct FuncOpInterface
// 1. Bufferize every block.
for (Block &block : funcOp.getBody())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
- options)))
+ options, state)))
return failure();
// 2. Bufferize the operands of the all return op.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index de820e9c8f8af..33a922d59224b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -1379,7 +1379,7 @@ LogicalResult bufferization::runOneShotBufferize(
// Run One-Shot Analysis and insert buffer copies (on the tensor level)
// only where needed. This is the default and much more efficient than
// copy-before-write.
- if (failed(insertTensorCopies(op, options, statistics)))
+ if (failed(insertTensorCopies(op, options, state, statistics)))
return failure();
// If test-analysis-only is set, the IR was annotated with RaW conflict
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 90ceea4d69680..dee2af8271ce8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -584,7 +584,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
"invalid combination of bufferization flags");
if (!options.copyBeforeWrite) {
if (options.noAnalysisFuncFilter.empty()) {
- if (failed(insertTensorCopies(moduleOp, options, statistics)))
+ if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
return failure();
} else {
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -600,7 +600,8 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
};
OneShotBufferizationOptions updatedOptions(options);
updatedOptions.opFilter.denyOperation(analysisFilterFn);
- if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics)))
+ if (failed(
+ insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
return failure();
}
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index 4326b19f3104d..d971ed5b0f71c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -28,28 +28,29 @@ using namespace mlir::bufferization;
LogicalResult mlir::bufferization::insertTensorCopies(
Operation *op, const OneShotBufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
- OneShotAnalysisState state(op, options);
+ OneShotAnalysisState analysisState(op, options);
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
- if (failed(analyzeModuleOp(cast<ModuleOp>(op), state, statistics)))
+ if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
return failure();
} else {
- if (failed(analyzeOp(op, state, statistics)))
+ if (failed(analyzeOp(op, analysisState, statistics)))
return failure();
}
if (options.testAnalysisOnly)
return success();
- return insertTensorCopies(op, state);
+ return insertTensorCopies(op, analysisState, bufferizationState);
}
-LogicalResult
-mlir::bufferization::insertTensorCopies(Operation *op,
- const AnalysisState &state) {
+LogicalResult mlir::bufferization::insertTensorCopies(
+ Operation *op, const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) {
IRRewriter rewriter(op->getContext());
// It may be more efficient to walk in pre-order here, but the current
@@ -62,14 +63,16 @@ mlir::bufferization::insertTensorCopies(Operation *op,
nestedOp->getParentWithTrait<OpTrait::SymbolTable>() != op)
return WalkResult::skip();
- auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp);
+ auto bufferizableOp =
+ analysisState.getOptions().dynCastBufferizableOp(nestedOp);
if (!bufferizableOp)
return WalkResult::skip();
// Find inplacability conflicts and resolve them. (Typically with explicit
// tensor copies in the form of AllocTensorOps.)
rewriter.setInsertionPoint(nestedOp);
- if (failed(bufferizableOp.resolveConflicts(rewriter, state)))
+ if (failed(bufferizableOp.resolveConflicts(rewriter, analysisState,
+ bufferizationState)))
return WalkResult::interrupt();
return WalkResult::advance();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index b6a498a57c036..ce355e96ee694 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -24,10 +24,9 @@ using namespace mlir::bufferization;
namespace {
/// Generic conversion for any DestinationStyleOpInterface on tensors.
-static LogicalResult
-bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
- DestinationStyleOpInterface op,
- const BufferizationOptions &options) {
+static LogicalResult bufferizeDestinationStyleOpInterface(
+ RewriterBase &rewriter, DestinationStyleOpInterface op,
+ const BufferizationOptions &options, BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
@@ -49,7 +48,8 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
newInputBuffers.push_back(opOperand->get());
continue;
}
- FailureOr<Value> buffer = getBuffer(rewriter, opOperand->get(), options);
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, opOperand->get(), options, state);
if (failed(buffer))
return failure();
newInputBuffers.push_back(*buffer);
@@ -60,7 +60,7 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
for (OpResult opResult : op->getOpResults()) {
OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber());
FailureOr<Value> resultBuffer =
- getBuffer(rewriter, opOperand->get(), options);
+ getBuffer(rewriter, opOperand->get(), options, state);
if (failed(resultBuffer))
return failure();
newOutputBuffers.push_back(*resultBuffer);
@@ -76,10 +76,10 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter,
// new op. Since the new op does not have any tensor results, it does not
// return anything.
assert(op->getNumRegions() == 1 && "expected that op has 1 region");
- OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{},
- op->getAttrs());
- state.addRegion();
- Operation *newOp = Operation::create(state);
+ OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{},
+ op->getAttrs());
+ opState.addRegion();
+ Operation *newOp = Operation::create(opState);
newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(),
op->getRegion(0).getBlocks());
@@ -151,7 +151,7 @@ struct LinalgOpInterface
const BufferizationOptions &options,
BufferizationState &state) const {
return bufferizeDestinationStyleOpInterface(
- rewriter, cast<DestinationStyleOpInterface>(op), options);
+ rewriter, cast<DestinationStyleOpInterface>(op), options, state);
}
};
@@ -179,11 +179,11 @@ struct SoftmaxOpInterface
BufferizationState &state) const {
auto softmaxOp = cast<linalg::SoftmaxOp>(op);
FailureOr<Value> inputBuffer =
- getBuffer(rewriter, softmaxOp.getInput(), options);
+ getBuffer(rewriter, softmaxOp.getInput(), options, state);
if (failed(inputBuffer))
return failure();
FailureOr<Value> outputBuffer =
- getBuffer(rewriter, softmaxOp.getOutput(), options);
+ getBuffer(rewriter, softmaxOp.getOutput(), options, state);
if (failed(outputBuffer))
return failure();
rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(),
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
index a69bc9e5088ae..ff6af63eee531 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -138,7 +138,8 @@ struct GlobalStoreOpInterface
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
- auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
+ auto sourceMemref =
+ getBuffer(rewriter, globalStoreOp.getValue(), options, state);
if (failed(sourceMemref)) {
return failure();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 3ff1f5c49aece..59c240c62f934 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -104,11 +104,12 @@ struct ConditionOpInterface
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Value value = it.value();
if (isa<TensorType>(value.getType())) {
- FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, value, options, state);
if (failed(maybeBuffer))
return failure();
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
- whileOp.getAfterArguments()[it.index()], options);
+ whileOp.getAfterArguments()[it.index()], options, state);
if (failed(resultType))
return failure();
Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType);
@@ -196,7 +197,7 @@ struct ExecuteRegionOpInterface
// Bufferize every block.
for (Block &block : newOp.getRegion())
if (failed(bufferization::bufferizeBlockSignature(&block, rewriter,
- options)))
+ options, state)))
return failure();
// Update all uses of the old op.
@@ -251,7 +252,7 @@ struct IfOpInterface
newTypes.push_back(result.getType());
continue;
}
- auto bufferType = bufferization::getBufferType(result, options);
+ auto bufferType = bufferization::getBufferType(result, options, state);
if (failed(bufferType))
return failure();
newTypes.push_back(*bufferType);
@@ -275,6 +276,7 @@ struct IfOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto ifOp = cast<scf::IfOp>(op);
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
@@ -290,8 +292,8 @@ struct IfOpInterface
// True branch was already bufferized.
thenBufferType = cast<BaseMemRefType>(thenValue.getType());
} else {
- auto maybeBufferType =
- bufferization::getBufferType(thenValue, options, invocationStack);
+ auto maybeBufferType = bufferization::getBufferType(
+ thenValue, options, state, invocationStack);
if (failed(maybeBufferType))
return failure();
thenBufferType = *maybeBufferType;
@@ -300,8 +302,8 @@ struct IfOpInterface
// False branch was already bufferized.
elseBufferType = cast<BaseMemRefType>(elseValue.getType());
} else {
- auto maybeBufferType =
- bufferization::getBufferType(elseValue, options, invocationStack);
+ auto maybeBufferType = bufferization::getBufferType(
+ elseValue, options, state, invocationStack);
if (failed(maybeBufferType))
return failure();
elseBufferType = *maybeBufferType;
@@ -362,7 +364,7 @@ struct IndexSwitchOpInterface
newTypes.push_back(result.getType());
continue;
}
- auto bufferType = bufferization::getBufferType(result, options);
+ auto bufferType = bufferization::getBufferType(result, options, state);
if (failed(bufferType))
return failure();
newTypes.push_back(*bufferType);
@@ -390,6 +392,7 @@ struct IndexSwitchOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto switchOp = cast<scf::IndexSwitchOp>(op);
assert(value.getDefiningOp() == op && "invalid value");
@@ -401,8 +404,8 @@ struct IndexSwitchOpInterface
Value yieldedValue = yieldOp->getOperand(resultNum);
if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType()))
return bufferType;
- auto maybeBufferType =
- bufferization::getBufferType(yieldedValue, options, invocationStack);
+ auto maybeBufferType = bufferization::getBufferType(
+ yieldedValue, options, state, invocationStack);
if (failed(maybeBufferType))
return failure();
return maybeBufferType;
@@ -468,12 +471,12 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
/// given OpOperands. If an operand is not a tensor, return the original value.
static FailureOr<SmallVector<Value>>
getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options, BufferizationState &state) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
if (isa<TensorType>(opOperand.get().getType())) {
FailureOr<Value> resultBuffer =
- getBuffer(rewriter, opOperand.get(), options);
+ getBuffer(rewriter, opOperand.get(), options, state);
if (failed(resultBuffer))
return failure();
result.push_back(*resultBuffer);
@@ -521,10 +524,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// layout map and a cast must be inserted.
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
- const BufferizationOptions &options, SmallVector<Value> &invocationStack) {
+ const BufferizationOptions &options, BufferizationState &state,
+ SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
auto initArgBufferType =
- bufferization::getBufferType(initArg, options, invocationStack);
+ bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();
@@ -550,8 +554,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType =
- bufferization::getBufferType(yieldedValue, options, invocationStack);
+ auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+ state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -650,12 +654,14 @@ struct ForOpInterface
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &state) const {
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
- if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
+ if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
+ rewriter, analysisState, bufferizationState)))
return failure();
- if (state.getOptions().copyBeforeWrite)
+ if (analysisState.getOptions().copyBeforeWrite)
return success();
// According to the `getAliasing...` implementations, a bufferized OpResult
@@ -683,12 +689,13 @@ struct ForOpInterface
doesNotAliasExternalValue(
it.value(), &forOp.getRegion(),
/*exceptions=*/forOp.getRegionIterArg(it.index()),
- static_cast<const OneShotAnalysisState &>(state))) {
+ static_cast<const OneShotAnalysisState &>(analysisState))) {
yieldValues.push_back(it.value());
continue;
}
FailureOr<Value> alloc = allocateTensorForShapedValue(
- rewriter, yieldOp.getLoc(), it.value(), state.getOptions());
+ rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(),
+ bufferizationState);
if (failed(alloc))
return failure();
yieldValues.push_back(*alloc);
@@ -701,6 +708,7 @@ struct ForOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
@@ -709,7 +717,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- return bufferization::getBufferType(bbArg, options, invocationStack);
+ return bufferization::getBufferType(bbArg, options, state,
+ invocationStack);
}
// Compute result/argument number.
@@ -722,7 +731,7 @@ struct ForOpInterface
BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
Value initArg = forOp.getInitArgs()[resultNum];
return computeLoopRegionIterArgBufferType(
- op, iterArg, initArg, yieldedValue, options, invocationStack);
+ op, iterArg, initArg, yieldedValue, options, state, invocationStack);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -737,7 +746,7 @@ struct ForOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
- getBuffers(rewriter, forOp.getInitArgsMutable(), options);
+ getBuffers(rewriter, forOp.getInitArgsMutable(), options, state);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
@@ -752,7 +761,7 @@ struct ForOpInterface
castedInitArgs.push_back(initArg);
continue;
}
- auto targetType = bufferization::getBufferType(result, options);
+ auto targetType = bufferization::getBufferType(result, options, state);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -892,12 +901,14 @@ struct WhileOpInterface
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &state) const {
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
- if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
+ if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
+ rewriter, analysisState, bufferizationState)))
return failure();
- if (state.getOptions().copyBeforeWrite)
+ if (analysisState.getOptions().copyBeforeWrite)
return success();
// According to the `getAliasing...` implementations, a bufferized OpResult
@@ -914,9 +925,10 @@ struct WhileOpInterface
// For every yielded value, is the value equivalent to its corresponding
// bbArg?
DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers(
- whileOp.getBeforeArguments(), conditionOp.getArgs(), state);
- DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers(
- whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state);
+ whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState);
+ DenseSet<int64_t> equivalentYieldsAfter =
+ getEquivalentBuffers(whileOp.getAfterArguments(),
+ whileOp.getYieldOp().getResults(), analysisState);
// Update "before" region.
rewriter.setInsertionPoint(conditionOp);
@@ -931,7 +943,8 @@ struct WhileOpInterface
continue;
}
FailureOr<Value> alloc = allocateTensorForShapedValue(
- rewriter, conditionOp.getLoc(), value, state.getOptions());
+ rewriter, conditionOp.getLoc(), value, analysisState.getOptions(),
+ bufferizationState);
if (failed(alloc))
return failure();
beforeYieldValues.push_back(*alloc);
@@ -956,7 +969,7 @@ struct WhileOpInterface
// The new memref init_args of the loop.
FailureOr<SmallVector<Value>> maybeInitArgs =
- getBuffers(rewriter, whileOp.getInitsMutable(), options);
+ getBuffers(rewriter, whileOp.getInitsMutable(), options, state);
if (failed(maybeInitArgs))
return failure();
SmallVector<Value> initArgs = *maybeInitArgs;
@@ -971,7 +984,7 @@ struct WhileOpInterface
castedInitArgs.push_back(initArg);
continue;
}
- auto targetType = bufferization::getBufferType(beforeArg, options);
+ auto targetType = bufferization::getBufferType(beforeArg, options, state);
if (failed(targetType))
return failure();
castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType));
@@ -984,7 +997,7 @@ struct WhileOpInterface
return bbArg.getType();
// TODO: error handling
return llvm::cast<Type>(
- *bufferization::getBufferType(bbArg, options));
+ *bufferization::getBufferType(bbArg, options, state));
}));
// Construct a new scf.while op with memref instead of tensor values.
@@ -1029,6 +1042,7 @@ struct WhileOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
@@ -1041,7 +1055,7 @@ struct WhileOpInterface
auto yieldOp = whileOp.getYieldOp();
Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber());
return computeLoopRegionIterArgBufferType(
- op, bbArg, initArg, yieldedValue, options, invocationStack);
+ op, bbArg, initArg, yieldedValue, options, state, invocationStack);
}
}
@@ -1062,7 +1076,7 @@ struct WhileOpInterface
// scf.condition was already bufferized.
return cast<BaseMemRefType>(conditionYieldedVal.getType());
}
- return bufferization::getBufferType(conditionYieldedVal, options,
+ return bufferization::getBufferType(conditionYieldedVal, options, state,
invocationStack);
}
@@ -1161,7 +1175,8 @@ struct YieldOpInterface
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
if (isa<TensorType>(value.getType())) {
- FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, value, options, state);
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
@@ -1169,14 +1184,14 @@ struct YieldOpInterface
if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(
yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
- yieldOp->getParentOp()->getResult(it.index()), options);
+ yieldOp->getParentOp()->getResult(it.index()), options, state);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
} else if (auto whileOp =
dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) {
FailureOr<BaseMemRefType> resultType = bufferization::getBufferType(
- whileOp.getBeforeArguments()[it.index()], options);
+ whileOp.getBeforeArguments()[it.index()], options, state);
if (failed(resultType))
return failure();
buffer = castBuffer(rewriter, buffer, *resultType);
@@ -1236,7 +1251,7 @@ struct ForallOpInterface
// Get buffers for all output operands.
SmallVector<Value> buffers;
for (Value out : forallOp.getOutputs()) {
- FailureOr<Value> buffer = getBuffer(rewriter, out, options);
+ FailureOr<Value> buffer = getBuffer(rewriter, out, options, state);
if (failed(buffer))
return failure();
buffers.push_back(*buffer);
@@ -1283,6 +1298,7 @@ struct ForallOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto forallOp = cast<ForallOp>(op);
@@ -1290,13 +1306,14 @@ struct ForallOpInterface
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
return bufferization::getBufferType(
- forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack);
+ forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+ invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- invocationStack);
+ state, invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index e8cab76d3c753..dc91117a51936 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -119,7 +119,7 @@ struct AssumingYieldOpInterface
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
if (isa<TensorType>(value.getType())) {
- FailureOr<Value> buffer = getBuffer(rewriter, value, options);
+ FailureOr<Value> buffer = getBuffer(rewriter, value, options, state);
if (failed(buffer))
return failure();
newResults.push_back(*buffer);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 7c7c64f2aef01..a3ab53d818115 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -152,8 +152,10 @@ class SparsificationAndBufferizationPass
// invalidate the results of the analysis. From now on, only small and
// localized rewrites are allowed, such as replacing a tensor op with its
// memref equivalent.
- if (failed(bufferization::insertTensorCopies(getOperation(),
- bufferizationOptions)))
+ bufferization::BufferizationState bufferizationState;
+
+ if (failed(bufferization::insertTensorCopies(
+ getOperation(), bufferizationOptions, bufferizationState)))
return signalPassFailure();
// Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 630e970cd4b19..154f12b31fc70 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -51,10 +51,11 @@ struct CastOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto castOp = cast<tensor::CastOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
- castOp.getSource(), options, invocationStack);
+ castOp.getSource(), options, state, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
Attribute memorySpace = maybeSrcBufferType->getMemorySpace();
@@ -89,13 +90,13 @@ struct CastOpInterface
// The result buffer still has the old (pre-cast) type.
FailureOr<Value> resultBuffer =
- getBuffer(rewriter, castOp.getSource(), options);
+ getBuffer(rewriter, castOp.getSource(), options, state);
if (failed(resultBuffer))
return failure();
// Compute the new type.
auto resultMemRefType =
- bufferization::getBufferType(castOp.getResult(), options);
+ bufferization::getBufferType(castOp.getResult(), options, state);
if (failed(resultMemRefType))
return failure();
if (resultBuffer->getType() == *resultMemRefType) {
@@ -141,10 +142,11 @@ struct CollapseShapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
- collapseShapeOp.getSrc(), options, invocationStack);
+ collapseShapeOp.getSrc(), options, state, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -168,7 +170,7 @@ struct CollapseShapeOpInterface
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
FailureOr<Value> maybeBuffer =
- getBuffer(rewriter, collapseShapeOp.getSrc(), options);
+ getBuffer(rewriter, collapseShapeOp.getSrc(), options, state);
if (failed(maybeBuffer))
return failure();
Value buffer = *maybeBuffer;
@@ -210,7 +212,7 @@ struct CollapseShapeOpInterface
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
AnalysisState analysisState(options);
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
- rewriter, op->getLoc(), collapseShapeOp.getSrc(), options);
+ rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state);
if (failed(tensorAlloc))
return failure();
auto memrefType =
@@ -252,7 +254,7 @@ struct DimOpInterface
const BufferizationOptions &options,
BufferizationState &state) const {
auto dimOp = cast<tensor::DimOp>(op);
- FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options);
+ FailureOr<Value> v = getBuffer(rewriter, dimOp.getSource(), options, state);
if (failed(v))
return failure();
replaceOpWithNewBufferizedOp<memref::DimOp>(rewriter, op, *v,
@@ -286,7 +288,8 @@ struct EmptyOpInterface
// Allocate a tensor. This emits a "bufferization.alloc_tensor" op.
FailureOr<Value> allocTensor = allocateTensorForShapedValue(
- rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false);
+ rewriter, op->getLoc(), emptyOp.getResult(), options, state,
+ /*copy=*/false);
if (failed(allocTensor))
return failure();
rewriter.replaceOp(op, *allocTensor);
@@ -317,10 +320,11 @@ struct ExpandShapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
- expandShapeOp.getSrc(), options, invocationStack);
+ expandShapeOp.getSrc(), options, state, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
auto srcBufferType = llvm::cast<MemRefType>(*maybeSrcBufferType);
@@ -338,7 +342,7 @@ struct ExpandShapeOpInterface
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto tensorResultType = expandShapeOp.getResultType();
FailureOr<Value> buffer =
- getBuffer(rewriter, expandShapeOp.getSrc(), options);
+ getBuffer(rewriter, expandShapeOp.getSrc(), options, state);
if (failed(buffer))
return failure();
@@ -382,13 +386,13 @@ struct ExtractSliceOpInterface
// Get source buffer.
FailureOr<Value> srcMemref =
- getBuffer(rewriter, extractSliceOp.getSource(), options);
+ getBuffer(rewriter, extractSliceOp.getSource(), options, state);
if (failed(srcMemref))
return failure();
// Take a subview of the source buffer.
- auto resultMemrefType =
- bufferization::getBufferType(extractSliceOp.getResult(), options);
+ auto resultMemrefType = bufferization::getBufferType(
+ extractSliceOp.getResult(), options, state);
if (failed(resultMemrefType))
return failure();
Value subView = rewriter.create<memref::SubViewOp>(
@@ -401,11 +405,12 @@ struct ExtractSliceOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
assert(value == extractSliceOp.getResult() && "invalid value");
auto srcMemrefType = bufferization::getBufferType(
- extractSliceOp.getSource(), options, invocationStack);
+ extractSliceOp.getSource(), options, state, invocationStack);
if (failed(srcMemrefType))
return failure();
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
@@ -442,7 +447,7 @@ struct ExtractOpInterface
BufferizationState &state) const {
auto extractOp = cast<tensor::ExtractOp>(op);
FailureOr<Value> srcMemref =
- getBuffer(rewriter, extractOp.getTensor(), options);
+ getBuffer(rewriter, extractOp.getTensor(), options, state);
if (failed(srcMemref))
return failure();
replaceOpWithNewBufferizedOp<memref::LoadOp>(rewriter, op, *srcMemref,
@@ -491,12 +496,12 @@ struct FromElementsOpInterface
auto shape = tensorType.getShape();
// TODO: Create alloc_tensor ops during TensorCopyInsertion.
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
- rewriter, loc, fromElementsOp.getResult(), options,
+ rewriter, loc, fromElementsOp.getResult(), options, state,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
FailureOr<BaseMemRefType> memrefType =
- bufferization::getBufferType(*tensorAlloc, options);
+ bufferization::getBufferType(*tensorAlloc, options, state);
if (failed(memrefType))
return failure();
Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -607,7 +612,7 @@ struct GenerateOpInterface
// Allocate memory.
Location loc = op->getLoc();
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
- rewriter, loc, generateOp.getResult(), options,
+ rewriter, loc, generateOp.getResult(), options, state,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
@@ -633,7 +638,7 @@ struct InsertOpInterface
BufferizationState &state) const {
auto insertOp = cast<tensor::InsertOp>(op);
FailureOr<Value> destMemref =
- getBuffer(rewriter, insertOp.getDest(), options);
+ getBuffer(rewriter, insertOp.getDest(), options, state);
if (failed(destMemref))
return failure();
rewriter.create<memref::StoreOp>(insertOp.getLoc(), insertOp.getScalar(),
@@ -695,7 +700,7 @@ struct InsertSliceOpInterface
// Get destination buffer.
FailureOr<Value> dstMemref =
- getBuffer(rewriter, insertSliceOp.getDest(), options);
+ getBuffer(rewriter, insertSliceOp.getDest(), options, state);
if (failed(dstMemref))
return failure();
@@ -712,7 +717,7 @@ struct InsertSliceOpInterface
// Copy tensor. If this tensor.insert_slice has a matching
// tensor.extract_slice, the copy operation will eventually fold away.
FailureOr<Value> srcMemref =
- getBuffer(rewriter, insertSliceOp.getSource(), options);
+ getBuffer(rewriter, insertSliceOp.getSource(), options, state);
if (failed(srcMemref))
return failure();
if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView)))
@@ -749,11 +754,12 @@ struct PadOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Infer memory space from the source tensor.
auto padOp = cast<tensor::PadOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
- padOp.getSource(), options, invocationStack);
+ padOp.getSource(), options, state, invocationStack);
if (failed(maybeSrcBufferType))
return failure();
MemRefLayoutAttrInterface layout;
@@ -797,9 +803,9 @@ struct PadOpInterface
}
// Allocate a buffer for the padded result.
- FailureOr<Value> tensorAlloc =
- allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options,
- /*copy=*/false);
+ FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
+ rewriter, loc, padOp.getResult(), options, state,
+ /*copy=*/false);
if (failed(tensorAlloc))
return failure();
@@ -846,7 +852,8 @@ struct RankOpInterface
const BufferizationOptions &options,
BufferizationState &state) const {
auto rankOp = cast<tensor::RankOp>(op);
- FailureOr<Value> v = getBuffer(rewriter, rankOp.getTensor(), options);
+ FailureOr<Value> v =
+ getBuffer(rewriter, rankOp.getTensor(), options, state);
if (failed(v))
return failure();
replaceOpWithNewBufferizedOp<memref::RankOp>(rewriter, op, rankOp.getType(),
@@ -885,13 +892,13 @@ struct ReshapeOpInterface
BufferizationState &state) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
FailureOr<Value> srcBuffer =
- getBuffer(rewriter, reshapeOp.getSource(), options);
+ getBuffer(rewriter, reshapeOp.getSource(), options, state);
FailureOr<Value> shapeBuffer =
- getBuffer(rewriter, reshapeOp.getShape(), options);
+ getBuffer(rewriter, reshapeOp.getShape(), options, state);
if (failed(srcBuffer) || failed(shapeBuffer))
return failure();
auto maybeResultMemRefType =
- bufferization::getBufferType(reshapeOp.getResult(), options);
+ bufferization::getBufferType(reshapeOp.getResult(), options, state);
if (failed(maybeResultMemRefType))
return failure();
@@ -901,7 +908,7 @@ struct ReshapeOpInterface
auto srcType = llvm::dyn_cast<MemRefType>(srcBuffer->getType());
if (srcType && !srcType.getLayout().isIdentity()) {
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
- rewriter, op->getLoc(), reshapeOp.getSource(), options);
+ rewriter, op->getLoc(), reshapeOp.getSource(), options, state);
if (failed(tensorAlloc))
return failure();
auto memrefType = MemRefType::get(
@@ -920,11 +927,12 @@ struct ReshapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
assert(value == reshapeOp.getResult() && "unexpected value provided");
auto maybeSourceBufferType = bufferization::getBufferType(
- reshapeOp.getSource(), options, invocationStack);
+ reshapeOp.getSource(), options, state, invocationStack);
if (failed(maybeSourceBufferType))
return failure();
return getMemRefTypeWithStaticIdentityLayout(
@@ -966,11 +974,11 @@ struct ParallelInsertSliceOpInterface
// Get source and destination buffers.
FailureOr<Value> destBuffer =
- getBuffer(rewriter, parallelInsertSliceOp.getDest(), options);
+ getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state);
if (failed(destBuffer))
return failure();
FailureOr<Value> srcBuffer =
- getBuffer(rewriter, parallelInsertSliceOp.getSource(), options);
+ getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state);
if (failed(srcBuffer))
return failure();
@@ -1016,7 +1024,8 @@ struct ParallelInsertSliceOpInterface
/// tensor.parallel_insert_slice op has implicit inplace behavior. We
/// shouldn't create copy to resolve conflict.
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &state) const {
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) const {
return success();
}
};
@@ -1038,7 +1047,7 @@ struct SplatOpInterface
// Allocate memory.
Location loc = op->getLoc();
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
- rewriter, loc, splatOp.getResult(), options,
+ rewriter, loc, splatOp.getResult(), options, state,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
@@ -1097,7 +1106,7 @@ struct ConcatOpInterface
// Allocate memory.
Location loc = op->getLoc();
FailureOr<Value> tensorAlloc = allocateTensorForShapedValue(
- rewriter, loc, concatOp.getResult(), options,
+ rewriter, loc, concatOp.getResult(), options, state,
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
@@ -1147,7 +1156,7 @@ struct ConcatOpInterface
for (auto operand : concatOp.getInputs()) {
// Get the buffer for the operand.
- FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options);
+ FailureOr<Value> srcBuffer = getBuffer(rewriter, operand, options, state);
if (failed(srcBuffer))
return failure();
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index 45b6e7c512947..a94a1d3567573 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -53,7 +53,8 @@ struct TransferReadOpInterface
auto readOp = cast<vector::TransferReadOp>(op);
assert(isa<TensorType>(readOp.getShapedType()) &&
"only tensor types expected");
- FailureOr<Value> buffer = getBuffer(rewriter, readOp.getBase(), options);
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, readOp.getBase(), options, state);
if (failed(buffer))
return failure();
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
@@ -112,7 +113,7 @@ struct TransferWriteOpInterface
// Create a new transfer_write on buffer that doesn't have a return value.
FailureOr<Value> resultBuffer =
- getBuffer(rewriter, writeOp.getBase(), options);
+ getBuffer(rewriter, writeOp.getBase(), options, state);
if (failed(resultBuffer))
return failure();
rewriter.create<vector::TransferWriteOp>(
@@ -155,7 +156,8 @@ struct GatherOpInterface
auto gatherOp = cast<vector::GatherOp>(op);
assert(isa<TensorType>(gatherOp.getBaseType()) &&
"only tensor types expected");
- FailureOr<Value> buffer = getBuffer(rewriter, gatherOp.getBase(), options);
+ FailureOr<Value> buffer =
+ getBuffer(rewriter, gatherOp.getBase(), options, state);
if (failed(buffer))
return failure();
replaceOpWithNewBufferizedOp<vector::GatherOp>(
@@ -185,9 +187,11 @@ struct MaskOpInterface
}
LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &state) const {
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
- if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state)))
+ if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
+ rewriter, analysisState, bufferizationState)))
return failure();
// TODO: Remove this function when vector.mask bodies can bufferize
@@ -302,7 +306,8 @@ struct YieldOpInterface
SmallVector<Value> newResults;
for (Value value : yieldOp.getOperands()) {
if (isa<TensorType>(value.getType())) {
- FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options);
+ FailureOr<Value> maybeBuffer =
+ getBuffer(rewriter, value, options, state);
if (failed(maybeBuffer))
return failure();
newResults.push_back(*maybeBuffer);
diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
index 2991a3c165ee2..dfaebccde7dcc 100644
--- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
+++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp
@@ -48,7 +48,11 @@ struct TestTensorCopyInsertionPass
options.defaultMemorySpaceFn =
[](TensorType t) -> std::optional<Attribute> { return std::nullopt; };
}
- if (failed(bufferization::insertTensorCopies(getOperation(), options)))
+
+ bufferization::BufferizationState bufferizationState;
+
+ if (failed(bufferization::insertTensorCopies(getOperation(), options,
+ bufferizationState)))
signalPassFailure();
}
>From fb6838eaef09776eddf971ae1cc8c796c01c96de Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Mon, 26 May 2025 15:46:37 +0200
Subject: [PATCH 2/3] Fix code format
---
.../Dialect/Bufferization/IR/UnstructuredControlFlow.h | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 7c07f705c8435..00f7799fc18c6 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,7 +34,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
@@ -82,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (bufferType == callerType)
continue;
- // If the computed buffer type does not match the computed buffer type
- // of the earlier forwarded operands, fall back to a buffer type with a
- // fully dynamic layout map.
+ // If the computed buffer type does not match the computed buffer type
+ // of the earlier forwarded operands, fall back to a buffer type with a
+ // fully dynamic layout map.
#ifndef NDEBUG
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
assert(bufferType.hasRank() && callerType.hasRank() &&
>From 3a59029fe4eff85f5d6a601123d6cb27dd9f3065 Mon Sep 17 00:00:00 2001
From: Michele Scuttari <michele.scuttari at outlook.com>
Date: Tue, 27 May 2025 11:27:36 +0200
Subject: [PATCH 3/3] Make references to BufferizationState constant
---
.../IR/BufferizableOpInterface.h | 10 +++----
.../IR/BufferizableOpInterface.td | 6 ++---
.../Bufferization/IR/BufferizationOps.td | 4 +--
.../IR/UnstructuredControlFlow.h | 2 +-
.../Bufferization/Transforms/Transforms.h | 4 +--
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../IR/BufferizableOpInterface.cpp | 13 +++++-----
.../Bufferization/IR/BufferizationOps.cpp | 2 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 4 +--
.../Transforms/TensorCopyInsertion.cpp | 4 +--
.../BufferizableOpInterfaceImpl.cpp | 2 +-
.../BufferizableOpInterfaceImpl.cpp | 26 ++++++++++---------
.../BufferizableOpInterfaceImpl.cpp | 19 +++++++-------
.../BufferizableOpInterfaceImpl.cpp | 7 ++---
14 files changed, 55 insertions(+), 50 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 328d928c9ebdb..adccbef754ec5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,14 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
- BufferizationState &state, bool copy = true);
+ const BufferizationState &state, bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
- BufferizationState &state);
+ const BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
@@ -617,7 +617,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
- BufferizationState &state);
+ const BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
@@ -631,7 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack);
/// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack);
/// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 80f9b72531660..5607df8c96039 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -382,7 +382,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"const ::mlir::bufferization::AnalysisState &":$analysisState,
- "::mlir::bufferization::BufferizationState &":$bufferizationState),
+ "const ::mlir::bufferization::BufferizationState &":$bufferizationState),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
@@ -524,7 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
- "::mlir::bufferization::BufferizationState &":$state,
+ "const ::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -619,7 +619,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::AnalysisState &analysisState,
- ::mlir::bufferization::BufferizationState &bufferizationState);
+ const ::mlir::bufferization::BufferizationState &bufferizationState);
/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 0ee4f79144158..3d4dcdee2663b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,7 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack);
RankedTensorType getType() {
@@ -472,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- BufferizationState &state, SmallVector<Value> &invocationStack) {
+ const BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 00f7799fc18c6..a441b8b66659e 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,7 +34,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index e587753ddebee..e17d5264a1a45 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,7 +75,7 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
- BufferizationState &bufferizationState,
+ const BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
@@ -83,7 +83,7 @@ LogicalResult insertTensorCopies(Operation *op,
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const AnalysisState &analysisState,
- BufferizationState &bufferizationState);
+ const BufferizationState &bufferizationState);
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 0389a984e169c..a57d58ab28d28 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -183,7 +183,7 @@ struct SelectOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 7d67d4a33ac32..1d6e1bdaf80f5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
- const BufferizationOptions &options, BufferizationState &state, bool copy) {
+ const BufferizationOptions &options, const BufferizationState &state,
+ bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
@@ -224,7 +225,7 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
RewriterBase &rewriter, const AnalysisState &analysisState,
- BufferizationState &bufferizationState) {
+ const BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -670,7 +671,7 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
const BufferizationOptions &options,
- BufferizationState &state) {
+ const BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
@@ -695,7 +696,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
- BufferizationState &state) {
+ const BufferizationState &state) {
SmallVector<Value> invocationStack;
return getBufferType(value, options, state, invocationStack);
}
@@ -703,7 +704,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
@@ -951,7 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
- BufferizationState &bufferizationState,
+ const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 41b86437e11cf..dc54ac94aed32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -224,7 +224,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
FailureOr<BaseMemRefType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(value == getResult() && "invalid value");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 9a0a85a71debd..a0168da44b7b3 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -213,7 +213,7 @@ struct CallOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto callOp = cast<func::CallOp>(op);
@@ -398,7 +398,7 @@ struct FuncOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto funcOp = cast<FuncOp>(op);
auto bbArg = cast<BlockArgument>(value);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index d971ed5b0f71c..784d95a5dd22a 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -28,7 +28,7 @@ using namespace mlir::bufferization;
LogicalResult mlir::bufferization::insertTensorCopies(
Operation *op, const OneShotBufferizationOptions &options,
- BufferizationState &bufferizationState,
+ const BufferizationState &bufferizationState,
BufferizationStatistics *statistics) {
OneShotAnalysisState analysisState(op, options);
// Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize
@@ -50,7 +50,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
LogicalResult mlir::bufferization::insertTensorCopies(
Operation *op, const AnalysisState &analysisState,
- BufferizationState &bufferizationState) {
+ const BufferizationState &bufferizationState) {
IRRewriter rewriter(op->getContext());
// It may be more efficient to walk in pre-order here, but the current
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index ce355e96ee694..9044d89c80bd6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ namespace {
/// Generic conversion for any DestinationStyleOpInterface on tensors.
static LogicalResult bufferizeDestinationStyleOpInterface(
RewriterBase &rewriter, DestinationStyleOpInterface op,
- const BufferizationOptions &options, BufferizationState &state) {
+ const BufferizationOptions &options, const BufferizationState &state) {
// Take a guard before anything else.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(op);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 59c240c62f934..46fa77a7dc4e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -276,7 +276,7 @@ struct IfOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto ifOp = cast<scf::IfOp>(op);
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto switchOp = cast<scf::IndexSwitchOp>(op);
assert(value.getDefiningOp() == op && "invalid value");
@@ -524,7 +524,7 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// layout map and a cast must be inserted.
static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
- const BufferizationOptions &options, BufferizationState &state,
+ const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
auto initArgBufferType =
@@ -653,9 +653,10 @@ struct ForOpInterface
return true;
}
- LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &analysisState,
- BufferizationState &bufferizationState) const {
+ LogicalResult
+ resolveConflicts(Operation *op, RewriterBase &rewriter,
+ const AnalysisState &analysisState,
+ const BufferizationState &bufferizationState) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
rewriter, analysisState, bufferizationState)))
@@ -708,7 +709,7 @@ struct ForOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
@@ -900,9 +901,10 @@ struct WhileOpInterface
return true;
}
- LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &analysisState,
- BufferizationState &bufferizationState) const {
+ LogicalResult
+ resolveConflicts(Operation *op, RewriterBase &rewriter,
+ const AnalysisState &analysisState,
+ const BufferizationState &bufferizationState) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
rewriter, analysisState, bufferizationState)))
@@ -1042,7 +1044,7 @@ struct WhileOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
@@ -1298,7 +1300,7 @@ struct ForallOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto forallOp = cast<ForallOp>(op);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 154f12b31fc70..4b778b768d136 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -51,7 +51,7 @@ struct CastOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto castOp = cast<tensor::CastOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
@@ -142,7 +142,7 @@ struct CollapseShapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
@@ -320,7 +320,7 @@ struct ExpandShapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
auto maybeSrcBufferType = bufferization::getBufferType(
@@ -405,7 +405,7 @@ struct ExtractSliceOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
assert(value == extractSliceOp.getResult() && "invalid value");
@@ -754,7 +754,7 @@ struct PadOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Infer memory space from the source tensor.
auto padOp = cast<tensor::PadOp>(op);
@@ -927,7 +927,7 @@ struct ReshapeOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
- BufferizationState &state,
+ const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto reshapeOp = cast<tensor::ReshapeOp>(op);
assert(value == reshapeOp.getResult() && "unexpected value provided");
@@ -1023,9 +1023,10 @@ struct ParallelInsertSliceOpInterface
/// tensor.parallel_insert_slice op has implicit inplace behavior. We
/// shouldn't create copy to resolve conflict.
- LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &analysisState,
- BufferizationState &bufferizationState) const {
+ LogicalResult
+ resolveConflicts(Operation *op, RewriterBase &rewriter,
+ const AnalysisState &analysisState,
+ const BufferizationState &bufferizationState) const {
return success();
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
index a94a1d3567573..9da051150e409 100644
--- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -186,9 +186,10 @@ struct MaskOpInterface
return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}};
}
- LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter,
- const AnalysisState &analysisState,
- BufferizationState &bufferizationState) const {
+ LogicalResult
+ resolveConflicts(Operation *op, RewriterBase &rewriter,
+ const AnalysisState &analysisState,
+ const BufferizationState &bufferizationState) const {
auto bufferizableOp = cast<BufferizableOpInterface>(op);
if (failed(bufferizableOp.resolveTensorOpOperandConflicts(
rewriter, analysisState, bufferizationState)))
More information about the Mlir-commits
mailing list