[Mlir-commits] [mlir] [MLIR] make One-Shot and SCF bufferization TensorLikeType-aware (PR #189073)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Mar 27 11:00:26 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Dmitrii Makarenko (Devjiu)

<details>
<summary>Changes</summary>

Fix bufferization inconsistencies between builtin tensor types and custom TensorLikeType implementations across One-Shot analysis/module paths and SCF bufferization interfaces.

The main issue was a mix of TensorType/RankedTensorType checks in places that need TensorLikeType-aware handling. This could leave function-boundary equivalence/aliasing incomplete for custom tensor-like types, leading to spurious SCF loop equivalence verification failures.

This change:
- switches relevant One-Shot analysis/module checks from TensorType/ RankedTensorType to TensorLikeType;
- updates generic/default aliasing utilities to treat TensorLikeType consistently;
- updates SCF BufferizableOpInterface implementations (for/while/if/yield related paths) to use TensorLikeType/BufferLikeType where appropriate;
- updates test custom ops to provide required aliasing/getBufferType hooks for custom tensor-like types;
- refreshes and renames custom_types SCF tests to explicitly check memref replacement after bufferization.

Potential follow-ups / known risk areas:
- SCF.Forall shared_outs still has RankedTensorType assumptions in signatures/ paths and should be audited for full TensorLikeType coverage.
- SCF.For and SCF.While resolveConflicts call allocateTensorForShapedValue, which currently assumes ranked tensor/memref copy paths; this may still be a limitation for some tensor-like/unranked scenarios.

---

Patch is 58.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/189073.diff


8 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+175-180) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+6-5) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+14-14) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp (+10-9) 
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+37-34) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir (+176) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+13) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+39-10) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index c7775f2407ebd..c96309f3e81d5 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -25,21 +25,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
     alias at runtime.
   }];
   let cppNamespace = "::mlir::bufferization";
-  let methods = [
-      InterfaceMethod<
-        /*desc=*/[{
+  let methods =
+      [InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given Value may bufferize to a new buffer
           allocation. If it is statically unknown that the given Value
           bufferizes to a buffer allocation, `true` should be returned.
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"bufferizesToAllocation",
-        /*args=*/(ins "::mlir::Value":$value),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/"return false;"
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"bufferizesToAllocation",
+           /*args=*/(ins "::mlir::Value":$value),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/"return false;">,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given OpOperand bufferizes to a memory read. This
           method will never be called on OpOperands that do not have a tensor
           type.
@@ -50,18 +49,18 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           considers OpOperands of unknown ops (that do not implement this
           interface) as reading OpOperands.
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"bufferizesToMemoryRead",
-        /*args=*/(ins "::mlir::OpOperand &":$opOperand,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"bufferizesToMemoryRead",
+           /*args=*/
+           (ins "::mlir::OpOperand &":$opOperand,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
           llvm_unreachable("bufferizesToMemoryRead not implemented");
-         }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+         }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given OpOperand bufferizes to a memory write.
 
           This method will never be called on OpOperands that do not have a
@@ -78,20 +77,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           considers OpOperands of unknown ops (that do not implement this
           interface) as writing OpOperands.
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"bufferizesToMemoryWrite",
-        /*args=*/(ins "::mlir::OpOperand &":$opOperand,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"bufferizesToMemoryWrite",
+           /*args=*/
+           (ins "::mlir::OpOperand &":$opOperand,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
           // Does not have to be implemented for OpOperands that do not have an
           // aliasing Value.
           llvm_unreachable("bufferizesToMemoryWrite not implemented");
-         }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+         }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the operation bufferizes to IR that performs only
           element-wise accesses on the specified tensor operands. (The operands
           must have the same shape.) The `bufferize` method must be implemented
@@ -135,18 +134,18 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           ignored. A conservative implementation of this interface method may
           always return "false".
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"bufferizesToElementwiseAccess",
-        /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state,
-                      "::llvm::ArrayRef<::mlir::OpOperand *>":$opOperands),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"bufferizesToElementwiseAccess",
+           /*args=*/
+           (ins "const ::mlir::bufferization::AnalysisState &":$state,
+               "::llvm::ArrayRef<::mlir::OpOperand *>":$opOperands),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           // It is always safe to assume that the op is not element-wise.
           return false;
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given OpResult bufferizes to a memory write.
           This is the same property as `bufferizesToMemoryWrite`, but from The
           perspective of OpResults.
@@ -191,20 +190,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
              OpResult defined inside the scf.if op) bufferizes to a memory
              write.
           }],
