[Mlir-commits] [mlir] [MLIR] Add bufferization state to `getBufferType` and `resolveConflicts` interface methods (PR #141466)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon May 26 02:32:21 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-shape
Author: Michele Scuttari (mscuttari)
<details>
<summary>Changes</summary>
The PR continues the work started in #<!-- -->141019 by adding the `BufferizationState` class also to the `getBufferType` and `resolveConflicts` interface methods, together with the additional support functions that are used throughout the bufferization infrastructure.
---
Patch is 76.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141466.diff
22 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+7-3)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+7-4)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+2-1)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h (+2-1)
- (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h (+4-1)
- (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+10-8)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+30-21)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10-7)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+6-5)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-5)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+1-1)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+3-2)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp (+12-9)
- (modified) mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp (+13-13)
- (modified) mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp (+2-1)
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+61-44)
- (modified) mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+4-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+46-37)
- (modified) mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp (+11-6)
- (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+5-1)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 43c97d57e1834..328d928c9ebdb 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -598,13 +598,14 @@ class BufferizationState {
FailureOr<Value>
allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue,
const BufferizationOptions &options,
- bool copy = true);
+ BufferizationState &state, bool copy = true);
/// Lookup the buffer for the given value. If the value was not bufferized
/// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp,
/// from which the memref operand is returned.
FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR.
@@ -615,7 +616,8 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
FailureOr<BaseMemRefType> getBufferType(Value value,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
/// Return the buffer type for a given Value (tensor) after bufferization
/// without bufferizing any IR. This function (and not the other overload
@@ -629,6 +631,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
FailureOr<BaseMemRefType> getBufferType(Value value,
const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// Return "true" if the given op has tensor semantics and should be bufferized.
@@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// places.
FailureOr<BaseMemRefType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
/// This is the default implementation of
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index b599a9f053215..80f9b72531660 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"resolveConflicts",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::AnalysisState &":$state),
+ "const ::mlir::bufferization::AnalysisState &":$analysisState,
+ "::mlir::bufferization::BufferizationState &":$bufferizationState),
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
- rewriter, state);
+ rewriter, analysisState, bufferizationState);
}]
>,
InterfaceMethod<
@@ -523,6 +524,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state,
"::llvm::SmallVector<::mlir::Value> &":$invocationStack),
/*methodBody=*/"",
/*defaultImplementation=*/[{
@@ -531,7 +533,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
- value, options, invocationStack);
+ value, options, state, invocationStack);
}]
>,
InterfaceMethod<
@@ -616,7 +618,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/// form of `bufferization.alloc_tensor` ops.
::llvm::LogicalResult resolveTensorOpOperandConflicts(
::mlir::RewriterBase &rewriter,
- const ::mlir::bufferization::AnalysisState &state);
+ const ::mlir::bufferization::AnalysisState &analysisState,
+ ::mlir::bufferization::BufferizationState &bufferizationState);
/// Return `true` if the given OpOperand creates an alias but does neither
/// read nor write. This implies that `bufferizesToMemoryRead` and
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index dafa4b9b183f2..0ee4f79144158 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack);
RankedTensorType getType() {
@@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
FailureOr<BaseMemRefType> getBufferType(
Value value, const BufferizationOptions &options,
- SmallVector<Value> &invocationStack) {
+ BufferizationState &state, SmallVector<Value> &invocationStack) {
return ::llvm::cast<BaseMemRefType>(getMemref().getType());
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index cf86b9a23f59e..7c07f705c8435 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
// case the bufferized result type is different from the bufferized type of
// the aliasing OpOperand (if any).
if (isa<OpResult>(value))
- return bufferization::detail::defaultGetBufferType(value, options,
+ return bufferization::detail::defaultGetBufferType(value, options, state,
invocationStack);
// Compute the buffer type of the block argument by computing the bufferized
@@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
callerType = memrefType;
} else {
FailureOr<BaseMemRefType> maybeCallerType =
- bufferization::getBufferType(opOperand->get(), options,
+ bufferization::getBufferType(opOperand->get(), options, state,
invocationStack);
if (failed(maybeCallerType))
return failure();
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
index 70e3defee0867..c1f5654abbf9b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h
@@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
/// `BufferizableOpInterface`. The buffer types of tensor block arguments are
/// computed with `BufferizableOpIntercace::getBufferType`.
LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter,
- const BufferizationOptions &options);
+ const BufferizationOptions &options,
+ BufferizationState &state);
} // namespace bufferization
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
index a4ee893ca5341..e587753ddebee 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h
@@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op);
/// additional buffer allocations.
LogicalResult insertTensorCopies(Operation *op,
const OneShotBufferizationOptions &options,
+ BufferizationState &bufferizationState,
BufferizationStatistics *statistics = nullptr);
/// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops.
/// After applying this transform, the IR can be bufferized without inserting
/// additional buffer allocations.
-LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state);
+LogicalResult insertTensorCopies(Operation *op,
+ const AnalysisState &analysisState,
+ BufferizationState &bufferizationState);
/// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor
/// ops.
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index f646326ffc58f..0389a984e169c 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -90,7 +90,8 @@ struct IndexCastOpInterface
auto castOp = cast<arith::IndexCastOp>(op);
auto resultTensorType = cast<TensorType>(castOp.getType());
- FailureOr<Value> source = getBuffer(rewriter, castOp.getIn(), options);
+ FailureOr<Value> source =
+ getBuffer(rewriter, castOp.getIn(), options, state);
if (failed(source))
return failure();
auto sourceType = cast<BaseMemRefType>(source->getType());
@@ -151,9 +152,9 @@ struct SelectOpInterface
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
FailureOr<Value> maybeTrueBuffer =
- getBuffer(rewriter, selectOp.getTrueValue(), options);
+ getBuffer(rewriter, selectOp.getTrueValue(), options, state);
FailureOr<Value> maybeFalseBuffer =
- getBuffer(rewriter, selectOp.getFalseValue(), options);
+ getBuffer(rewriter, selectOp.getFalseValue(), options, state);
if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer))
return failure();
Value trueBuffer = *maybeTrueBuffer;
@@ -164,7 +165,7 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto targetType =
- bufferization::getBufferType(selectOp.getResult(), options);
+ bufferization::getBufferType(selectOp.getResult(), options, state);
if (failed(targetType))
return failure();
if (trueBuffer.getType() != *targetType)
@@ -182,13 +183,14 @@ struct SelectOpInterface
FailureOr<BaseMemRefType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
assert(value == selectOp.getResult() && "invalid value");
- auto trueType = bufferization::getBufferType(selectOp.getTrueValue(),
- options, invocationStack);
- auto falseType = bufferization::getBufferType(selectOp.getFalseValue(),
- options, invocationStack);
+ auto trueType = bufferization::getBufferType(
+ selectOp.getTrueValue(), options, state, invocationStack);
+ auto falseType = bufferization::getBufferType(
+ selectOp.getFalseValue(), options, state, invocationStack);
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 14fa4c1ed8159..7d67d4a33ac32 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -165,7 +165,7 @@ Operation *bufferization::getOwnerOfValue(Value value) {
/// allocated.
FailureOr<Value> bufferization::allocateTensorForShapedValue(
OpBuilder &b, Location loc, Value shapedValue,
- const BufferizationOptions &options, bool copy) {
+ const BufferizationOptions &options, BufferizationState &state, bool copy) {
Value tensor;
if (llvm::isa<RankedTensorType>(shapedValue.getType())) {
tensor = shapedValue;
@@ -210,7 +210,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BaseMemRefType> copyBufferType =
+ getBufferType(tensor, options, state);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -222,7 +223,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
}
LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
- RewriterBase &rewriter, const AnalysisState &state) {
+ RewriterBase &rewriter, const AnalysisState &analysisState,
+ BufferizationState &bufferizationState) {
OpBuilder::InsertionGuard g(rewriter);
Operation *op = getOperation();
SmallVector<OpOperand *> outOfPlaceOpOperands;
@@ -235,16 +237,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
Type operandType = opOperand.get().getType();
if (!llvm::isa<TensorType>(operandType))
continue;
- if (state.isInPlace(opOperand))
+ if (analysisState.isInPlace(opOperand))
continue;
if (llvm::isa<UnrankedTensorType>(operandType))
return op->emitError("copying of unranked tensors is not implemented");
- AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
+ AliasingValueList aliasingValues =
+ analysisState.getAliasingValues(opOperand);
if (aliasingValues.getNumAliases() == 1 &&
isa<OpResult>(aliasingValues.getAliases()[0].value) &&
- !state.bufferizesToMemoryWrite(opOperand) &&
- state.getAliasingOpOperands(aliasingValues.getAliases()[0].value)
+ !analysisState.bufferizesToMemoryWrite(opOperand) &&
+ analysisState
+ .getAliasingOpOperands(aliasingValues.getAliases()[0].value)
.getNumAliases() == 1 &&
!isa<UnrankedTensorType>(
aliasingValues.getAliases()[0].value.getType())) {
@@ -256,12 +260,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// cannot be copied at the moment).
Value value = aliasingValues.getAliases()[0].value;
outOfPlaceValues.push_back(value);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpValues.insert(value);
} else {
// In all other cases, make a copy of the OpOperand.
outOfPlaceOpOperands.push_back(&opOperand);
- if (!state.canOmitTensorCopy(opOperand))
+ if (!analysisState.canOmitTensorCopy(opOperand))
copiedOpOperands.insert(&opOperand);
}
}
@@ -270,8 +274,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPoint(op);
for (OpOperand *opOperand : outOfPlaceOpOperands) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), opOperand->get(), state.getOptions(),
- copiedOpOperands.contains(opOperand));
+ rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(),
+ bufferizationState, copiedOpOperands.contains(opOperand));
if (failed(copy))
return failure();
rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); });
@@ -281,8 +285,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
rewriter.setInsertionPointAfter(op);
for (Value value : outOfPlaceValues) {
FailureOr<Value> copy = allocateTensorForShapedValue(
- rewriter, op->getLoc(), value, state.getOptions(),
- copiedOpValues.count(value));
+ rewriter, op->getLoc(), value, analysisState.getOptions(),
+ bufferizationState, copiedOpValues.count(value));
if (failed(copy))
return failure();
SmallVector<OpOperand *> uses = llvm::to_vector(
@@ -665,7 +669,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) {
}
FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
- const BufferizationOptions &options) {
+ const BufferizationOptions &options,
+ BufferizationState &state) {
#ifndef NDEBUG
auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
assert(tensorType && "unexpected non-tensor type");
@@ -678,7 +683,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
// Insert to_buffer op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, value);
- FailureOr<BaseMemRefType> memrefType = getBufferType(value, options);
+ FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
if (failed(memrefType))
return failure();
ensureToBufferOpIsValid(value, *memrefType);
@@ -689,14 +694,16 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
-bufferization::getBufferType(Value value, const BufferizationOptions &options) {
+bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state) {
SmallVector<Value> invocationStack;
- return getBufferType(value, options, invocationStack);
+ return getBufferType(value, options, state, invocationStack);
}
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options,
+ BufferizationState &state,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) &&
"unexpected non-tensor type");
@@ -708,7 +715,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
Operation *op = getOwnerOfValue(value);
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (bufferizableOp)
- return bufferizableOp.getBufferType(value, options, invocationStack);
+ return bufferizableOp.getBufferType(value, options, state, invocationStack);
// Op is not bufferizable.
auto memSpace =
@@ -944,6 +951,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Failur...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/141466
More information about the Mlir-commits
mailing list