[Mlir-commits] [mlir] [MLIR] make One-Shot and SCF bufferization TensorLikeType-aware (PR #189073)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 27 11:00:26 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Dmitrii Makarenko (Devjiu)
<details>
<summary>Changes</summary>
Fix bufferization inconsistencies between builtin tensor types and custom TensorLikeType implementations across One-Shot analysis/module paths and SCF bufferization interfaces.
The main issue was a mix of TensorType/RankedTensorType checks in places that need TensorLikeType-aware handling. This could leave function-boundary equivalence/aliasing incomplete for custom tensor-like types, leading to spurious SCF loop equivalence verification failures.
This change:
- switches relevant One-Shot analysis/module checks from TensorType/ RankedTensorType to TensorLikeType;
- updates generic/default aliasing utilities to treat TensorLikeType consistently;
- updates SCF BufferizableOpInterface implementations (for/while/if/yield related paths) to use TensorLikeType/BufferLikeType where appropriate;
- updates test custom ops to provide required aliasing/getBufferType hooks for custom tensor-like types;
- refreshes and renames custom_types SCF tests to explicitly check memref replacement after bufferization.
Potential follow-ups / known risk areas:
- SCF.Forall shared_outs still has RankedTensorType assumptions in signatures/ paths and should be audited for full TensorLikeType coverage.
- SCF.For and SCF.While resolveConflicts call allocateTensorForShapedValue, which currently assumes ranked tensor/memref copy paths; this may still be a limitation for some tensor-like/unranked scenarios.
---
Patch is 58.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/189073.diff
8 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+175-180)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+6-5)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+14-14)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+10-9)
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+37-34)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+176)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+13)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+39-10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index c7775f2407ebd..c96309f3e81d5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -25,21 +25,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
alias at runtime.
}];
let cppNamespace = "::mlir::bufferization";
- let methods = [
- InterfaceMethod<
- /*desc=*/[{
+ let methods =
+ [InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given Value may bufferize to a new buffer
allocation. If it is statically unknown that the given Value
bufferizes to a buffer allocation, `true` should be returned.
}],
- /*retType=*/"bool",
- /*methodName=*/"bufferizesToAllocation",
- /*args=*/(ins "::mlir::Value":$value),
- /*methodBody=*/"",
- /*defaultImplementation=*/"return false;"
- >,
- InterfaceMethod<
- /*desc=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"bufferizesToAllocation",
+ /*args=*/(ins "::mlir::Value":$value),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/"return false;">,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given OpOperand bufferizes to a memory read. This
method will never be called on OpOperands that do not have a tensor
type.
@@ -50,18 +49,18 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
considers OpOperands of unknown ops (that do not implement this
interface) as reading OpOperands.
}],
- /*retType=*/"bool",
- /*methodName=*/"bufferizesToMemoryRead",
- /*args=*/(ins "::mlir::OpOperand &":$opOperand,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"bufferizesToMemoryRead",
+ /*args=*/
+ (ins "::mlir::OpOperand &":$opOperand,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
llvm_unreachable("bufferizesToMemoryRead not implemented");
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given OpOperand bufferizes to a memory write.
This method will never be called on OpOperands that do not have a
@@ -78,20 +77,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
considers OpOperands of unknown ops (that do not implement this
interface) as writing OpOperands.
}],
- /*retType=*/"bool",
- /*methodName=*/"bufferizesToMemoryWrite",
- /*args=*/(ins "::mlir::OpOperand &":$opOperand,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"bufferizesToMemoryWrite",
+ /*args=*/
+ (ins "::mlir::OpOperand &":$opOperand,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
// Does not have to be implemented for OpOperands that do not have an
// aliasing Value.
llvm_unreachable("bufferizesToMemoryWrite not implemented");
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the operation bufferizes to IR that performs only
element-wise accesses on the specified tensor operands. (The operands
must have the same shape.) The `bufferize` method must be implemented
@@ -135,18 +134,18 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
ignored. A conservative implementation of this interface method may
always return "false".
}],
- /*retType=*/"bool",
- /*methodName=*/"bufferizesToElementwiseAccess",
- /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state,
- "::llvm::ArrayRef<::mlir::OpOperand *>":$opOperands),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"bufferizesToElementwiseAccess",
+ /*args=*/
+ (ins "const ::mlir::bufferization::AnalysisState &":$state,
+ "::llvm::ArrayRef<::mlir::OpOperand *>":$opOperands),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
// It is always safe to assume that the op is not element-wise.
return false;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given OpResult bufferizes to a memory write.
This is the same property as `bufferizesToMemoryWrite`, but from The
perspective of OpResults.
@@ -191,20 +190,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
OpResult defined inside the scf.if op) bufferizes to a memory
write.
}],
- /*retType=*/"bool",
- /*methodName=*/"resultBufferizesToMemoryWrite",
- /*args=*/(ins "::mlir::OpResult":$opResult,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"resultBufferizesToMemoryWrite",
+ /*args=*/
+ (ins "::mlir::OpResult":$opResult,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
assert(opResult.getDefiningOp() == $_op.getOperation() &&
"invalid OpResult");
return ::mlir::bufferization::detail::defaultResultBufferizesToMemoryWrite(
opResult, state);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given OpOperand must bufferize in-place. Alias
sets and inplace attributes will be set up accordingly before making
any other bufferization decisions. This method will never be called on
@@ -214,17 +213,17 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
be extended in the future. Unranked tensors are used with external
functions only.
}],
- /*retType=*/"bool",
- /*methodName=*/"mustBufferizeInPlace",
- /*args=*/(ins "::mlir::OpOperand &":$opOperand,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"mustBufferizeInPlace",
+ /*args=*/
+ (ins "::mlir::OpOperand &":$opOperand,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
return ::llvm::isa<::mlir::UnrankedTensorType>(opOperand.get().getType());
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return the Values that may alias with a given OpOperand when
bufferized in-place. This method will never be called on OpOperands
that do not have a tensor type.
@@ -280,20 +279,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
%r = "dummy.alias_or_copy(%t) : (tensor<10xf32>) -> (tensor<10xf32>)"
```
}],
- /*retType=*/"::mlir::bufferization::AliasingValueList",
- /*methodName=*/"getAliasingValues",
- /*args=*/(ins "::mlir::OpOperand &":$opOperand,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"::mlir::bufferization::AliasingValueList",
+ /*methodName=*/"getAliasingValues",
+ /*args=*/
+ (ins "::mlir::OpOperand &":$opOperand,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
assert(::llvm::isa<::mlir::TensorType>(opOperand.get().getType()) &&
"expected OpOperand with tensor type");
llvm_unreachable("getAliasingValues not implemented");
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return the OpOperands that alias with a given Value when bufferized
in-place. This method will never be called on Values that do not
have a tensor type.
@@ -352,20 +351,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
%r = tensor.empty() : tensor<10xf32>
```
}],
- /*retType=*/"::mlir::bufferization::AliasingOpOperandList",
- /*methodName=*/"getAliasingOpOperands",
- /*args=*/(ins "::mlir::Value":$value,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
- assert(isa<::mlir::TensorType>(value.getType()) &&
+ /*retType=*/"::mlir::bufferization::AliasingOpOperandList",
+ /*methodName=*/"getAliasingOpOperands",
+ /*args=*/
+ (ins "::mlir::Value":$value,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ assert(isa<::mlir::bufferization::TensorLikeType>(value.getType()) &&
"expected tensor type");
return ::mlir::bufferization::detail::defaultGetAliasingOpOperands(
value, state);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Resolve all inplacability conflicts by inserting explicit
`bufferization.alloc_tensor` ops. Examples of inplacability conflicts
are read-after-write conflicts or writes into non-writable buffers.
@@ -378,21 +377,22 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
This method can query analysis information from the given analysis
state.
}],
- /*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"resolveConflicts",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::AnalysisState &":$analysisState,
- "const ::mlir::bufferization::BufferizationState &":$bufferizationState),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"::llvm::LogicalResult",
+ /*methodName=*/"resolveConflicts",
+ /*args=*/
+ (ins "::mlir::RewriterBase &":$rewriter,
+ "const ::mlir::bufferization::AnalysisState &":$analysisState,
+ "const ::mlir::bufferization::BufferizationState "
+ "&":$bufferizationState),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
auto bufferizableOp =
::llvm::cast<BufferizableOpInterface>($_op.getOperation());
return bufferizableOp.resolveTensorOpOperandConflicts(
rewriter, analysisState, bufferizationState);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Bufferize this op, i.e., rewrite it into a memref-based equivalent.
Buffers of tensor SSA values can be retrieved via `getBuffer`.
Uses of tensor results of the existing tensor op can be replaced with
@@ -429,19 +429,19 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
suggestion to make sure IR is valid at every point in time and could
be done differently).
}],
- /*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"bufferize",
- /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "const ::mlir::bufferization::BufferizationOptions &":$options,
- "::mlir::bufferization::BufferizationState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"::llvm::LogicalResult",
+ /*methodName=*/"bufferize",
+ /*args=*/
+ (ins "::mlir::RewriterBase &":$rewriter,
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "::mlir::bufferization::BufferizationState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
llvm_unreachable("bufferize not implemented");
return ::mlir::failure();
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given Value can be written to in-place. Value is
either an OpResult of this operation or a BlockArgument of a block of
this operation.
@@ -456,17 +456,17 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
operation. This method conservatively returns `false`. This method
will never be called on BlockArguments that do not have a tensor type.
}],
- /*retType=*/"bool",
- /*methodName=*/"isWritable",
- /*args=*/(ins "::mlir::Value":$value,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"isWritable",
+ /*args=*/
+ (ins "::mlir::Value":$value,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
return ::llvm::isa<::mlir::OpResult>(value);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the `uRead` and `uWrite` do not constitute a RaW
conflict. If they are conflicting or if it is unknown whether they are
conflicting, return `false`. This method will never be called with
@@ -478,18 +478,17 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
to be conflicting and do not force out-of-place bufferization. (There
may still be other conflicts that do.)
}],
- /*retType=*/"bool",
- /*methodName=*/"isNotConflicting",
- /*args=*/(ins "::mlir::OpOperand *":$uRead,
- "::mlir::OpOperand *":$uWrite,
- "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"isNotConflicting",
+ /*args=*/
+ (ins "::mlir::OpOperand *":$uRead, "::mlir::OpOperand *":$uWrite,
+ "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
return false;
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `failure` if this op does not pass the analysis. This method
is run during One-Shot Bufferize (after all post-analysis steps). If
the op does not pass the analysis, bufferization is aborted.
@@ -497,16 +496,15 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
This method can be used to check expected invariants and limitations
of the current bufferization implementation.
}],
- /*retType=*/"::llvm::LogicalResult",
- /*methodName=*/"verifyAnalysis",
- /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"::llvm::LogicalResult",
+ /*methodName=*/"verifyAnalysis",
+ /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
return ::mlir::success();
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return the bufferized type of the given tensor value (without
bufferizing the IR). The value is either a BlockArgument of a block
that belongs to this op or an OpResult of the given op.
@@ -525,24 +523,25 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
- /*methodName=*/"getBufferType",
- /*args=*/(ins "::mlir::Value":$value,
- "const ::mlir::bufferization::BufferizationOptions &":$options,
- "const ::mlir::bufferization::BufferizationState &":$state,
- "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/
+ "::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
+ /*methodName=*/"getBufferType",
+ /*args=*/
+ (ins "::mlir::Value":$value,
+ "const ::mlir::bufferization::BufferizationOptions &":$options,
+ "const ::mlir::bufferization::BufferizationState &":$state,
+ "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
assert(getOwnerOfValue(value) == $_op.getOperation() &&
"expected that value belongs to this op");
assert(invocationStack.back() == value &&
"inconsistant invocation stack");
return ::mlir::bufferization::detail::defaultGetBufferType(
value, options, state, invocationStack);
- }]
- >,
- InterfaceMethod<
- /*desc=*/[{
+ }]>,
+ InterfaceMethod<
+ /*desc=*/[{
Return `true` if the given region of this op is repetitive. By default
this information is queried from the `RegionBranchOpInterface`. Ops
that do not implement this inferface can override this method to
@@ -555,17 +554,16 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
that are executed multiple times. This is described in more detail in
documentation of One-Shot Analysis.
}],
- /*retType=*/"bool",
- /*methodName=*/"isRepetitiveRegion",
- /*args=*/(ins "unsigned":$index),
- /*methodBody=*/"",
- /*defaultImplementation=*/[{
+ /*retType=*/"bool",
+ /*methodName=*/"isRepetitiveRegion",
+ /*args=*/(ins "unsigned":$index),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
return ::mlir::bufferization::detail::defaultIsRepetitiveRegion(
::llvm::cast<BufferizableOpInterface>($_op.getOperation()), index);
- }]
- >,
- InterfaceMet...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/189073
More information about the Mlir-commits
mailing list