[Mlir-commits] [mlir] [mlir][bufferization] Support custom types (1/N) (PR #142986)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 5 08:11:20 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-bufferization

Author: Andrei Golubev (andrey-golubev)

<details>
<summary>Changes</summary>

Following the introduction of TensorLike and BufferLike type interfaces (see 00eaff3e9c897c263a879416d0f151d7ca7eeaff), introduce minimal changes required to bufferize a custom tensor operation into a custom buffer operation.

To achieve this, a new conversion dialect interface is added that abstracts away the differences between existing (tensor -> memref) and custom conversions.

The scope of the changes is intentionally limited (for example, BufferizableOpInterface is untouched) in order to first understand the basics and reach consensus design-wise.

---

Patch is 49.87 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142986.diff


19 Files Affected:

- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+15-2) 
- (added) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h (+72) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+28-20) 
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+3-2) 
- (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+51-25) 
- (added) mlir/lib/Dialect/Bufferization/IR/BufferizationConversionInterface.cpp (+67) 
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+11-10) 
- (modified) mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+4-4) 
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+4-4) 
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-23) 
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+8-6) 
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+20-1) 
- (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+49) 
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+23) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.h (+1) 
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+56-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..8390da956444d 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -17,6 +17,7 @@
 #include <optional>
 
 #include "mlir/Dialect/Bufferization/IR/BufferizationEnums.h.inc"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
 
 namespace mlir {
 class OpBuilder;
@@ -615,7 +616,7 @@ FailureOr<Value> getBuffer(RewriterBase &rewriter, Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around BufferizableOpInterface::getBufferType.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state);
 
@@ -629,7 +630,7 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
 /// IR, this function can be used.
 ///
 /// This function is a wrapper around `BufferizableOpInterface::getBufferType`.
-FailureOr<BaseMemRefType> getBufferType(Value value,
+FailureOr<BufferLikeType> getBufferType(Value value,
                                         const BufferizationOptions &options,
                                         const BufferizationState &state,
                                         SmallVector<Value> &invocationStack);
@@ -738,6 +739,18 @@ AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
 /// This is the default implementation of
 /// BufferizableOpInterface::hasTensorSemantics
 bool defaultHasTensorSemantics(Operation *op);
+
+/// This is a helper function used when buffer type is guaranteed to be memref.
+FailureOr<BaseMemRefType> castToMemRef(FailureOr<BufferLikeType> bufferType);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to verify the types in tensor and buffer
+/// worlds match.
+bool typesMatchAfterBufferization(Operation &op, Value tensor, Value buffer);
+
+/// This function is a free-standing helper that relies on
+/// bufferization::ConversionInterface to perform the conversion.
+Type getTensorFromBuffer(Type buffer);
 } // namespace detail
 
 } // namespace bufferization
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
new file mode 100644
index 0000000000000..4164d1dcb9ea6
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h
@@ -0,0 +1,72 @@
+//===- BufferizationConversionInterface.h - Dialect Interface ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+#define MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONCONVERSIONINTERFACE_H_
+
+#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h"
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+namespace bufferization {
+
+/// This class defines a virtual interface for conversions between tensor-like
+/// and buffer-like types.
+struct ConversionDialectInterface
+    : DialectInterface::Base<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Hook to customize tensor-like -> buffer-like conversion within a given
+  /// dialect. Returns a buffer-like type for the specific tensor-like type.
+  virtual FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize type checking between tensor-like and buffer-like types.
+  /// Given tensor `T` and buffer `B = getBufferType(T, ...)`, the call to
+  /// `typesMatch(T, B)` must return true.
+  virtual LogicalResult typesMatch(
+      TensorLikeType tensor, BufferLikeType buffer,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const = 0;
+
+  /// Hook to customize buffer-like -> tensor-like conversion, which is the
+  /// opposite of bufferization.
+  virtual TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const = 0;
+};
+
+/// Interface collection for conversion between tensor-like and buffer-like
+/// types, dispatches to a concrete interface implementation based on the
+/// dialect to which the given type belongs.
+struct ConversionInterface
+    : DialectInterfaceCollection<ConversionDialectInterface> {
+  using Base::Base;
+
+  /// Dispatches to ConversionDialectInterface::getBufferType() of the dialect
+  /// associated with the value type.
+  FailureOr<BufferLikeType> getBufferType(
+      Value value, const BufferizationOptions &options,
+      const BufferizationState &state,
+      function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::typesMatch() of the dialect
+  /// associated with the value type.
+  LogicalResult
+  typesMatch(TensorLikeType tensor, BufferLikeType buffer,
+             function_ref<InFlightDiagnostic(const Twine &)> emitError) const;
+
+  /// Dispatches to ConversionDialectInterface::getTensorFromBuffer() of the
+  /// dialect associated with the value type.
+  TensorLikeType getTensorFromBuffer(BufferLikeType buffer) const;
+};
+
+} // namespace bufferization
+} // namespace mlir
+
+#endif // MLIR_DIALECT_BUFFERIZATION_IR_BUFFERIZATIONTYPEINTERFACES_H_
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 3d4dcdee2663b..277d56bc3f647 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -12,6 +12,7 @@
 include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferViewFlowOpInterface.td"
 include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
