[Mlir-commits] [mlir] [mlir][bufferization] Use TensorLike, BufferLike type interfaces (PR #136736)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 22 11:03:15 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-arith

Author: Andrei Golubev (andrey-golubev)

<details>
<summary>Changes</summary>

The general idea is to replace most of the places that rely on builtin's TensorType / BaseMemRefType with the newly added type interfaces.

Thus far, do the bare minimum: refactor (almost) "blindly" the API of the dialect and options, leaving most of the logic "as is". The exceptions are the bufferization.{to_tensor, to_memref} ops that act as "glue" when bufferizing neighbouring operations and the enclosing functions.

---

Patch is 75.29 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136736.diff


26 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+11-10) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+10-7) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td (+10-3) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+21-14) 
- (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+7-6) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+54-40) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp (+5-1) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+16-14) 
- (added) mlir/lib/Dialect/Bufferization/IR/BufferizationTypeInterfaces.cpp (+21) 
- (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp (+9-8) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+10-8) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+15-15) 
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+60-49) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+5-4) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1) 
- (modified) mlir/test/Dialect/Bufferization/invalid.mlir (+4-4) 
- (modified) mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp (+3-1) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+24) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+54-1) 
- (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index ada9539e87121..70092908d961f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -259,18 +260,18 @@ struct BufferizationOptions {
       std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
   /// Initializer function for analysis state.
   using AnalysisStateInitFn = std::function<void(AnalysisState &)>;
-  /// Tensor -> MemRef type converter.
-  /// Parameters: tensor type, memory space, func op, bufferization options
+  /// TensorLike -> BufferLike type converter.
+  /// Parameters: tensor like type, memory space, func op, bufferization options
   using FunctionArgTypeConverterFn =
-      std::function<BaseMemRefType(TensorType, Attribute memorySpace,
+      std::function<BufferLikeType(TensorLikeType, Attribute memorySpace,
                                    func::FuncOp, const BufferizationOptions &)>;
-  /// Tensor -> MemRef type converter.
+  /// TensorLike -> BufferLike type converter.
   /// Parameters: Value, memory space, bufferization options
-  using UnknownTypeConverterFn = std::function<BaseMemRefType(
+  using UnknownTypeConverterFn = std::function<BufferLikeType(
       Value, Attribute memorySpace, const BufferizationOptions &)>;
   // Produce a MemorySpace attribute from a tensor type
   using DefaultMemorySpaceFn =
-      std::function<std::optional<Attribute>(TensorType t)>;
+      std::function<std::optional<Attribute>(TensorLikeType t)>;
 
   BufferizationOptions();
 
@@ -360,7 +361,7 @@ struct BufferizationOptions {
   // Returning std::nullopt will cause bufferization to fail (useful to indicate
   // failure to determine memory space for a tensor type).
   DefaultMemorySpaceFn defaultMemorySpaceFn =
-      [](TensorType t) -> std::optional<Attribute> { return Attribute(); };
+      [](TensorLikeType t) -> std::optional<Attribute> { return Attribute(); };
 
   /// If set to `true`, the analysis is skipped. A buffer is copied before every
   /// write. This flag cannot be used together with `testAnalysisOnly = true`.
@@ -600,7 +601,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options);
 
 /// Return the buffer type for a given Value (tensor) after bufferization
@@ -613,7 +614,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         SmallVector<Value> &invocationStack);
 
@@ -693,7 +694,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      SmallVector<Value> &invocationStack);
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index 95022d7d665d2..1de1742fab81a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -518,7 +518,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index fad78a63444b9..81ce0f3fb650b 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -13,6 +13,7 @@ include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -109,7 +110,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack);
 
@@ -438,11 +439,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
                            [MemReadAt<0, FullEffect>]>:$memref,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -465,10 +466,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BufferLikeType>(getMemref().getType());
     }
   }];
 
@@ -493,6 +494,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 // ToMemrefOp
 //===----------------------------------------------------------------------===//
 
+// TODO: rename to "to_buffer"
 def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
