[Mlir-commits] [mlir] [MLIR] make One-Shot and SCF bufferization TensorLikeType-aware (PR #189073)

Dmitrii Makarenko llvmlistbot at llvm.org
Wed Apr 15 09:28:56 PDT 2026


https://github.com/Devjiu updated https://github.com/llvm/llvm-project/pull/189073

>From 5e60afb3580da6728b66fdac07ce719d819fb9bc Mon Sep 17 00:00:00 2001
From: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
Date: Fri, 27 Mar 2026 17:42:11 +0000
Subject: [PATCH] [MLIR] make One-Shot and SCF bufferization
 TensorLikeType-aware

Fix bufferization inconsistencies between builtin tensor types and custom
TensorLikeType implementations across One-Shot analysis/module paths and SCF
bufferization interfaces.

The main issue was a mix of TensorType/RankedTensorType checks in places that
need TensorLikeType-aware handling. This could leave function-boundary
equivalence/aliasing incomplete for custom tensor-like types, leading to
spurious SCF loop equivalence verification failures.

This change:
- switches relevant One-Shot analysis/module checks from TensorType/
  RankedTensorType to TensorLikeType;
- updates generic/default aliasing utilities to treat TensorLikeType
  consistently;
- updates SCF BufferizableOpInterface implementations (for/while/if/yield
  related paths) to use TensorLikeType/BufferLikeType where appropriate;
- updates test custom ops to provide required aliasing/getBufferType hooks for
  custom tensor-like types;
- refreshes and renames custom_types SCF tests to explicitly check memref
  replacement after bufferization.

Potential follow-ups / known risk areas:
- SCF.Forall shared_outs still has RankedTensorType assumptions in signatures/
  paths and should be audited for full TensorLikeType coverage.
- SCF.For and SCF.While resolveConflicts call allocateTensorForShapedValue,
  which currently assumes ranked tensor/memref copy paths; this may still be a
  limitation for some tensor-like/unranked scenarios.

Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
---
 .../IR/BufferizableOpInterface.td             |   8 +-
 .../IR/BufferizableOpInterface.cpp            |  11 +-
 .../Transforms/OneShotAnalysis.cpp            |  28 +--
 .../Transforms/OneShotModuleBufferize.cpp     |  19 +-
 .../BufferizableOpInterfaceImpl.cpp           |  71 +++----
 .../Transforms/one-shot-module-bufferize.mlir | 176 ++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     |  13 ++
 mlir/test/lib/Dialect/Test/TestOps.td         |  49 ++++-
 8 files changed, 299 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index c7775f2407ebd..34aaf7432bdfd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -287,8 +287,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           // Does not have to be implemented for ops without tensor OpOperands.
-          assert(::llvm::isa<::mlir::TensorType>(opOperand.get().getType()) &&
-                 "expected OpOperand with tensor type");
+          assert(::llvm::isa<::mlir::bufferization::TensorLikeType>(opOperand.get().getType()) &&
+                 "expected OpOperand with tensor like type");
           llvm_unreachable("getAliasingValues not implemented");
         }]
       >,
@@ -358,8 +358,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
                       "const ::mlir::bufferization::AnalysisState &":$state),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