+include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
 include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
@@ -387,20 +388,28 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
 // ToTensorOp
 //===----------------------------------------------------------------------===//
 
+class Bufferization_TensorAndBufferMatch<string tensor, string buffer> : PredOpTrait<
+  "specified tensor and buffer types match",
+  CPred<
+    "::mlir::bufferization::detail::typesMatchAfterBufferization("
+        "$_op, $" # tensor # ", $" # buffer #")"
+  >
+>;
+
 def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     BufferizableOpInterface,
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
-    AllElementTypesMatch<["memref", "result"]>
+    Bufferization_TensorAndBufferMatch<"result", "buffer">
   ]> {
-  let summary = "create a tensor from a `memref`";
+  let summary = "create a buffer-like type from a tensor-like type";
   let description = [{
-    An operation that creates a tensor from a `memref`. The result value is a
-    tensor whose shape and element type match the memref operand.
+    An operation that creates a tensor from a buffer. The result value is a
+    tensor-like type whose shape and element type match the buffer-like operand.
 
     The opposite of this op is `to_buffer`. Together, these two ops are
     useful for source/target materializations when doing type conversions
-    involving tensors and memrefs.
+    involving tensors and buffers.
 
     Example:
 
@@ -442,11 +451,11 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     away. However, such IR is no longer bufferizable with One-Shot Bufferize.
   }];
 
-  let arguments = (ins Arg<AnyRankedOrUnrankedMemRef,
+  let arguments = (ins Arg<Bufferization_BufferLikeTypeInterface,
                            "the reference to load from",
-                           [MemReadAt<0, FullEffect>]>:$memref,
+                           [MemReadAt<0, FullEffect>]>:$buffer,
                        UnitAttr:$restrict, UnitAttr:$writable);
-  let results = (outs AnyTensor:$result);
+  let results = (outs Bufferization_TensorLikeTypeInterface:$result);
 
   let extraClassDeclaration = [{
     /// The result of a to_tensor is always a tensor.
@@ -473,19 +482,19 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
     FailureOr<BaseMemRefType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getMemref().getType());
+      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
     }
   }];
 
   let assemblyFormat = [{
-    $memref (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
-      `:` type($memref) `to` type($result)
+    $buffer (`restrict` $restrict^)? (`writable` $writable^)? attr-dict
+      `:` type($buffer) `to` type($result)
   }];
 
   let builders = [
-    OpBuilder<(ins "Value":$memref, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
-      auto rtt = memref::getTensorTypeFromMemRefType(memref.getType());
-      build($_builder, $_state, rtt, memref, restrict, writeable);
+    OpBuilder<(ins "Value":$buffer, CArg<"bool", "false">:$restrict, CArg<"bool", "false">:$writeable), [{
+      auto rtt = bufferization::detail::getTensorFromBuffer(buffer.getType());
+      build($_builder, $_state, rtt, buffer, restrict, writeable);
     }]>
   ];
 
@@ -503,10 +512,9 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     SameOperandsAndResultShape,
     SameOperandsAndResultElementType,
     Pure,
-    AllShapesMatch<["memref", "tensor"]>,
-    AllElementTypesMatch<["memref", "tensor"]>
+    Bufferization_TensorAndBufferMatch<"tensor", "buffer">
   ]> {
-  let summary = "cast a tensor to memref";
+  let summary = "cast a tensor-like type to buffer-like type";
   let description = [{
     An operation that returns the future buffer of a `tensor`.
 
@@ -524,8 +532,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
     the returned buffer) will not be written to.
   }];
 
-  let arguments = (ins AnyTensor:$tensor, UnitAttr:$read_only);
-  let results = (outs AnyRankedOrUnrankedMemRef:$memref);
+  let arguments = (ins Bufferization_TensorLikeTypeInterface:$tensor, UnitAttr:$read_only);
+  let results = (outs Bufferization_BufferLikeTypeInterface:$buffer);
 
   let extraClassDeclaration = [{
     //===------------------------------------------------------------------===//
@@ -560,7 +568,7 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
   }];
 
   let assemblyFormat = [{
-    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($memref)
+    $tensor (`read_only` $read_only^)? attr-dict `:` type($tensor) `to` type($buffer)
   }];
 
   let hasFolder = 1;
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index a441b8b66659e..f56c10555f02c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -65,12 +65,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
         // The operand was already bufferized. Take its type directly.
         callerType = memrefType;
       } else {
-        FailureOr<BaseMemRefType> maybeCallerType =
+        FailureOr<BufferLikeType> maybeCallerType =
             bufferization::getBufferType(opOperand->get(), options, state,
                                          invocationStack);
         if (failed(maybeCallerType))
           return failure();
-        callerType = *maybeCallerType;
+        assert(isa<BaseMemRefType>(*maybeCallerType) && "expected memref type");
+        callerType = cast<BaseMemRefType>(*maybeCallerType);
       }
 
       if (!bufferType) {
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index a57d58ab28d28..021a557f68b4b 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -164,8 +164,8 @@ struct SelectOpInterface
     // buffers have different types, they differ only in their layout map. Cast
     // both of them to the most dynamic MemRef type.
     if (trueBuffer.getType() != falseBuffer.getType()) {
-      auto targetType =
-          bufferization::getBufferType(selectOp.getResult(), options, state);
+      auto targetType = bufferization::detail::castToMemRef(
+          bufferization::getBufferType(selectOp.getResult(), options, state));
       if (failed(targetType))
         return failure();
       if (trueBuffer.getType() != *targetType)
@@ -187,10 +187,12 @@ struct SelectOpInterface
                 SmallVector<Value> &invocationStack) const {
     auto selectOp = cast<arith::SelectOp>(op);
     assert(value == selectOp.getResult() && "invalid value");
-    auto trueType = bufferization::getBufferType(
-        selectOp.getTrueValue(), options, state, invocationStack);
-    auto falseType = bufferization::getBufferType(
-        selectOp.getFalseValue(), options, state, invocationStack);
+    auto trueType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getTrueValue(), options, state, invocationStack));
+    auto falseType =
+        bufferization::detail::castToMemRef(bufferization::getBufferType(
+            selectOp.getFalseValue(), options, state, invocationStack));
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..d00605a7b9a17 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/IR/BufferizationConversionInterface.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -211,8 +212,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue(
   // Add 'memory_space' attribute. Not needed if 'copy' operand is specified.
   if (copy)
     return allocTensorOp.getResult();
-  FailureOr<BaseMemRefType> copyBufferType =
-      getBufferType(tensor, options, state);
+  auto copyBufferType =
+      detail::castToMemRef(getBufferType(tensor, options, state));
   if (failed(copyBufferType))
     return failure();
   std::optional<Attribute> memorySpace = copyBufferType->getMemorySpace();
@@ -673,28 +674,28 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
                                           const BufferizationOptions &options,
                                           const BufferizationState &state) {
 #ifndef NDEBUG
-  auto tensorType = llvm::dyn_cast<TensorType>(value.getType());
+  auto tensorType = llvm::dyn_cast<TensorLikeType>(value.getType());
   assert(tensorType && "unexpected non-tensor type");
 #endif // NDEBUG
 
   // Replace "%t = to_tensor %m" with %m.
   if (auto toTensorOp = value.getDefiningOp<bufferization::ToTensorOp>())
-    return toTensorOp.getMemref();
+    return toTensorOp.getBuffer();
 
   // Insert to_buffer op.
   OpBuilder::InsertionGuard g(rewriter);
   setInsertionPointAfter(rewriter, value);
-  FailureOr<BaseMemRefType> memrefType = getBufferType(value, options, state);
-  if (failed(memrefType))
+  FailureOr<BufferLikeType> bufferType = getBufferType(value, options, state);
+  if (failed(bufferType))
     return failure();
-  ensureToBufferOpIsValid(value, *memrefType);
+  ensureToBufferOpIsValid(value, *bufferType);
   return rewriter
-      .create<bufferization::ToBufferOp>(value.getLoc(), *memrefType, value)
+      .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value)
       .getResult();
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state) {
   SmallVector<Value> invocationStack;
@@ -702,11 +703,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
 }
 
 /// Return the buffer type for a given Value (tensor) after bufferization.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 bufferization::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
