[Mlir-commits] [mlir] [mlir][bufferization] Use TensorLike, BufferLike type interfaces (PR #136736)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 22 11:03:14 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.
Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.
---
Patch is 75.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136736.diff
26 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+11-10)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+10-7)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td (+10-3)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+21-14)
- (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+7-6)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+54-40)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+5-1)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+16-14)
- (added) mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp (+21)
- (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+9-8)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+10-8)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+15-15)
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+60-49)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+2-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+5-4)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1)
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+4-4)
- (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+3-1)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+24)
- (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+54-1)
- (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
#include <optional>
#include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
namespace mlir {
class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
/// Initializer function for analysis state.
using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
- /// Tensor -> MemRef type converter.
- /// Parameters: tensor type, memory space, func op, bufferization options
+ /// TensorLike -> BufferLike type converter.
+ /// Parameters: tensor like type, memory space, func op, bufferization options
using FunctionArgTypeConverterFn =
- std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+ std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
- /// Tensor -> MemRef type converter.
+ /// TensorLike -> BufferLike type converter.
/// Parameters: Value, memory space, bufferization options
- using UnknownTypeConverterFn = std::function<BaseMemRefType(
+ using UnknownTypeConverterFn = std::function<BufferLikeType(
Value, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
- std::function<std::optional<Attribute>(TensorType t)>;
+ std::function<std::optional<Attribute>(TensorLikeType t)>;
BufferizationOptions();
@@ -360,7 +361,7 @@ struct BufferizationOptions {
// Returning std::nullopt will cause bufferization to fail (useful to indicate
// failure to determine memory space for a tensor type).
DefaultMemorySpaceFn defaultMemorySpaceFn =
- [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+ [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
/// If set to `true`, the analysis is skipped. A buffer is copied before every
/// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options);
/// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
/// IR, this function can be used.
///
/// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack);
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
away. However, such IR is no longer bufferizable with One-Shot Bufferize.
}];
- let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+ let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
"the reference to load from",
[MemReadAt<0, FullEffect>]>:$memref,
UnitAttr:$restrict, UnitAttr:$writable);
- let results = (outs AnyTensor:$result);
+ let results = (outs Bufferization_TensorLikeTypeInterface:$result);
let extraClassDeclaration = [{
/// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+ return ::llvm::cast<BufferLikeType>(getMemref().getType());
}
}];
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
// ToMemrefOp
//===----------------------------------------------------------------------===//
+// TODO: rename to "to_buffer"
def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
BufferizableOpInterface,
SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
the returned buffer) will not be written to.
}];
- let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
- let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+ let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+ UnitAttr:$read_only);
+ let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h" // mlir::Attribute
#include "mlir/IR/Types.h"
#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
let description = [{
Indicates that this type is a buffer type (similarly to a MLIR builtin
memref) for bufferization purposes.
-
- The interface currently has no methods as it is used by types to opt into
- being supported by the bufferization procedures.
}];
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the memory space in which data referred to by this buffer resides.
+ }],
+ /*retType=*/"::mlir::Attribute",
+ /*methodName=*/"getMemorySpace"
+ >,
+ ];
}
#endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
// Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// operand types of all forwarded values. If these are all the same type,
// take that type. Otherwise, take only the memory space and fall back to a
// buffer type with a fully dynamic layout map.
- BaseMemRefType bufferType;
+ BufferLikeType bufferType;
auto tensorType = cast<TensorType>(value.getType());
for (OpOperand *opOperand :
detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
continue;
// Compute the bufferized type of the forwarded operand.
- BaseMemRefType callerType;
- if (auto memrefType =
- dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+ BufferLikeType callerType;
+ if (auto bufferLikeType =
+ dyn_cast<BufferLikeType>(opOperand->get().getType())) {
// The operand was already bufferized. Take its type directly.
- callerType = memrefType;
+ callerType = bufferLikeType;
} else {
- FailureOr<BaseMemRefType> maybeCallerType =
+ FailureOr<BufferLikeType> maybeCallerType =
bufferization::getBufferType(opOperand->get(), options,
invocationStack);
if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
// of the earlier forwarded operands, fall back to a buffer type with a
// fully dynamic layout map.
#ifndef NDEBUG
+ assert(mlir::isa<BaseMemRefType>(bufferType) &&
+ mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+ auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+ auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
- assert(bufferType.hasRank() && callerType.hasRank() &&
+ assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
"expected ranked memrefs");
- assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
- rankedTensorType.getShape()}) &&
- "expected same shape");
+ assert(
+ llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+ rankedTensorType.getShape()}) &&
+ "expected same shape");
} else {
- assert(!bufferType.hasRank() && !callerType.hasRank() &&
+ assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
"expected unranked memrefs");
}
#endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
return op->emitOpError("incoming operands of block argument have "
"inconsistent memory spaces");
- bufferType = getMemRefTypeWithFullyDynamicLayout(
- tensorType, bufferType.getMemorySpace());
+ bufferType =
+ mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ tensorType, bufferType.getMemorySpace()));
}
if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto constantOp = cast<arith::ConstantOp>(op);
- auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+ auto type = dyn_cast<TensorLikeType>(constantOp.getType());
// Only ranked tensors are supported.
if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<bufferization::BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
SmallVector<Value> &invocationStack) const {
auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
- RankedTensorType::get(memrefType.getShape(),
- memrefType.getElementType()),
- memrefType.getMemorySpace());
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ RankedTensorType::get(memrefType.getShape(),
+ memrefType.getElementType()),
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
// Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
if (copy)
return allocTensorOp.getResult();
- FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+ FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
if (failed(copyBufferType))
return failure();
std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
if (!memorySpace)
- memorySpace = options.defaultMemorySpaceFn(tensorType);
+ memorySpace =
+ options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
if (memorySpace.has_value())
allocTensorOp.setMemorySpaceAttr(memorySpace.value());
return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
// Find all out-of-place OpOperands.
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
+ // Note: can only copy TensorType (any other TensorLikeType is rejected)
if (!llvm::isa<TensorType>(operandType))
continue;
if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
namespace {
/// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
- func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+ memorySpace));
}
/// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(
+ llvm::cast<TensorType>(value.getType()), memorySpace));
}
} // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
void BufferizationOptions::setFunctionBoundaryTypeConversion(
LayoutMapOption layoutMapOption) {
- functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
- func::FuncOp funcOp,
+ functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+ Attribute memorySpace, func::FuncOp funcOp,
const BufferizationOptions &options) {
if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
- return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
- memorySpace);
- return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
- memorySpace);
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
+ return mlir::cast<bufferization::BufferLikeType>(
+ bufferization::getMemRefTypeWithFullyDynamicLayout(
+ mlir::cast<TensorType>(tensorType), memorySpace));
};
inferFunctionResultLayout =
layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
/// read. Also takes into account ops that create an alias but do not read by
/// themselves (e.g., ExtractSliceOp).
bool AnalysisState::isValueRead(Value value) const {
- assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+ assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+ "expected TensorLikeType");
SmallVector<OpOperand *> workingSet;
DenseSet<OpOperand *> visited;
for (OpOperand &use : value.getUses())
@@ -66...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136736
More information about the Mlir-commits
mailing list