-          assert(isa<::mlir::TensorType>(value.getType()) &&
-                 "expected tensor type");
+          assert(isa<::mlir::bufferization::TensorLikeType>(value.getType()) &&
+                 "expected tensor like type");
           return ::mlir::bufferization::detail::defaultGetAliasingOpOperands(
               value, state);
         }]
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 08319ef9df79a..1696e527511a7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -508,7 +508,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
 bool AnalysisState::isValueRead(Value value) const {
-  assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+  assert(llvm::isa<TensorLikeType>(value.getType()) &&
+         "expected TensorLikeType");
   SmallVector<OpOperand *> workingSet;
   DenseSet<OpOperand *> visited;
   for (OpOperand &use : value.getUses())
@@ -948,7 +949,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   Operation *op = getOwnerOfValue(value);
   SmallVector<AliasingOpOperand> result;
   for (OpOperand &opOperand : op->getOpOperands()) {
-    if (!llvm::isa<TensorType>(opOperand.get().getType()))
+    if (!llvm::isa<TensorLikeType>(opOperand.get().getType()))
       continue;
     AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
     for (const auto &it : aliasingValues)
@@ -1027,7 +1028,7 @@ bufferization::detail::unknownGetAliasingOpOperands(Value value) {
   // with every OpOperand.
   AliasingOpOperandList r;
   for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
-    if (isa<TensorType>(operand.get().getType()))
+    if (isa<TensorLikeType>(operand.get().getType()))
       r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
   return r;
 }
@@ -1040,12 +1041,12 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
   // with every OpOperand.
   AliasingValueList r;
   for (OpResult result : opOperand.getOwner()->getOpResults())
-    if (llvm::isa<TensorType>(result.getType()))
+    if (llvm::isa<TensorLikeType>(result.getType()))
       r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
   for (Region &region : opOperand.getOwner()->getRegions())
     if (!region.getBlocks().empty())
       for (BlockArgument bbArg : region.getBlocks().front().getArguments())
-        if (isa<TensorType>(bbArg.getType()))
+        if (isa<TensorLikeType>(bbArg.getType()))
           r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
   return r;
 }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index b57811868a725..a70a9a6232052 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -67,7 +67,7 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
 using namespace mlir;
 using namespace mlir::bufferization;
 
-static bool isaTensor(Type t) { return isa<TensorType>(t); }
+static bool isaTensor(Type t) { return isa<TensorLikeType>(t); }
 
 //===----------------------------------------------------------------------===//
 // Bufferization-specific attribute manipulation.
@@ -100,7 +100,7 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
   } else {
     inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
     for (OpOperand &opOperand : op->getOpOperands())
-      if (isa<TensorType>(opOperand.get().getType()))
+      if (isa<TensorLikeType>(opOperand.get().getType()))
         inPlaceVector[opOperand.getOperandNumber()] = "false";
   }
   inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