-  assert(llvm::isa<TensorType>(value.getType()) &&
+  assert(llvm::isa<TensorLikeType>(value.getType()) &&
          "unexpected non-tensor type");
   invocationStack.push_back(value);
   auto popFromStack =
@@ -718,13 +719,11 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
   if (bufferizableOp)
     return bufferizableOp.getBufferType(value, options, state, invocationStack);
 
-  // Op is not bufferizable.
-  auto memSpace =
-      options.defaultMemorySpaceFn(cast<TensorType>(value.getType()));
-  if (!memSpace.has_value())
-    return op->emitError("could not infer memory space");
-
-  return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+  // Op is not bufferizable, use conversion interface.
+  bufferization::ConversionInterface iface(value.getContext());
+  return iface.getBufferType(value, options, state, [&](const Twine &message) {
+    return op->emitError(message);
+  });
 }
 
 bool bufferization::hasTensorSemantics(Operation *op) {
@@ -744,12 +743,11 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
   SmallVector<Value> replacements;
   for (OpResult opResult : op->getOpResults()) {
     Value replacement = values[opResult.getResultNumber()];
-    if (llvm::isa<TensorType>(opResult.getType())) {
+    if (llvm::isa<TensorLikeType>(opResult.getType())) {
       // The OpResult is a tensor. Such values are replaced with memrefs during
       // bufferization.
-      assert((llvm::isa<MemRefType>(replacement.getType()) ||
-              llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
-             "tensor op result should be replaced with a memref value");
+      assert(llvm::isa<BufferLikeType>(replacement.getType()) &&
+             "tensor op result should be replaced with a buffer value");
       // The existing uses of the OpResult still expect a tensor. Insert a
       // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
       // loose all of its users and eventually DCE away.
@@ -970,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return getBufferType(equivalentOperand, options, bufferizationState,
-                         invocationStack);
+    return castToMemRef(getBufferType(equivalentOperand, options,
+                                      bufferizationState, invocationStack));
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -1031,7 +1029,7 @@ bufferization::detail::unknownGe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/142986


More information about the Mlir-commits mailing list