-        /*retType=*/"bool",
-        /*methodName=*/"resultBufferizesToMemoryWrite",
-        /*args=*/(ins "::mlir::OpResult":$opResult,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"resultBufferizesToMemoryWrite",
+           /*args=*/
+           (ins "::mlir::OpResult":$opResult,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           assert(opResult.getDefiningOp() == $_op.getOperation() &&
                  "invalid OpResult");
           return ::mlir::bufferization::detail::defaultResultBufferizesToMemoryWrite(
               opResult, state);
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given OpOperand must bufferize in-place. Alias
           sets and inplace attributes will be set up accordingly before making
           any other bufferization decisions. This method will never be called on
@@ -214,17 +213,17 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           be extended in the future. Unranked tensors are used with external
           functions only.
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"mustBufferizeInPlace",
-        /*args=*/(ins "::mlir::OpOperand &":$opOperand,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"mustBufferizeInPlace",
+           /*args=*/
+           (ins "::mlir::OpOperand &":$opOperand,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           return ::llvm::isa<::mlir::UnrankedTensorType>(opOperand.get().getType());
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return the Values that may alias with a given OpOperand when
           bufferized in-place. This method will never be called on OpOperands
           that do not have a tensor type.
@@ -280,20 +279,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           %r = "dummy.alias_or_copy(%t) : (tensor<10xf32>) -> (tensor<10xf32>)"
           ```
         }],
-        /*retType=*/"::mlir::bufferization::AliasingValueList",
-        /*methodName=*/"getAliasingValues",
-        /*args=*/(ins "::mlir::OpOperand &":$opOperand,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"::mlir::bufferization::AliasingValueList",
+           /*methodName=*/"getAliasingValues",
+           /*args=*/
+           (ins "::mlir::OpOperand &":$opOperand,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
           assert(::llvm::isa<::mlir::TensorType>(opOperand.get().getType()) &&
                  "expected OpOperand with tensor type");
           llvm_unreachable("getAliasingValues not implemented");
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return the OpOperands that alias with a given Value when bufferized
           in-place. This method will never be called on Values that do not
           have a tensor type.
@@ -352,20 +351,20 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           %r = tensor.empty() : tensor<10xf32>
           ```
         }],
-        /*retType=*/"::mlir::bufferization::AliasingOpOperandList",
-        /*methodName=*/"getAliasingOpOperands",
-        /*args=*/(ins "::mlir::Value":$value,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
-          assert(isa<::mlir::TensorType>(value.getType()) &&
+           /*retType=*/"::mlir::bufferization::AliasingOpOperandList",
+           /*methodName=*/"getAliasingOpOperands",
+           /*args=*/
+           (ins "::mlir::Value":$value,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
+          assert(isa<::mlir::bufferization::TensorLikeType>(value.getType()) &&
                  "expected tensor type");
           return ::mlir::bufferization::detail::defaultGetAliasingOpOperands(
               value, state);
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Resolve all inplacability conflicts by inserting explicit
           `bufferization.alloc_tensor` ops. Examples of inplacability conflicts
           are read-after-write conflicts or writes into non-writable buffers.
@@ -378,21 +377,22 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           This method can query analysis information from the given analysis
           state.
         }],
-        /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"resolveConflicts",
-        /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::AnalysisState &":$analysisState,
-                      "const ::mlir::bufferization::BufferizationState &":$bufferizationState),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"::llvm::LogicalResult",
+           /*methodName=*/"resolveConflicts",
+           /*args=*/
+           (ins "::mlir::RewriterBase &":$rewriter,
+               "const ::mlir::bufferization::AnalysisState &":$analysisState,
+               "const ::mlir::bufferization::BufferizationState "
+               "&":$bufferizationState),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           auto bufferizableOp =
               ::llvm::cast<BufferizableOpInterface>($_op.getOperation());
           return bufferizableOp.resolveTensorOpOperandConflicts(
               rewriter, analysisState, bufferizationState);
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Bufferize this op, i.e., rewrite it into a memref-based equivalent.
           Buffers of tensor SSA values can be retrieved via `getBuffer`.
           Uses of tensor results of the existing tensor op can be replaced with
@@ -429,19 +429,19 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           suggestion to make sure IR is valid at every point in time and could
           be done differently).
         }],
