[Mlir-commits] [mlir] [MLIR] make One-Shot and SCF bufferization TensorLikeType-aware (PR #189073)
Dmitrii Makarenko
llvmlistbot at llvm.org
Wed Apr 15 09:28:56 PDT 2026
https://github.com/Devjiu updated https://github.com/llvm/llvm-project/pull/189073
>From 5e60afb3580da6728b66fdac07ce719d819fb9bc Mon Sep 17 00:00:00 2001
From: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
Date: Fri, 27 Mar 2026 17:42:11 +0000
Subject: [PATCH] [MLIR] make One-Shot and SCF bufferization
TensorLikeType-aware
Fix bufferization inconsistencies between builtin tensor types and custom
TensorLikeType implementations across One-Shot analysis/module paths and SCF
bufferization interfaces.
The main issue was a mix of TensorType/RankedTensorType checks in places that
need TensorLikeType-aware handling. This could leave function-boundary
equivalence/aliasing incomplete for custom tensor-like types, leading to
spurious SCF loop equivalence verification failures.
This change:
- switches relevant One-Shot analysis/module checks from TensorType/
RankedTensorType to TensorLikeType;
- updates generic/default aliasing utilities to treat TensorLikeType
consistently;
- updates SCF BufferizableOpInterface implementations (for/while/if/yield
related paths) to use TensorLikeType/BufferLikeType where appropriate;
- updates test custom ops to provide required aliasing/getBufferType hooks for
custom tensor-like types;
- refreshes and renames custom_types SCF tests to explicitly check memref
replacement after bufferization.
Potential follow-ups / known risk areas:
- SCF.Forall shared_outs still has RankedTensorType assumptions in signatures/
paths and should be audited for full TensorLikeType coverage.
- SCF.For and SCF.While resolveConflicts call allocateTensorForShapedValue,
which currently assumes ranked tensor/memref copy paths; this may still be a
limitation for some tensor-like/unranked scenarios.
Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko at intel.com>
---
.../IR/BufferizableOpInterface.td | 8 +-
.../IR/BufferizableOpInterface.cpp | 11 +-
.../Transforms/OneShotAnalysis.cpp | 28 +--
.../Transforms/OneShotModuleBufferize.cpp | 19 +-
.../BufferizableOpInterfaceImpl.cpp | 71 +++----
.../Transforms/one-shot-module-bufferize.mlir | 176 ++++++++++++++++++
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 13 ++
mlir/test/lib/Dialect/Test/TestOps.td | 49 ++++-
8 files changed, 299 insertions(+), 76 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index c7775f2407ebd..34aaf7432bdfd 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -287,8 +287,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/[{
// Does not have to be implemented for ops without tensor OpOperands.
- assert(::llvm::isa<::mlir::TensorType>(opOperand.get().getType()) &&
- "expected OpOperand with tensor type");
+ assert(::llvm::isa<::mlir::bufferization::TensorLikeType>(opOperand.get().getType()) &&
+ "expected OpOperand with tensor like type");
llvm_unreachable("getAliasingValues not implemented");
}]
>,
@@ -358,8 +358,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
"const ::mlir::bufferization::AnalysisState &":$state),
/*methodBody=*/"",
/*defaultImplementation=*/[{
- assert(isa<::mlir::TensorType>(value.getType()) &&
- "expected tensor type");
+ assert(isa<::mlir::bufferization::TensorLikeType>(value.getType()) &&
+ "expected tensor like type");
return ::mlir::bufferization::detail::defaultGetAliasingOpOperands(
value, state);
}]
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 08319ef9df79a..1696e527511a7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -508,7 +508,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool AnalysisState::isValueRead(Value value) const {
- assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+ assert(llvm::isa<TensorLikeType>(value.getType()) &&
+ "expected TensorLikeType");
SmallVector<OpOperand *> workingSet;
DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
@@ -948,7 +949,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
Operation *op = getOwnerOfValue(value);
SmallVector<AliasingOpOperand> result;
for (OpOperand &opOperand : op->getOpOperands()) {
- if (!llvm::isa<TensorType>(opOperand.get().getType()))
+ if (!llvm::isa<TensorLikeType>(opOperand.get().getType()))
continue;
AliasingValueList aliasingValues = state.getAliasingValues(opOperand);
for (const auto &it : aliasingValues)
@@ -1027,7 +1028,7 @@ bufferization::detail::unknownGetAliasingOpOperands(Value value) {
// with every OpOperand.
AliasingOpOperandList r;
for (OpOperand &operand : value.getDefiningOp()->getOpOperands())
- if (isa<TensorType>(operand.get().getType()))
+ if (isa<TensorLikeType>(operand.get().getType()))
r.addAlias({&operand, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
@@ -1040,12 +1041,12 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
// with every OpOperand.
AliasingValueList r;
for (OpResult result : opOperand.getOwner()->getOpResults())
- if (llvm::isa<TensorType>(result.getType()))
+ if (llvm::isa<TensorLikeType>(result.getType()))
r.addAlias({result, BufferRelation::Unknown, /*isDefinite=*/false});
for (Region ®ion : opOperand.getOwner()->getRegions())
if (!region.getBlocks().empty())
for (BlockArgument bbArg : region.getBlocks().front().getArguments())
- if (isa<TensorType>(bbArg.getType()))
+ if (isa<TensorLikeType>(bbArg.getType()))
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
return r;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index b57811868a725..a70a9a6232052 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -67,7 +67,7 @@ MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
using namespace mlir;
using namespace mlir::bufferization;
-static bool isaTensor(Type t) { return isa<TensorType>(t); }
+static bool isaTensor(Type t) { return isa<TensorLikeType>(t); }
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
@@ -100,7 +100,7 @@ static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
} else {
inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
for (OpOperand &opOperand : op->getOpOperands())
- if (isa<TensorType>(opOperand.get().getType()))
+ if (isa<TensorLikeType>(opOperand.get().getType()))
inPlaceVector[opOperand.getOperandNumber()] = "false";
}
inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
@@ -118,12 +118,12 @@ OneShotAnalysisState::OneShotAnalysisState(
// Set up alias sets.
op->walk([&](Operation *op) {
for (Value v : op->getResults())
- if (isa<TensorType>(v.getType()))
+ if (isa<TensorLikeType>(v.getType()))
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
- if (isa<TensorType>(bbArg.getType()))
+ if (isa<TensorLikeType>(bbArg.getType()))
createAliasInfoEntry(bbArg);
});
@@ -132,7 +132,7 @@ OneShotAnalysisState::OneShotAnalysisState(
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpOperand &opOperand : bufferizableOp->getOpOperands())
- if (isa<TensorType>(opOperand.get().getType()))
+ if (isa<TensorLikeType>(opOperand.get().getType()))
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
bufferizeInPlace(opOperand);
return WalkResult::advance();
@@ -195,7 +195,7 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// Check all tensor OpResults.
for (OpResult opResult : op->getOpResults()) {
- if (!isa<TensorType>(opResult.getType()))
+ if (!isa<TensorLikeType>(opResult.getType()))
continue;
// If there is no preceding definition, the tensor contents are
@@ -1001,7 +1001,7 @@ LogicalResult
OneShotAnalysisState::analyzeSingleOp(Operation *op,
const DominanceInfo &domInfo) {
for (OpOperand &opOperand : op->getOpOperands())
- if (isa<TensorType>(opOperand.get().getType()))
+ if (isa<TensorLikeType>(opOperand.get().getType()))
if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
return failure();
return success();
@@ -1013,7 +1013,7 @@ static void equivalenceAnalysis(SmallVector<Operation *> &ops,
for (Operation *op : ops) {
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
for (OpResult opResult : op->getOpResults()) {
- if (!isa<TensorType>(opResult.getType()))
+ if (!isa<TensorLikeType>(opResult.getType()))
continue;
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() == 0)
@@ -1085,7 +1085,7 @@ bottomUpFromTerminatorsHeuristic(Operation *op,
// we stay within the same region.
SmallVector<OpResult> worklist;
for (Value v : term->getOperands()) {
- if (!isa<TensorType>(v.getType()))
+ if (!isa<TensorLikeType>(v.getType()))
continue;
auto opResult = dyn_cast<OpResult>(v);
if (!opResult)
@@ -1102,7 +1102,7 @@ bottomUpFromTerminatorsHeuristic(Operation *op,
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
for (auto alias : aliases) {
Value v = alias.opOperand->get();
- if (!isa<TensorType>(v.getType()))
+ if (!isa<TensorLikeType>(v.getType()))
continue;
auto opResult = dyn_cast<OpResult>(v);
if (!opResult)
@@ -1222,7 +1222,7 @@ checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
}
for (OpOperand &opOperand : op->getOpOperands()) {
- if (isa<TensorType>(opOperand.get().getType())) {
+ if (isa<TensorLikeType>(opOperand.get().getType())) {
if (wouldCreateReadAfterWriteInterference(
opOperand, domInfo, state,
/*checkConsistencyOnly=*/true)) {
@@ -1259,7 +1259,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
// Add __inplace_operands_attr__.
op->walk([&](Operation *op) {
for (OpOperand &opOperand : op->getOpOperands())
- if (isa<TensorType>(opOperand.get().getType()))
+ if (isa<TensorLikeType>(opOperand.get().getType()))
setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
});
}
@@ -1284,7 +1284,7 @@ static void annotateOpsWithAliasSets(Operation *op,
// Build alias set array for every OpResult.
SmallVector<Attribute> opResultAliasSets;
for (OpResult opResult : op->getOpResults()) {
- if (llvm::isa<TensorType>(opResult.getType())) {
+ if (llvm::isa<TensorLikeType>(opResult.getType())) {
opResultAliasSets.push_back(buildAliasesArray(opResult));
}
}
@@ -1299,7 +1299,7 @@ static void annotateOpsWithAliasSets(Operation *op,
for (Block &block : r.getBlocks()) {
SmallVector<Attribute> bbArgAliasSets;
for (BlockArgument bbArg : block.getArguments()) {
- if (llvm::isa<TensorType>(bbArg.getType())) {
+ if (llvm::isa<TensorLikeType>(bbArg.getType())) {
bbArgAliasSets.push_back(buildAliasesArray(bbArg));
hasTensorBbArg = true;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index 6c5719ce6df8e..4d044bbb74df1 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -62,6 +62,7 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
@@ -122,10 +123,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// return value may alias with any tensor bbArg.
FunctionType type = funcOp.getFunctionType();
for (const auto &inputIt : llvm::enumerate(type.getInputs())) {
- if (!isa<TensorType>(inputIt.value()))
+ if (!isa<TensorLikeType>(inputIt.value()))
continue;
for (const auto &resultIt : llvm::enumerate(type.getResults())) {
- if (!isa<TensorType>(resultIt.value()))
+ if (!isa<TensorLikeType>(resultIt.value()))
continue;
int64_t returnIdx = resultIt.index();
int64_t bbArgIdx = inputIt.index();
@@ -145,13 +146,13 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
// Build alias sets. Merge all aliases from all func.return ops.
for (BlockArgument bbArg : funcOp.getArguments()) {
- if (isa<RankedTensorType>(bbArg.getType())) {
+ if (isa<TensorLikeType>(bbArg.getType())) {
int64_t bbArgIdx = bbArg.getArgNumber();
// Store aliases in a set, so that we don't add the same alias twice.
SetVector<int64_t> aliases;
for (func::ReturnOp returnOp : returnOps) {
for (OpOperand &returnVal : returnOp->getOpOperands()) {
- if (isa<RankedTensorType>(returnVal.get().getType())) {
+ if (isa<TensorLikeType>(returnVal.get().getType())) {
int64_t returnIdx = returnVal.getOperandNumber();
if (state.areAliasingBufferizedValues(returnVal.get(), bbArg))
aliases.insert(returnIdx);
@@ -170,10 +171,10 @@ aliasingFuncOpBBArgsAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
auto findEquivalentBlockArgIdx =
[&](OpOperand &opOperand) -> std::optional<int64_t> {
Value v = opOperand.get();
- if (!isa<TensorType>(v.getType()))
+ if (!isa<TensorLikeType>(v.getType()))
return std::nullopt;
for (BlockArgument bbArg : funcOp.getArguments()) {
- if (isa<RankedTensorType>(bbArg.getType())) {
+ if (isa<TensorLikeType>(bbArg.getType())) {
if (state.areEquivalentBufferizedValues(v, bbArg)) {
if (state.getOptions().testAnalysisOnly)
annotateEquivalentReturnBbArg(opOperand, bbArg);
@@ -243,7 +244,7 @@ funcOpBbArgReadWriteAnalysis(FuncOp funcOp, OneShotAnalysisState &state,
for (int64_t idx = 0, e = funcOp.getFunctionType().getNumInputs(); idx < e;
++idx) {
// Skip non-tensor arguments.
- if (!isa<TensorType>(funcOp.getFunctionType().getInput(idx)))
+ if (!isa<TensorLikeType>(funcOp.getFunctionType().getInput(idx)))
continue;
bool isRead;
bool isWritten;
@@ -297,9 +298,9 @@ getCalledFunction(func::CallOp callOp,
/// Return "true" if the given function signature has tensor semantics.
static bool hasTensorSignature(func::FuncOp funcOp) {
return llvm::any_of(funcOp.getFunctionType().getInputs(),
- llvm::IsaPred<TensorType>) ||
+ llvm::IsaPred<TensorLikeType>) ||
llvm::any_of(funcOp.getFunctionType().getResults(),
- llvm::IsaPred<TensorType>);
+ llvm::IsaPred<TensorLikeType>);
}
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 9b6a5a96fbc6b..32aba47f50866 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -32,11 +32,12 @@ namespace {
/// Helper function for loop bufferization. Cast the given buffer to the given
/// memref type.
static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
- assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
- assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
// If the buffer already has the correct type, no cast is needed.
if (buffer.getType() == type)
return buffer;
+
+ assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType");
+ assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType");
// TODO: In case `type` has a layout map that is not the fully dynamic
// one, we may not be able to cast the buffer. In that case, the loop
// iter_arg's layout map must be changed (see uses of `castBuffer`).
@@ -102,7 +103,7 @@ struct ConditionOpInterface
SmallVector<Value> newArgs;
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Value value = it.value();
- if (isa<TensorType>(value.getType())) {
+ if (isa<TensorLikeType>(value.getType())) {
FailureOr<Value> maybeBuffer =
getBuffer(rewriter, value, options, state);
if (failed(maybeBuffer))
@@ -247,7 +248,7 @@ struct IfOpInterface
// Compute bufferized result types.
SmallVector<Type> newTypes;
for (Value result : ifOp.getResults()) {
- if (!isa<TensorType>(result.getType())) {
+ if (!isa<TensorLikeType>(result.getType())) {
newTypes.push_back(result.getType());
continue;
}
@@ -286,25 +287,23 @@ struct IfOpInterface
auto opResult = cast<OpResult>(value);
auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber());
auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber());
- BaseMemRefType thenBufferType, elseBufferType;
- if (isa<BaseMemRefType>(thenValue.getType())) {
+ BufferLikeType thenBufferType, elseBufferType;
+ if (isa<BufferLikeType>(thenValue.getType())) {
// True branch was already bufferized.
- thenBufferType = cast<BaseMemRefType>(thenValue.getType());
+ thenBufferType = cast<BufferLikeType>(thenValue.getType());
} else {
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- thenValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(
+ thenValue, options, state, invocationStack);
if (failed(maybeBufferType))
return failure();
thenBufferType = *maybeBufferType;
}
- if (isa<BaseMemRefType>(elseValue.getType())) {
+ if (isa<BufferLikeType>(elseValue.getType())) {
// False branch was already bufferized.
- elseBufferType = cast<BaseMemRefType>(elseValue.getType());
+ elseBufferType = cast<BufferLikeType>(elseValue.getType());
} else {
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- elseValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(
+ elseValue, options, state, invocationStack);
if (failed(maybeBufferType))
return failure();
elseBufferType = *maybeBufferType;
@@ -315,12 +314,17 @@ struct IfOpInterface
return cast<BufferLikeType>(thenBufferType);
// Memory space mismatch.
- if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
+ auto thenBaseMemRefType = dyn_cast<BaseMemRefType>(thenBufferType);
+ auto elseBaseMemRefType = dyn_cast<BaseMemRefType>(elseBufferType);
+ if (thenBaseMemRefType && elseBaseMemRefType &&
+ thenBaseMemRefType.getMemorySpace() !=
+ elseBaseMemRefType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
+ cast<TensorType>(opResult.getType()),
+ thenBaseMemRefType.getMemorySpace()));
}
};
@@ -444,7 +448,7 @@ struct IndexSwitchOpInterface
static DenseSet<int64_t> getTensorIndices(ValueRange values) {
DenseSet<int64_t> result;
for (const auto &it : llvm::enumerate(values))
- if (isa<TensorType>(it.value().getType()))
+ if (isa<TensorLikeType>(it.value().getType()))
result.insert(it.index());
return result;
}
@@ -457,8 +461,8 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size());
DenseSet<int64_t> result;
for (unsigned int i = 0; i < minSize; ++i) {
- if (!isa<TensorType>(bbArgs[i].getType()) ||
- !isa<TensorType>(yieldedValues[i].getType()))
+ if (!isa<TensorLikeType>(bbArgs[i].getType()) ||
+ !isa<TensorLikeType>(yieldedValues[i].getType()))
continue;
if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i]))
result.insert(i);
@@ -473,7 +477,7 @@ getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands,
const BufferizationOptions &options, BufferizationState &state) {
SmallVector<Value> result;
for (OpOperand &opOperand : operands) {
- if (isa<TensorType>(opOperand.get().getType())) {
+ if (isa<TensorLikeType>(opOperand.get().getType())) {
FailureOr<Value> resultBuffer =
getBuffer(rewriter, opOperand.get(), options, state);
if (failed(resultBuffer))
@@ -547,7 +551,7 @@ static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
// Compute the buffer type of the yielded value.
BufferLikeType yieldedValueBufferType;
- if (isa<BaseMemRefType>(yieldedValue.getType())) {
+ if (isa<BufferLikeType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
@@ -712,7 +716,7 @@ struct ForOpInterface
SmallVector<Value> &invocationStack) const {
auto forOp = cast<scf::ForOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(isa<TensorType>(value.getType()) && "expected tensor type");
+ assert(isa<TensorLikeType>(value.getType()) && "expected tensor type");
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
@@ -757,7 +761,7 @@ struct ForOpInterface
Value initArg = it.value();
Value result = forOp->getResult(it.index());
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!isa<TensorType>(result.getType())) {
+ if (!isa<TensorLikeType>(result.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -809,9 +813,8 @@ struct ForOpInterface
auto forOp = cast<scf::ForOp>(op);
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
for (OpResult opResult : op->getOpResults()) {
- if (!isa<TensorType>(opResult.getType()))
+ if (!isa<TensorLikeType>(opResult.getType()))
continue;
-
// Note: This is overly strict. We should check for aliasing bufferized
// values. But we don't have a "must-alias" analysis yet.
if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
@@ -938,7 +941,7 @@ struct WhileOpInterface
for (int64_t idx = 0;
idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) {
Value value = conditionOp.getArgs()[idx];
- if (!isa<TensorType>(value.getType()) ||
+ if (!isa<TensorLikeType>(value.getType()) ||
(equivalentYieldsAfter.contains(idx) &&
equivalentYieldsBefore.contains(idx))) {
beforeYieldValues.push_back(value);
@@ -982,7 +985,7 @@ struct WhileOpInterface
Value initArg = it.value();
Value beforeArg = whileOp.getBeforeArguments()[it.index()];
// If the type is not a tensor, bufferization doesn't need to touch it.
- if (!isa<TensorType>(beforeArg.getType())) {
+ if (!isa<TensorLikeType>(beforeArg.getType())) {
castedInitArgs.push_back(initArg);
continue;
}
@@ -995,7 +998,7 @@ struct WhileOpInterface
// The result types of a WhileOp are the same as the "after" bbArg types.
SmallVector<Type> argsTypesAfter = llvm::map_to_vector(
whileOp.getAfterArguments(), [&](BlockArgument bbArg) {
- if (!isa<TensorType>(bbArg.getType()))
+ if (!isa<TensorLikeType>(bbArg.getType()))
return bbArg.getType();
// TODO: error handling
return llvm::cast<Type>(
@@ -1048,7 +1051,7 @@ struct WhileOpInterface
SmallVector<Value> &invocationStack) const {
auto whileOp = cast<scf::WhileOp>(op);
assert(getOwnerOfValue(value) == op && "invalid value");
- assert(isa<TensorType>(value.getType()) && "expected tensor type");
+ assert(isa<TensorLikeType>(value.getType()) && "expected tensor type");
// Case 1: Block argument of the "before" region.
if (auto bbArg = dyn_cast<BlockArgument>(value)) {
@@ -1074,7 +1077,7 @@ struct WhileOpInterface
llvm_unreachable("invalid value");
}
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
- if (!isa<TensorType>(conditionYieldedVal.getType())) {
+ if (!isa<TensorLikeType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
return cast<BufferLikeType>(conditionYieldedVal.getType());
}
@@ -1103,7 +1106,7 @@ struct WhileOpInterface
auto conditionOp = whileOp.getConditionOp();
for (const auto &it : llvm::enumerate(conditionOp.getArgs())) {
Block *block = conditionOp->getBlock();
- if (!isa<TensorType>(it.value().getType()))
+ if (!isa<TensorLikeType>(it.value().getType()))
continue;
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
@@ -1116,7 +1119,7 @@ struct WhileOpInterface
auto yieldOp = whileOp.getYieldOp();
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Block *block = yieldOp->getBlock();
- if (!isa<TensorType>(it.value().getType()))
+ if (!isa<TensorLikeType>(it.value().getType()))
continue;
if (it.index() >= block->getNumArguments() ||
!state.areEquivalentBufferizedValues(it.value(),
@@ -1176,7 +1179,7 @@ struct YieldOpInterface
SmallVector<Value> newResults;
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
Value value = it.value();
- if (isa<TensorType>(value.getType())) {
+ if (isa<TensorLikeType>(value.getType())) {
FailureOr<Value> maybeBuffer =
getBuffer(rewriter, value, options, state);
if (failed(maybeBuffer))
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
index d5cb7a0f14f5a..f8d5a1310ebdf 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir
@@ -905,3 +905,179 @@ func.func @ranked_return_via_unranked_call(%arg0: tensor<64x20x40xf32>) -> tenso
return %b : tensor<64x20x40xf32>
}
func.func private @relu_unranked(tensor<*xf32>) -> tensor<*xf32>
+
+// -----
+
+// CHECK: func.func @custom_types_scf_for_inplace(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>,
+// CHECK-SAME: %[[lb:.+]]: index, %[[ub:.+]]: index, %[[step:.+]]: index
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_for_inplace(
+ %arg: !test.test_tensor<[4, 4], f64>,
+ %lb: index, %ub: index, %step: index)
+ -> !test.test_tensor<[4, 4], f64> {
+ // CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]]
+ // CHECK-SAME: iter_args(%[[iter:.+]] = %[[arg]]) -> (!test.test_memref<[4, 4], f64>) {
+ // CHECK: %[[call:.+]] = "test.dummy_memref_op"(%[[iter]])
+ // CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64>
+ %loop = scf.for %i = %lb to %ub step %step
+ iter_args(%iter = %arg) -> (!test.test_tensor<[4, 4], f64>) {
+ // Inside loop: use iter_args directly (this is inplace modifiable op)
+ %call = "test.dummy_tensor_op"(%iter) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ // Yield: return the same iter_args value (or result of inplace op on it)
+ scf.yield %call : !test.test_tensor<[4, 4], f64>
+ }
+
+ // CHECK: return %[[loop]] : !test.test_memref<[4, 4], f64>
+ return %loop : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64> {
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// Same as @custom_types_scf_for_inplace, but with an inner call to test alias analysis
+// through function boundaries.
+// CHECK-LABEL: func.func @custom_types_scf_for_inplace_with_call(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: %[[lb:.+]]: index, %[[ub:.+]]: index, %[[step:.+]]: index
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+// CHECK: %[[loop:.+]] = scf.for %{{.*}} = %[[lb]] to %[[ub]] step %[[step]] iter_args(%[[iter:.+]] = %[[arg]]) -> (!test.test_memref<[4, 4], f64>) {
+// CHECK: %[[call:.+]] = func.call @custom_types_identity_2d(%[[iter]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64>
+// CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64>
+// CHECK: return %[[loop]] : !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_for_inplace_with_call(
+ %arg: !test.test_tensor<[4, 4], f64>,
+ %lb: index, %ub: index, %step: index)
+ -> !test.test_tensor<[4, 4], f64> {
+ %loop = scf.for %i = %lb to %ub step %step
+ iter_args(%iter = %arg) -> (!test.test_tensor<[4, 4], f64>) {
+ %call = func.call @custom_types_identity_2d(%iter)
+ : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64>
+ scf.yield %call : !test.test_tensor<[4, 4], f64>
+ }
+
+ return %loop : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @custom_types_scf_if_inplace(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: %[[cond:.+]]: i1
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+// CHECK: %[[res:.+]] = scf.if %[[cond]] -> (!test.test_memref<[4, 4], f64>) {
+// CHECK: %[[dummy:.+]] = "test.dummy_memref_op"(%[[arg]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64>
+// CHECK: scf.yield %[[dummy]] : !test.test_memref<[4, 4], f64>
+// CHECK: } else {
+// CHECK: scf.yield %[[arg]] : !test.test_memref<[4, 4], f64>
+// CHECK: }
+// CHECK: return %[[res]] : !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_if_inplace(
+ %arg: !test.test_tensor<[4, 4], f64>,
+ %cond: i1)
+ -> !test.test_tensor<[4, 4], f64> {
+ %res = scf.if %cond -> (!test.test_tensor<[4, 4], f64>) {
+ %dummy = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ scf.yield %dummy : !test.test_tensor<[4, 4], f64>
+ } else {
+ scf.yield %arg : !test.test_tensor<[4, 4], f64>
+ }
+ return %res : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64> {
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK-LABEL: func.func @custom_types_scf_if_inplace_with_call(
+// CHECK-SAME: %[[arg:.+]]: !test.test_memref<[4, 4], f64>
+// CHECK-SAME: %[[cond:.+]]: i1
+// CHECK-SAME: ) -> !test.test_memref<[4, 4], f64>
+// CHECK: %[[res:.+]] = scf.if %[[cond]] -> (!test.test_memref<[4, 4], f64>) {
+// CHECK: %[[call:.+]] = func.call @custom_types_identity_2d(%[[arg]]) : (!test.test_memref<[4, 4], f64>) -> !test.test_memref<[4, 4], f64>
+// CHECK: scf.yield %[[call]] : !test.test_memref<[4, 4], f64>
+// CHECK: } else {
+// CHECK: scf.yield %[[arg]] : !test.test_memref<[4, 4], f64>
+// CHECK: }
+// CHECK: return %[[res]] : !test.test_memref<[4, 4], f64>
+func.func @custom_types_scf_if_inplace_with_call(
+ %arg: !test.test_tensor<[4, 4], f64>,
+ %cond: i1)
+ -> !test.test_tensor<[4, 4], f64> {
+ %res = scf.if %cond -> (!test.test_tensor<[4, 4], f64>) {
+ %call = func.call @custom_types_identity_2d(%arg)
+ : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64>
+ scf.yield %call : !test.test_tensor<[4, 4], f64>
+ } else {
+ scf.yield %arg : !test.test_tensor<[4, 4], f64>
+ }
+ return %res : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @scf_while_inplace(
+// CHECK-SAME: !test.test_memref<[4, 4], f64>
+// CHECK: scf.while
+// CHECK: scf.condition
+// CHECK: scf.yield
+// CHECK: return
+func.func @scf_while_inplace(
+ %arg: !test.test_tensor<[4, 4], f64>,
+ %cond: i1)
+ -> !test.test_tensor<[4, 4], f64> {
+ %loop = scf.while (%iter = %arg)
+ : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> {
+ scf.condition(%cond) %iter : !test.test_tensor<[4, 4], f64>
+ } do {
+ ^bb0(%current: !test.test_tensor<[4, 4], f64>):
+ %dummy = "test.dummy_tensor_op"(%current) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ scf.yield %dummy : !test.test_tensor<[4, 4], f64>
+ }
+ return %loop : !test.test_tensor<[4, 4], f64>
+}
+
+// -----
+
+func.func private @custom_types_identity_2d(%arg: !test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64> {
+ %out = "test.dummy_tensor_op"(%arg) : (!test.test_tensor<[4, 4], f64>)
+ -> !test.test_tensor<[4, 4], f64>
+ return %out : !test.test_tensor<[4, 4], f64>
+}
+
+// CHECK-LABEL: func.func @scf_while_inplace(
+// CHECK-SAME: !test.test_memref<[4, 4], f64>
+// CHECK: scf.while
+// CHECK: scf.condition
+// CHECK: scf.yield
+// CHECK: return
+func.func @scf_while_inplace(
+ %arg: !test.test_tensor<[4, 4], f64>,
+ %cond: i1)
+ -> !test.test_tensor<[4, 4], f64> {
+ %loop = scf.while (%iter = %arg)
+ : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64> {
+ scf.condition(%cond) %iter : !test.test_tensor<[4, 4], f64>
+ } do {
+ ^bb0(%current: !test.test_tensor<[4, 4], f64>):
+ %call = func.call @custom_types_identity_2d(%current)
+ : (!test.test_tensor<[4, 4], f64>) -> !test.test_tensor<[4, 4], f64>
+ scf.yield %call : !test.test_tensor<[4, 4], f64>
+ }
+ return %loop : !test.test_tensor<[4, 4], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 5fee060689d24..340b44b14dd96 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1796,6 +1796,19 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
return mlir::success();
}
+mlir::FailureOr<mlir::bufferization::BufferLikeType>
+test::TestDummyTensorOp::getBufferType(
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ const mlir::bufferization::BufferizationState &,
+ llvm::SmallVector<::mlir::Value> &) {
+ const auto type = dyn_cast<test::TestTensorType>(value.getType());
+ if (type == nullptr)
+ return failure();
+
+ return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
+ getContext(), type.getShape(), type.getElementType(), nullptr));
+}
+
::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
::mlir::RewriterBase &rewriter,
const ::mlir::bufferization::BufferizationOptions &options,
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 774329b9d2736..348ff5d7f4ea0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3936,10 +3936,13 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
// Test Ops bufferization
//===----------------------------------------------------------------------===//
-def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
- [DeclareOpInterfaceMethods<BufferizableOpInterface,
- ["bufferize", "bufferizesToMemoryRead",
- "bufferizesToMemoryWrite", "getAliasingValues"]>]> {
+def TestDummyTensorOp
+ : TEST_Op<"dummy_tensor_op",
+ [DeclareOpInterfaceMethods<
+ BufferizableOpInterface,
+ ["bufferize", "getBufferType", "bufferizesToMemoryRead",
+ "bufferizesToMemoryWrite", "getAliasingValues",
+ "getAliasingOpOperands"]>]> {
let arguments = (ins
Arg<Bufferization_TensorLikeTypeInterface>:$input
);
@@ -3959,7 +3962,23 @@ def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
::mlir::bufferization::AliasingValueList
test::TestDummyTensorOp::getAliasingValues(::mlir::OpOperand&,
const ::mlir::bufferization::AnalysisState&) {
- return {};
+ auto relation = getInput().getType() == getOutput().getType()
+ ? ::mlir::bufferization::BufferRelation::Equivalent
+ : ::mlir::bufferization::BufferRelation::Unknown;
+ return {{getOutput(), relation, /*isDefinite=*/true}};
+ }
+
+ ::mlir::bufferization::AliasingOpOperandList
+ test::TestDummyTensorOp::getAliasingOpOperands(::mlir::Value value,
+ const ::mlir::bufferization::AnalysisState&) {
+ if (value != getOutput())
+ return {};
+
+ auto relation = getInput().getType() == getOutput().getType()
+ ? ::mlir::bufferization::BufferRelation::Equivalent
+ : ::mlir::bufferization::BufferRelation::Unknown;
+ return {{&getOperation()->getOpOperand(0), relation,
+ /*isDefinite=*/true}};
}
}];
}
@@ -3973,11 +3992,13 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
);
}
-def TestCreateTensorOp : TEST_Op<"create_tensor_op",
- [DeclareOpInterfaceMethods<BufferizableOpInterface,
- ["bufferize", "getBufferType", "bufferizesToMemoryRead",
- "bufferizesToMemoryWrite", "getAliasingValues",
- "bufferizesToAllocation"]>]> {
+def TestCreateTensorOp
+ : TEST_Op<"create_tensor_op",
+ [DeclareOpInterfaceMethods<
+ BufferizableOpInterface,
+ ["bufferize", "getBufferType", "bufferizesToMemoryRead",
+ "bufferizesToMemoryWrite", "getAliasingValues",
+ "getAliasingOpOperands", "bufferizesToAllocation"]>]> {
let arguments = (ins);
let results = (outs Arg<Bufferization_TensorLikeTypeInterface>:$output);
let extraClassDefinition = [{
@@ -3998,6 +4019,14 @@ def TestCreateTensorOp : TEST_Op<"create_tensor_op",
const ::mlir::bufferization::AnalysisState&) {
return {};
}
+
+ ::mlir::bufferization::AliasingOpOperandList
+ test::TestCreateTensorOp::getAliasingOpOperands(
+ ::mlir::Value value,
+ const ::mlir::bufferization::AnalysisState&) {
+ (void)value;
+ return {};
+ }
}];
}
More information about the Mlir-commits
mailing list