@@ -519,8 +521,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor,
+                       UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$memref);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index 5faa1479ee542..290f1298f2501 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h" // mlir::Attribute
 #include "mlir/IR/Types.h"
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
index f19224a295648..c053a6bdc1a91 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td
@@ -33,10 +33,17 @@ def Bufferization_BufferLikeTypeInterface
   let description = [{
     Indicates that this type is a buffer type (similarly to a MLIR builtin
     memref) for bufferization purposes.
-
-    The interface currently has no methods as it is used by types to opt into
-    being supported by the bufferization procedures.
   }];
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the memory space in which data referred to by this buffer resides.
+      }],
+      /*retType=*/"::mlir::Attribute",
+      /*methodName=*/"getMemorySpace"
+    >,
+  ];
 }
 
 #endif // BUFFERIZATION_TYPE_INTERFACES
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index 78109770efab7..89eb65c4a0942 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     // Note: The user may want to override this function for OpResults in
@@ -46,7 +46,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     // operand types of all forwarded values. If these are all the same type,
     // take that type. Otherwise, take only the memory space and fall back to a
     // buffer type with a fully dynamic layout map.
-    BaseMemRefType bufferType;
+    BufferLikeType bufferType;
     auto tensorType = cast<TensorType>(value.getType());
     for (OpOperand *opOperand :
          detail::getCallerOpOperands(cast<BlockArgument>(value))) {
@@ -59,13 +59,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         continue;
 
       // Compute the bufferized type of the forwarded operand.
-      BaseMemRefType callerType;
-      if (auto memrefType =
-              dyn_cast<BaseMemRefType>(opOperand->get().getType())) {
+      BufferLikeType callerType;
+      if (auto bufferLikeType =
+              dyn_cast<BufferLikeType>(opOperand->get().getType())) {
         // The operand was already bufferized. Take its type directly.
-        callerType = memrefType;
+        callerType = bufferLikeType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options,
                                          invocationStack);
         if (failed(maybeCallerType))
@@ -86,14 +86,20 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // of the earlier forwarded operands, fall back to a buffer type with a
         // fully dynamic layout map.
 #ifndef NDEBUG
+      assert(mlir::isa<BaseMemRefType>(bufferType) &&
+             mlir::isa<BaseMemRefType>(callerType) && "expected memrefs");
+      auto memrefType = mlir::cast<BaseMemRefType>(bufferType);
+      auto callerMemrefType = mlir::cast<BaseMemRefType>(callerType);
+
       if (auto rankedTensorType = dyn_cast<RankedTensorType>(tensorType)) {
-        assert(bufferType.hasRank() && callerType.hasRank() &&
+        assert(memrefType.hasRank() && callerMemrefType.hasRank() &&
                "expected ranked memrefs");
-        assert(llvm::all_equal({bufferType.getShape(), callerType.getShape(),
-                                rankedTensorType.getShape()}) &&
-               "expected same shape");
+        assert(
+            llvm::all_equal({memrefType.getShape(), callerMemrefType.getShape(),
+                             rankedTensorType.getShape()}) &&
+            "expected same shape");
       } else {
-        assert(!bufferType.hasRank() && !callerType.hasRank() &&
+        assert(!memrefType.hasRank() && !callerMemrefType.hasRank() &&
                "expected unranked memrefs");
       }
 #endif // NDEBUG
@@ -102,8 +108,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         return op->emitOpError("incoming operands of block argument have "
                                "inconsistent memory spaces");
 
-      bufferType = getMemRefTypeWithFullyDynamicLayout(
-          tensorType, bufferType.getMemorySpace());
+      bufferType =
+          mlir::cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+              tensorType, bufferType.getMemorySpace()));
     }
 
     if (!bufferType)
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 5e69a98db8f1e..433757192bfd1 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -26,7 +26,7 @@ struct ConstantOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto constantOp = cast<arith::ConstantOp>(op);
-    auto type = dyn_cast<RankedTensorType>(constantOp.getType());
+    auto type = dyn_cast<TensorLikeType>(constantOp.getType());
 
     // Only ranked tensors are supported.
     if (!type)
@@ -176,7 +176,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<bufferization::BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
@@ -195,10 +195,11 @@ struct SelectOpInterface
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
-        RankedTensorType::get(memrefType.getShape(),
-                              memrefType.getElementType()),
-        memrefType.getMemorySpace());
+    return mlir::cast<bufferization::BufferLikeType>(
+        getMemRefTypeWithFullyDynamicLayout(
+            RankedTensorType::get(memrefType.getShape(),
+                                  memrefType.getElementType()),
+            memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 99ffa62c41a4d..82ff1bdfe5fd7 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -206,12 +206,13 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType = getBufferType(tensor, options);
+  FailureOr<BufferLikeType> copyBufferType = getBufferType(tensor, options);
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
   if (!memorySpace)
-    memorySpace = options.defaultMemorySpaceFn(tensorType);
+    memorySpace =
+        options.defaultMemorySpaceFn(mlir::cast<TensorLikeType>(tensorType));
   if (memorySpace.has_value())
     allocTensorOp.setMemorySpaceAttr(memorySpace.value());
   return allocTensorOp.getResult();
@@ -229,6 +230,7 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts(
   // Find all out-of-place OpOperands.
   for (OpOperand &opOperand : op->getOpOperands()) {
     Type operandType = opOperand.get().getType();
+    // Note: can only copy TensorType (any other TensorLikeType is rejected)
     if (!llvm::isa<TensorType>(operandType))
       continue;
     if (state.isInPlace(opOperand))
@@ -328,18 +330,21 @@ bool OpFilter::isOpAllowed(Operation *op) const {
 namespace {
 
 /// Default function arg type converter: Use a fully dynamic layout map.
-BaseMemRefType
-defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
-                                func::FuncOp funcOp,
+bufferization::BufferLikeType
+defaultFunctionArgTypeConverter(bufferization::TensorLikeType type,
+                                Attribute memorySpace, func::FuncOp funcOp,
                                 const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(type, memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(mlir::cast<TensorType>(type),
+                                          memorySpace));
 }
 /// Default unknown type converter: Use a fully dynamic layout map.
-BaseMemRefType
+BufferLikeType
 defaultUnknownTypeConverter(Value value, Attribute memorySpace,
                             const BufferizationOptions &options) {
-  return getMemRefTypeWithFullyDynamicLayout(
-      llvm::cast<TensorType>(value.getType()), memorySpace);
+  return mlir::cast<bufferization::BufferLikeType>(
+      getMemRefTypeWithFullyDynamicLayout(
+          llvm::cast<TensorType>(value.getType()), memorySpace));
 }
 
 } // namespace
@@ -376,14 +381,16 @@ BufferizationOptions::dynCastBufferizableOp(Value value) const {
 
 void BufferizationOptions::setFunctionBoundaryTypeConversion(
     LayoutMapOption layoutMapOption) {
-  functionArgTypeConverterFn = [=](TensorType tensorType, Attribute memorySpace,
-                                   func::FuncOp funcOp,
+  functionArgTypeConverterFn = [=](TensorLikeType tensorType,
+                                   Attribute memorySpace, func::FuncOp funcOp,
                                    const BufferizationOptions &options) {
     if (layoutMapOption == LayoutMapOption::IdentityLayoutMap)
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(tensorType,
-                                                                  memorySpace);
-    return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
-                                                              memorySpace);
+      return mlir::cast<bufferization::BufferLikeType>(
+          bufferization::getMemRefTypeWithStaticIdentityLayout(
+              mlir::cast<TensorType>(tensorType), memorySpace));
+    return mlir::cast<bufferization::BufferLikeType>(
+        bufferization::getMemRefTypeWithFullyDynamicLayout(
+            mlir::cast<TensorType>(tensorType), memorySpace));
   };
   inferFunctionResultLayout =
       layoutMapOption == LayoutMapOption::InferLayoutMap;
@@ -473,7 +480,8 @@ bool AnalysisState::bufferizesToMemoryWrite(Value value) const {
 /// read. Also takes into account ops that create an alias but do not read by
 /// themselves (e.g., ExtractSliceOp).
 bool AnalysisState::isValueRead(Value value) const {
-  assert(llvm::isa<TensorType>(value.getType()) && "expected TensorType");
+  assert(llvm::isa<bufferization::TensorLikeType>(value.getType()) &&
+         "expected TensorLikeType");
   SmallVector<OpOperand *> workingSet;
   DenseSet<OpOperand *> visited;
   for (OpOperand &use : value.getUses())
@@ -66...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/136736


More information about the Mlir-commits mailing list