-        /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"bufferize",
-        /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "::mlir::bufferization::BufferizationState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"::llvm::LogicalResult",
+           /*methodName=*/"bufferize",
+           /*args=*/
+           (ins "::mlir::RewriterBase &":$rewriter,
+               "const ::mlir::bufferization::BufferizationOptions &":$options,
+               "::mlir::bufferization::BufferizationState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           llvm_unreachable("bufferize not implemented");
           return ::mlir::failure();
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given Value can be written to in-place. Value is
           either an OpResult of this operation or a BlockArgument of a block of
           this operation.
@@ -456,17 +456,17 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           operation. This method conservatively returns `false`. This method
           will never be called on BlockArguments that do not have a tensor type.
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"isWritable",
-        /*args=*/(ins "::mlir::Value":$value,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"isWritable",
+           /*args=*/
+           (ins "::mlir::Value":$value,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           return ::llvm::isa<::mlir::OpResult>(value);
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the `uRead` and `uWrite` do not constitute a RaW
           conflict. If they are conflicting or if it is unknown whether they are
           conflicting, return `false`. This method will never be called with
@@ -478,18 +478,17 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           to be conflicting and do not force out-of-place bufferization. (There
           may still be other conflicts that do.)
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"isNotConflicting",
-        /*args=*/(ins "::mlir::OpOperand *":$uRead,
-                      "::mlir::OpOperand *":$uWrite,
-                      "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"isNotConflicting",
+           /*args=*/
+           (ins "::mlir::OpOperand *":$uRead, "::mlir::OpOperand *":$uWrite,
+               "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           return false;
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `failure` if this op does not pass the analysis. This method
           is run during One-Shot Bufferize (after all post-analysis steps). If
           the op does not pass the analysis, bufferization is aborted.
@@ -497,16 +496,15 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           This method can be used to check expected invariants and limitations
           of the current bufferization implementation.
         }],
-        /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"verifyAnalysis",
-        /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"::llvm::LogicalResult",
+           /*methodName=*/"verifyAnalysis",
+           /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           return ::mlir::success();
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return the bufferized type of the given tensor value (without
           bufferizing the IR). The value is either a BlockArgument of a block
           that belongs to this op or an OpResult of the given op.
@@ -525,24 +523,25 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
-        /*methodName=*/"getBufferType",
-        /*args=*/(ins "::mlir::Value":$value,
-                      "const ::mlir::bufferization::BufferizationOptions &":$options,
-                      "const ::mlir::bufferization::BufferizationState &":$state,
-                      "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/
+           "::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
+           /*methodName=*/"getBufferType",
+           /*args=*/
+           (ins "::mlir::Value":$value,
+               "const ::mlir::bufferization::BufferizationOptions &":$options,
+               "const ::mlir::bufferization::BufferizationState &":$state,
+               "::llvm::SmallVector<::mlir::Value> &":$invocationStack),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           assert(getOwnerOfValue(value) == $_op.getOperation() &&
                  "expected that value belongs to this op");
           assert(invocationStack.back() == value &&
                  "inconsistant invocation stack");
           return ::mlir::bufferization::detail::defaultGetBufferType(
               value, options, state, invocationStack);
-        }]
-      >,
-      InterfaceMethod<
-        /*desc=*/[{
+        }]>,
+       InterfaceMethod<
+           /*desc=*/[{
           Return `true` if the given region of this op is repetitive. By default
           this information is queried from the `RegionBranchOpInterface`. Ops
           that do not implement this inferface can override this method to
@@ -555,17 +554,16 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           that are executed multiple times. This is described in more detail in
           documentation of One-Shot Analysis.
         }],
-        /*retType=*/"bool",
-        /*methodName=*/"isRepetitiveRegion",
-        /*args=*/(ins "unsigned":$index),
-        /*methodBody=*/"",
-        /*defaultImplementation=*/[{
+           /*retType=*/"bool",
+           /*methodName=*/"isRepetitiveRegion",
+           /*args=*/(ins "unsigned":$index),
+           /*methodBody=*/"",
+           /*defaultImplementation=*/[{
           return ::mlir::bufferization::detail::defaultIsRepetitiveRegion(
               ::llvm::cast<BufferizableOpInterface>($_op.getOperation()), index);
-        }]
-      >,
-      InterfaceMet...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/189073


More information about the Mlir-commits mailing list