@@ -118,12 +118,12 @@ OneShotAnalysisState::OneShotAnalysisState(
   // Set up alias sets.
   op->walk([&](Operation *op) {
     for (Value v : op->getResults())
-      if (isa<TensorType>(v.getType()))
+      if (isa<TensorLikeType>(v.getType()))
         createAliasInfoEntry(v);
     for (Region &r : op->getRegions())
       for (Block &b : r.getBlocks())
         for (auto bbArg : b.getArguments())
-          if (isa<TensorType>(bbArg.getType()))
+          if (isa<TensorLikeType>(bbArg.getType()))
             createAliasInfoEntry(bbArg);
   });
 
@@ -132,7 +132,7 @@ OneShotAnalysisState::OneShotAnalysisState(
     if (!options.isOpAllowed(bufferizableOp))
       return WalkResult::skip();
     for (OpOperand &opOperand : bufferizableOp->getOpOperands())
-      if (isa<TensorType>(opOperand.get().getType()))
+      if (isa<TensorLikeType>(opOperand.get().getType()))
         if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
           bufferizeInPlace(opOperand);
     return WalkResult::advance();
@@ -195,7 +195,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
 
     // Check all tensor OpResults.
     for (OpResult opResult : op->getOpResults()) {
-      if (!isa<TensorType>(opResult.getType()))
+      if (!isa<TensorLikeType>(opResult.getType()))
         continue;
 
       // If there is no preceding definition, the tensor contents are
@@ -1001,7 +1001,7 @@ LogicalResult
 OneShotAnalysisState::analyzeSingleOp(Operation *op,
                                       const DominanceInfo &domInfo) {
   for (OpOperand &opOperand : op->getOpOperands())
-    if (isa<TensorType>(opOperand.get().getType()))
+    if (isa<TensorLikeType>(opOperand.get().getType()))
       if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
         return failure();
   return success();
@@ -1013,7 +1013,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
   for (Operation *op : ops) {
     if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
       for (OpResult opResult : op->getOpResults()) {
-        if (!isa<TensorType>(opResult.getType()))
+        if (!isa<TensorLikeType>(opResult.getType()))
           continue;
         AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
         if (aliases.getNumAliases() == 0)
@@ -1085,7 +1085,7 @@ bottomUpFromTerminatorsHeuristic(Operation *op,
     // we stay within the same region.
     SmallVector<OpResult> worklist;
     for (Value v : term->getOperands()) {
-      if (!isa<TensorType>(v.getType()))
+      if (!isa<TensorLikeType>(v.getType()))
         continue;
       auto opResult = dyn_cast<OpResult>(v);
       if (!opResult)
@@ -1102,7 +1102,7 @@ bottomUpFromTerminatorsHeuristic(Operation *op,
       AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
       for (auto alias : aliases) {
         Value v = alias.opOperand->get();
-        if (!isa<TensorType>(v.getType()))
+        if (!isa<TensorLikeType>(v.getType()))
           continue;
         auto opResult = dyn_cast<OpResult>(v);
         if (!opResult)
@@ -1222,7 +1222,7 @@ checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
     }
 
     for (OpOperand &opOperand : op->getOpOperands()) {
-      if (isa<TensorType>(opOperand.get().getType())) {
+      if (isa<TensorLikeType>(opOperand.get().getType())) {
         if (wouldCreateReadAfterWriteInterference(
                 opOperand, domInfo, state,
                 /*checkConsistencyOnly=*/true)) {
@@ -1259,7 +1259,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
   // Add __inplace_operands_attr__.
   op->walk([&](Operation *op) {
     for (OpOperand &opOperand : op->getOpOperands())
-      if (isa<TensorType>(opOperand.get().getType()))
+      if (isa<TensorLikeType>(opOperand.get().getType()))
         setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
   });
 }
@@ -1284,7 +1284,7 @@ static void annotateOpsWithAliasSets(Operation *op,
     // Build alias set array for every OpResult.
     SmallVector<Attribute> opResultAliasSets;
     for (OpResult opResult : op->getOpResults()) {
-      if (llvm::isa<TensorType>(opResult.getType())) {
+      if (llvm::isa<TensorLikeType>(opResult.getType())) {
         opResultAliasSets.push_back(buildAliasesArray(opResult));
       }
     }
@@ -1299,7 +1299,7 @@ static void annotateOpsWithAliasSets(Operation *op,
       for (Block &block : r.getBlocks()) {
         SmallVector<Attribute> bbArgAliasSets;
         for (BlockArgument bbArg : block.getArguments()) {
-          if (llvm::isa<TensorType>(bbArg.getType())) {
+          if (llvm::isa<TensorLikeType>(bbArg.getType())) {
             bbArgAliasSets.push_back(buildAliasesArray(bbArg));
             hasTensorBbArg = true;
           }
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 6c5719ce6df8e..4d044bbb74df1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -62,6 +62,7 @@
 #include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
 #include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
 #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
@@ -122,10 +123,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
     // return value may alias with any tensor bbArg.
     FunctionType type = funcOp.getFunctionType();
     for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
-      if (!isa<TensorType>(inputIt.value()))
+      if (!isa<TensorLikeType>(inputIt.value()))
         continue;
       for (const auto &resultIt : llvm::enumerate(type.getResults())) {
-        if (!isa<TensorType>(resultIt.value()))
+        if (!isa<TensorLikeType>(resultIt.value()))
           continue;
         int64_t returnIdx = resultIt.index();
         int64_t bbArgIdx = inputIt.index();
@@ -145,13 +146,13 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
 
   // Build alias sets. Merge all aliases from all func.return ops.
   for (BlockArgument bbArg : funcOp.getArguments()) {
-    if (isa<RankedTensorType>(bbArg.getType())) {
+    if (isa<TensorLikeType>(bbArg.getType())) {
       int64_t bbArgIdx = bbArg.getArgNumber();
       // Store aliases in a set, so that we don't add the same alias twice.
       SetVector<int64_t> aliases;
       for (func::ReturnOp returnOp : returnOps) {
         for (OpOperand &returnVal : returnOp->getOpOperands()) {
-          if (isa<RankedTensorType>(returnVal.get().getType())) {
+          if (isa<TensorLikeType>(returnVal.get().getType())) {
             int64_t returnIdx = returnVal.getOperandNumber();
             if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
               aliases.insert(returnIdx);
@@ -170,10 +171,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
   auto findEquivalentBlockArgIdx =
       [&](OpOperand &opOperand) -> std::optional<int64_t> {
     Value v = opOperand.get();
-    if (!isa<TensorType>(v.getType()))
+    if (!isa<TensorLikeType>(v.getType()))
       return std::nullopt;
     for (BlockArgument bbArg : funcOp.getArguments()) {
-      if (isa<RankedTensorType>(bbArg.getType())) {
+      if (isa<TensorLikeType>(bbArg.getType())) {
         if (state.areEquivalentBufferizedValues(v, bbArg)) {
           if (state.getOptions().testAnalysisOnly)
             annotateEquivalentReturnBbArg(opOperand, bbArg);
@@ -243,7 +244,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
   for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
        ++idx) {
     // Skip non-tensor arguments.
-    if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
+    if (!isa<TensorLikeType>(funcOp.getFunctionType().getInput(idx)))
       continue;
     bool isRead;
     bool isWritten;
@@ -297,9 +298,9 @@ getCalledFunction(func::CallOp callOp,
 /// Return "true" if the given function signature has tensor semantics.
 static bool hasTensorSignature(func::FuncOp funcOp) {
   return llvm::any_of(funcOp.getFunctionType().getInputs(),
-                      llvm::IsaPred<TensorType>) ||
+                      llvm::IsaPred<TensorLikeType>) ||
          llvm::any_of(funcOp.getFunctionType().getResults(),
-                      llvm::IsaPred<TensorType>);
+                      llvm::IsaPred<TensorLikeType>);
 }
 
 /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 9b6a5a96fbc6b..32aba47f50866 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -32,11 +32,12 @@ namespace {
 /// Helper function for loop bufferization. Cast the given buffer to the given
 /// memref type.
 static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
-  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
-  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
   // If the buffer already has the correct type, no cast is needed.
   if (buffer.getType() == type)
     return buffer;
+
+  assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
+  assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
   // TODO: In case `type` has a layout map that is not the fully dynamic
   // one, we may not be able to cast the buffer. In that case, the loop
   // iter_arg's layout map must be changed (see uses of `castBuffer`).
@@ -102,7 +103,7 @@ struct ConditionOpInterface
     SmallVector<Value> newArgs;
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
       Value value = it.value();
-      if (isa<TensorType>(value.getType())) {
+      if (isa<TensorLikeType>(value.getType())) {
         FailureOr<Value> maybeBuffer =
             getBuffer(rewriter, value, options, state);
         if (failed(maybeBuffer))
@@ -247,7 +248,7 @@ struct IfOpInterface
     // Compute bufferized result types.
     SmallVector<Type> newTypes;
     for (Value result : ifOp.getResults()) {
-      if (!isa<TensorType>(result.getType())) {
+      if (!isa<TensorLikeType>(result.getType())) {
         newTypes.push_back(result.getType());
         continue;
       }
@@ -286,25 +287,23 @@ struct IfOpInterface
     auto opResult = cast<OpResult>(value);
     auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
     auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
-    BaseMemRefType thenBufferType, elseBufferType;
-    if (isa<BaseMemRefType>(thenValue.getType())) {
+    BufferLikeType thenBufferType, elseBufferType;
+    if (isa<BufferLikeType>(thenValue.getType())) {
       // True branch was already bufferized.
-      thenBufferType = cast<BaseMemRefType>(thenValue.getType());
+      thenBufferType = cast<BufferLikeType>(thenValue.getType());
     } else {
-      auto maybeBufferType =
-          bufferization::detail::asMemRefType(bufferization::getBufferType(
-              thenValue, options, state, invocationStack));
+      auto maybeBufferType = bufferization::getBufferType(
+          thenValue, options, state, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       thenBufferType = *maybeBufferType;
     }
-    if (isa<BaseMemRefType>(elseValue.getType())) {
+    if (isa<BufferLikeType>(elseValue.getType())) {
       // False branch was already bufferized.
-      elseBufferType = cast<BaseMemRefType>(elseValue.getType());
+      elseBufferType = cast<BufferLikeType>(elseValue.getType());
     } else {
-      auto maybeBufferType =
-          bufferization::detail::asMemRefType(bufferization::getBufferType(
-              elseValue, options, state, invocationStack));
+      auto maybeBufferType = bufferization::getBufferType(
+          elseValue, options, state, invocationStack);
       if (failed(maybeBufferType))
         return failure();
       elseBufferType = *maybeBufferType;
@@ -315,12 +314,17 @@ struct IfOpInterface
       return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
-    if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
+    auto thenBaseMemRefType = dyn_cast<BaseMemRefType>(thenBufferType);
+    auto elseBaseMemRefType = dyn_cast<BaseMemRefType>(elseBufferType);
+    if (thenBaseMemRefType && elseBaseMemRefType &&
+        thenBaseMemRefType.getMemorySpace() !=
+            elseBaseMemRefType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
     return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
+        cast<TensorType>(opResult.getType()),
+        thenBaseMemRefType.getMemorySpace()));
   }
 };
 
@@ -444,7 +448,7 @@ struct IndexSwitchOpInterface
 static DenseSet<int64_t> getTensorIndices(ValueRange values) {
   DenseSet<int64_t> result;
   for (const auto &it : llvm::enumerate(values))
-    if (isa<TensorType>(it.value().getType()))
+    if (isa<TensorLikeType>(it.value().getType()))
       result.insert(it.index());
   return result;
 }
@@ -457,8 +461,8 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
   unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
   DenseSet<int64_t> result;
   for (unsigned int i = 0; i < minSize; ++i) {
-    if (!isa<TensorType>(bbArgs[i].getType()) ||
-        !isa<TensorType>(yieldedValues[i].getType()))
+    if (!isa<TensorLikeType>(bbArgs[i].getType()) ||
+        !isa<TensorLikeType>(yieldedValues[i].getType()))
       continue;
     if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
       result.insert(i);
@@ -473,7 +477,7 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
            const BufferizationOptions &options, BufferizationState &state) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
-    if (isa<TensorType>(opOperand.get().getType())) {
+    if (isa<TensorLikeType>(opOperand.get().getType())) {
       FailureOr<Value> resultBuffer =
           getBuffer(rewriter, opOperand.get(), options, state);
       if (failed(resultBuffer))
@@ -547,7 +551,7 @@ static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
 
   // Compute the buffer type of the yielded value.
   BufferLikeType yieldedValueBufferType;
-  if (isa<BaseMemRefType>(yieldedValue.getType())) {
+  if (isa<BufferLikeType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
     yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
@@ -712,7 +716,7 @@ struct ForOpInterface
                 SmallVector<Value> &invocationStack) const {
     auto forOp = cast<scf::ForOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
-    assert(isa<TensorType>(value.getType()) && "expected tensor type");
+    assert(isa<TensorLikeType>(value.getType()) && "expected tensor type");
 
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
@@ -757,7 +761,7 @@ struct ForOpInterface
       Value initArg = it.value();
       Value result = forOp->getResult(it.index());
       // If the type is not a tensor, bufferization doesn't need to touch it.
-      if (!isa<TensorType>(result.getType())) {
+      if (!isa<TensorLikeType>(result.getType())) {
         castedInitArgs.push_back(initArg);
         continue;
       }
@@ -809,9 +813,8 @@ struct ForOpInterface
     auto forOp = cast<scf::ForOp>(op);
     auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     for (OpResult opResult : op->getOpResults()) {
-      if (!isa<TensorType>(opResult.getType()))
+      if (!isa<TensorLikeType>(opResult.getType()))
         continue;
-
       // Note: This is overly strict. We should check for aliasing bufferized
       // values. But we don't have a "must-alias" analysis yet.
       if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
@@ -938,7 +941,7 @@ struct WhileOpInterface
     for (int64_t idx = 0;
          idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
       Value value = conditionOp.getArgs()[idx];
-      if (!isa<TensorType>(value.getType()) ||
+      if (!isa<TensorLikeType>(value.getType()) ||
           (equivalentYieldsAfter.contains(idx) &&
            equivalentYieldsBefore.contains(idx))) {
         beforeYieldValues.push_back(value);
@@ -982,7 +985,7 @@ struct WhileOpInterface
       Value initArg = it.value();
       Value beforeArg = whileOp.getBeforeArguments()[it.index()];
       // If the type is not a tensor, bufferization doesn't need to touch it.
-      if (!isa<TensorType>(beforeArg.getType())) {
+      if (!isa<TensorLikeType>(beforeArg.getType())) {
         castedInitArgs.push_back(initArg);
         continue;
       }
@@ -995,7 +998,7 @@ struct WhileOpInterface
     // The result types of a WhileOp are the same as the "after" bbArg types.
     SmallVector<Type> argsTypesAfter = llvm::map_to_vector(
         whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
-          if (!isa<TensorType>(bbArg.getType()))
+          if (!isa<TensorLikeType>(bbArg.getType()))
             return bbArg.getType();
           // TODO: error handling
           return llvm::cast<Type>(
@@ -1048,7 +1051,7 @@ struct WhileOpInterface
                 SmallVector<Value> &invocationStack) const {
     auto whileOp = cast<scf::WhileOp>(op);
     assert(getOwnerOfValue(value) == op && "invalid value");
-    assert(isa<TensorType>(value.getType()) && "expected tensor type");
+    assert(isa<TensorLikeType>(value.getType()) && "expected tensor type");
 
     // Case 1: Block argument of the "before" region.
     if (auto bbArg = dyn_cast<BlockArgument>(value)) {
@@ -1074,7 +1077,7 @@ struct WhileOpInterface
       llvm_unreachable("invalid value");
     }
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
-    if (!isa<TensorType>(conditionYieldedVal.getType())) {
+    if (!isa<TensorLikeType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
       return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
@@ -1103,7 +1106,7 @@ struct WhileOpInterface
     auto conditionOp = whileOp.getConditionOp();
     for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
       Block *block = conditionOp->getBlock();
-      if (!isa<TensorType>(it.value().getType()))
+      if (!isa<TensorLikeType>(it.value().getType()))
         continue;
       if (it.index() >= block->getNumArguments() ||
           !state.areEquivalentBufferizedValues(it.value(),
@@ -1116,7 +1119,7 @@ struct WhileOpInterface
     auto yieldOp = whileOp.getYieldOp();
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Block *block = yieldOp->getBlock();
-      if (!isa<TensorType>(it.value().getType()))
+      if (!isa<TensorLikeType>(it.value().getType()))
         continue;
       if (it.index() >= block->getNumArguments() ||
           !state.areEquivalentBufferizedValues(it.value(),
@@ -1176,7 +1179,7 @@ struct YieldOpInterface
     SmallVector<Value> newResults;
     for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
       Value value = it.value();
-      if (isa<TensorType>(value.getType())) {
+      if (isa<TensorLikeType>(value.getType())) {
         FailureOr<Value> maybeBuffer =
             getBuffer(rewriter, value, options, state);
         if (failed(maybeBuffer))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d5cb7a0f14f5a..f8d5a1310ebdf 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -905,3 +905,179 @@ func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tenso
   return %b : tensor<64x20x40xf32>
 }
 func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>
+
+// -----
+
+// CHECK:   func.func @custom_types_scf_for_inplace(
+// CHECK-SAME:    %[[arg:.+]]: !test.test_memref<[4, 4], f64>,
+// CHECK-SAME:    %[[lb:.+]]: index, %[[ub:.+]]: index, %[[step:.+]]: index
+// CHECK-SAME:  ) -> !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_for_inplace(
+    %arg: !test.test_tensor<[4, 4], f64>,
+    %lb: index, %ub: index, %step: index)
+    -> !test.test_tensor<[4, 4], f64> {
+  // CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]]
+  // CHECK-SAME: iter_args(%[[iter:.+]] = %[[arg]]) -> (!test.test_memref<[4, 4], f64>) {
+  // CHECK: %[[call:.+]] = "test.dummy_memref_op"(%[[iter]])
+  // CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64>
+  %loop = scf.for %i = %lb to %ub step %step
+      iter_args(%iter = %arg) -> (!test.test_tensor<[4, 4], f64>) {
+    // Inside loop: use iter_args directly (this is inplace modifiable op)
+    %call = "test.dummy_tensor_op"(%iter) : (!test.test_tensor<[4, 4], f64>)
+      -> !test.test_tensor<[4, 4], f64>
+    // Yield: return the same iter_args value (or result of inplace op on it)
+    scf.yield %call : !test.test_tensor<[4, 4], f64>
+  }
+
+  // CHECK: return %[[loop]] : !test.test_memref<[4, 4], f64>
+  return %loop : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64> {
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+  return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// Same as @custom_types_scf_for_inplace, but with an inner call to test alias analysis
+// through function boundaries.
+// CHECK-LABEL: func.func @custom_types_scf_for_inplace_with_call(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: %[[lb:.+]]: index, %[[ub:.+]]: index, %[[step:.+]]: index
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+// CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[iter:.+]] = %[[arg]]) -> (!test.test_memref<[4, 4], f64>) {
+// CHECK: %[[call:.+]] = func.call @custom_types_identity_2d(%[[iter]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64>
+// CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64>
+// CHECK: return %[[loop]] : !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_for_inplace_with_call(
+    %arg: !test.test_tensor<[4, 4], f64>,
+    %lb: index, %ub: index, %step: index)
+    -> !test.test_tensor<[4, 4], f64> {
+  %loop = scf.for %i = %lb to %ub step %step
+      iter_args(%iter = %arg) -> (!test.test_tensor<[4, 4], f64>) {
+    %call = func.call @custom_types_identity_2d(%iter)
+      : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64>
+    scf.yield %call : !test.test_tensor<[4, 4], f64>
+  }
+
+  return %loop : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @custom_types_scf_if_inplace(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: %[[cond:.+]]: i1
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+// CHECK: %[[res:.+]] = scf.if %[[cond]] -> (!test.test_memref<[4, 4], f64>) {
+// CHECK: %[[dummy:.+]] = "test.dummy_memref_op"(%[[arg]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64>
+// CHECK: scf.yield %[[dummy]] : !test.test_memref<[4, 4], f64>
+// CHECK: } else {
+// CHECK: scf.yield %[[arg]] : !test.test_memref<[4, 4], f64>
+// CHECK: }
+// CHECK: return %[[res]] : !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_if_inplace(
+    %arg: !test.test_tensor<[4, 4], f64>,
+    %cond: i1)
+    -> !test.test_tensor<[4, 4], f64> {
+  %res = scf.if %cond -> (!test.test_tensor<[4, 4], f64>) {
+    %dummy = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+      -> !test.test_tensor<[4, 4], f64>
+    scf.yield %dummy : !test.test_tensor<[4, 4], f64>
+  } else {
+    scf.yield %arg : !test.test_tensor<[4, 4], f64>
+  }
+  return %res : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64> {
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+  return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK-LABEL: func.func @custom_types_scf_if_inplace_with_call(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: %[[cond:.+]]: i1
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+// CHECK: %[[res:.+]] = scf.if %[[cond]] -> (!test.test_memref<[4, 4], f64>) {
+// CHECK: %[[call:.+]] = func.call @custom_types_identity_2d(%[[arg]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64>
+// CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64>
+// CHECK: } else {
+// CHECK: scf.yield %[[arg]] : !test.test_memref<[4, 4], f64>
+// CHECK: }
+// CHECK: return %[[res]] : !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_if_inplace_with_call(
+    %arg: !test.test_tensor<[4, 4], f64>,
+    %cond: i1)
+    -> !test.test_tensor<[4, 4], f64> {
+  %res = scf.if %cond -> (!test.test_tensor<[4, 4], f64>) {
+    %call = func.call @custom_types_identity_2d(%arg)
+      : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64>
+    scf.yield %call : !test.test_tensor<[4, 4], f64>
+  } else {
+    scf.yield %arg : !test.test_tensor<[4, 4], f64>
+  }
+  return %res : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scf_while_inplace(
+// CHECK-SAME: !test.test_memref<[4, 4], f64>
+// CHECK: scf.while
+// CHECK: scf.condition
+// CHECK: scf.yield
+// CHECK: return
+func.func @scf_while_inplace(
+    %arg: !test.test_tensor<[4, 4], f64>,
+    %cond: i1)
+    -> !test.test_tensor<[4, 4], f64> {
+  %loop = scf.while (%iter = %arg)
+      : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> {
+    scf.condition(%cond) %iter : !test.test_tensor<[4, 4], f64>
+  } do {
+  ^bb0(%current: !test.test_tensor<[4, 4], f64>):
+    %dummy = "test.dummy_tensor_op"(%current) : (!test.test_tensor<[4, 4], f64>)
+      -> !test.test_tensor<[4, 4], f64>
+    scf.yield %dummy : !test.test_tensor<[4, 4], f64>
+  }
+  return %loop : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64> {
+  %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+    -> !test.test_tensor<[4, 4], f64>
+  return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK-LABEL: func.func @scf_while_inplace(
+// CHECK-SAME: !test.test_memref<[4, 4], f64>
+// CHECK: scf.while
+// CHECK: scf.condition
+// CHECK: scf.yield
+// CHECK: return
+func.func @scf_while_inplace(
+    %arg: !test.test_tensor<[4, 4], f64>,
+    %cond: i1)
+    -> !test.test_tensor<[4, 4], f64> {
+  %loop = scf.while (%iter = %arg)
+      : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> {
+    scf.condition(%cond) %iter : !test.test_tensor<[4, 4], f64>
+  } do {
+  ^bb0(%current: !test.test_tensor<[4, 4], f64>):
+    %call = func.call @custom_types_identity_2d(%current)
+      : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64>
+    scf.yield %call : !test.test_tensor<[4, 4], f64>
+  }
+  return %loop : !test.test_tensor<[4, 4], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 5fee060689d24..340b44b14dd96 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1796,6 +1796,19 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
   return mlir::success();
 }
 
+mlir::FailureOr<mlir::bufferization::BufferLikeType>
+test::TestDummyTensorOp::getBufferType(
+    mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+    const mlir::bufferization::BufferizationState &,
+    llvm::SmallVector<::mlir::Value> &) {
+  const auto type = dyn_cast<test::TestTensorType>(value.getType());
+  if (type == nullptr)
+    return failure();
+
+  return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
+      getContext(), type.getShape(), type.getElementType(), nullptr));
+}
+
 ::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
     ::mlir::RewriterBase &rewriter,
     const ::mlir::bufferization::BufferizationOptions &options,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 774329b9d2736..348ff5d7f4ea0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3936,10 +3936,13 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
 // Test Ops bufferization
 //===----------------------------------------------------------------------===//
 
-def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
-    [DeclareOpInterfaceMethods<BufferizableOpInterface,
-        ["bufferize", "bufferizesToMemoryRead",
-         "bufferizesToMemoryWrite", "getAliasingValues"]>]> {
+def TestDummyTensorOp
+    : TEST_Op<"dummy_tensor_op",
+              [DeclareOpInterfaceMethods<
+                  BufferizableOpInterface,
+                  ["bufferize", "getBufferType", "bufferizesToMemoryRead",
+                   "bufferizesToMemoryWrite", "getAliasingValues",
+                   "getAliasingOpOperands"]>]> {
   let arguments = (ins
     Arg<Bufferization_TensorLikeTypeInterface>:$input
   );
@@ -3959,7 +3962,23 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
     ::mlir::bufferization::AliasingValueList
     test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&,
         const ::mlir::bufferization::AnalysisState&) {
-      return {};
+      auto relation = getInput().getType() == getOutput().getType()
+                          ? ::mlir::bufferization::BufferRelation::Equivalent
+                          : ::mlir::bufferization::BufferRelation::Unknown;
+      return {{getOutput(), relation, /*isDefinite=*/true}};
+    }
+
+    ::mlir::bufferization::AliasingOpOperandList
+    test::TestDummyTensorOp::getAliasingOpOperands(::mlir::Value value,
+        const ::mlir::bufferization::AnalysisState&) {
+      if (value != getOutput())
+        return {};
+
+      auto relation = getInput().getType() == getOutput().getType()
+                          ? ::mlir::bufferization::BufferRelation::Equivalent
+                          : ::mlir::bufferization::BufferRelation::Unknown;
+      return {{&getOperation()->getOpOperand(0), relation,
+               /*isDefinite=*/true}};
     }
   }];
 }
@@ -3973,11 +3992,13 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
   );
 }
 
-def TestCreateTensorOp : TEST_Op<"create_tensor_op",
-    [DeclareOpInterfaceMethods<BufferizableOpInterface,
-        ["bufferize", "getBufferType", "bufferizesToMemoryRead",
-         "bufferizesToMemoryWrite", "getAliasingValues",
-         "bufferizesToAllocation"]>]> {
+def TestCreateTensorOp
+    : TEST_Op<"create_tensor_op",
+              [DeclareOpInterfaceMethods<
+                  BufferizableOpInterface,
+                  ["bufferize", "getBufferType", "bufferizesToMemoryRead",
+                   "bufferizesToMemoryWrite", "getAliasingValues",
+                   "getAliasingOpOperands", "bufferizesToAllocation"]>]> {
   let arguments = (ins);
   let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
   let extraClassDefinition = [{
@@ -3998,6 +4019,14 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
         const ::mlir::bufferization::AnalysisState&) {
       return {};
     }
+
+    ::mlir::bufferization::AliasingOpOperandList
+    test::TestCreateTensorOp::getAliasingOpOperands(
+        ::mlir::Value value,
+        const ::mlir::bufferization::AnalysisState&) {
+      (void)value;
+      return {};
+    }
   }];
 }
 



More information about the Mlir-commits mailing list