[Mlir-commits] [mlir] [mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (PR #144867)

Andrei Golubev llvmlistbot at llvm.org
Thu Jun 19 03:32:51 PDT 2025


https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/144867

Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.

>From 57b09078fef7be7a9395d4e144a2dcd9dae49fb8 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Thu, 19 Jun 2025 10:29:41 +0000
Subject: [PATCH] [mlir][bufferization] Return BufferLikeType in
 BufferizableOpInterface

Support custom types (2/N): allow value-owning operations (e.g.
allocation ops) to bufferize into custom types. This requires
BufferizableOpInterface::getBufferType() to return BufferLikeType
instead of BaseMemRefType.

Affected implementors of the interface are update accordingly.
---
 .../IR/BufferizableOpInterface.h              |  2 +-
 .../IR/BufferizableOpInterface.td             |  2 +-
 .../Bufferization/IR/BufferizationOps.td      |  6 +-
 .../IR/BufferizationTypeInterfaces.h          |  1 +
 .../IR/UnstructuredControlFlow.h              |  4 +-
 .../BufferizableOpInterfaceImpl.cpp           |  8 +--
 .../IR/BufferizableOpInterface.cpp            | 15 +++--
 .../Bufferization/IR/BufferizationOps.cpp     |  5 +-
 .../FuncBufferizableOpInterfaceImpl.cpp       | 15 ++---
 .../BufferizableOpInterfaceImpl.cpp           | 61 +++++++++----------
 .../BufferizableOpInterfaceImpl.cpp           | 52 ++++++++--------
 .../Transforms/one-shot-bufferize.mlir        | 23 ++++++-
 mlir/test/lib/Dialect/Test/TestOpDefs.cpp     | 34 +++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         | 53 ++++++++++++++++
 14 files changed, 196 insertions(+), 85 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
 // Bufferization Type Interfaces
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Types.h"
 
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have different types, they differ only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are different: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only differ in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -68,20 +68,22 @@ struct CastOpInterface
     if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
       // When casting to a ranked tensor, we cannot infer any static offset or
       // strides from the source. Assume fully dynamic.
-      return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+      return cast<BufferLikeType>(
+          getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
     }
 
     // Case 2: Casting to an unranked tensor type
     if (isa<UnrankedTensorType>(castOp.getType())) {
-      return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+      return cast<BufferLikeType>(
+          getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
     }
 
     // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
     // change.
     auto rankedResultType = cast<RankedTensorType>(castOp.getType());
-    return MemRefType::get(
+    return cast<BufferLikeType>(MemRefType::get(
         rankedResultType.getShape(), rankedResultType.getElementType(),
-        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
+        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -141,7 +143,7 @@ struct CollapseShapeOpInterface
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -157,12 +159,13 @@ struct CollapseShapeOpInterface
     if (!canBeCollapsed) {
       // If dims cannot be collapsed, this op bufferizes to a new allocation.
       RankedTensorType tensorResultType = collapseShapeOp.getResultType();
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(
-          tensorResultType, srcBufferType.getMemorySpace());
+      return cast<BufferLikeType>(
+          bufferization::getMemRefTypeWithStaticIdentityLayout(
+              tensorResultType, srcBufferType.getMemorySpace()));
     }
 
-    return memref::CollapseShapeOp::computeCollapsedType(
-        srcBufferType, collapseShapeOp.getReassociationIndices());
+    return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
+        srcBufferType, collapseShapeOp.getReassociationIndices()));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -319,7 +322,7 @@ struct ExpandShapeOpInterface
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -334,7 +337,7 @@ struct ExpandShapeOpInterface
         expandShapeOp.getReassociationIndices());
     if (failed(maybeResultType))
       return failure();
-    return *maybeResultType;
+    return cast<BufferLikeType>(*maybeResultType);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -404,7 +407,7 @@ struct ExtractSliceOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -417,10 +420,10 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    return memref::SubViewOp::inferRankReducedResultType(
+    return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
         extractSliceOp.getType().getShape(),
         llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
-        mixedStrides);
+        mixedStrides));
   }
 };
 
@@ -501,8 +504,8 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    FailureOr<BaseMemRefType> memrefType = bufferization::detail::asMemRefType(
-        bufferization::getBufferType(*tensorAlloc, options, state));
+    FailureOr<BufferLikeType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options, state);
     if (failed(memrefType))
       return failure();
     Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -753,7 +756,7 @@ struct PadOpInterface
     return {};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -765,9 +768,10 @@ struct PadOpInterface
     if (failed(maybeSrcBufferType))
       return failure();
     MemRefLayoutAttrInterface layout;
-    return MemRefType::get(padOp.getResultType().getShape(),
-                           padOp.getResultType().getElementType(), layout,
-                           maybeSrcBufferType->getMemorySpace());
+    return cast<BufferLikeType>(
+        MemRefType::get(padOp.getResultType().getShape(),
+                        padOp.getResultType().getElementType(), layout,
+                        maybeSrcBufferType->getMemorySpace()));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -927,7 +931,7 @@ struct ReshapeOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -937,9 +941,9 @@ struct ReshapeOpInterface
         reshapeOp.getSource(), options, state, invocationStack);
     if (failed(maybeSourceBufferType))
       return failure();
-    return getMemRefTypeWithStaticIdentityLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
         reshapeOp.getResult().getType(),
-        cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
+        cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
   }
 };
 
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index da3c26ce36ba5..8031732011839 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -272,10 +272,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
 
 // -----
 
-// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK:       func.func @custom_op(
 // CHECK-SAME:    %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
 // CHECK-SAME:  ) -> !test.test_tensor<[32, 128], f64> {
-func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+func.func @custom_op(%arg: !test.test_tensor<[32, 64], f64>)
     -> !test.test_tensor<[32, 128], f64> {
   // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]]
   // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
@@ -288,3 +288,22 @@ func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
   // CHECK: return %[[OUT]]
   return %out : !test.test_tensor<[32, 128], f64>
 }
+
+// -----
+
+// CHECK:       func.func @custom_origin_op()
+// CHECK-SAME:  -> !test.test_tensor<[42], f64> {
+func.func @custom_origin_op() -> !test.test_tensor<[42], f64> {
+  // CHECK: %[[MEMREF:.*]] = "test.create_memref_op"() : ()
+  // CHECK-SAME: -> !test.test_memref<[21], f64>
+  // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+  // CHECK-SAME: : (!test.test_memref<[21], f64>)
+  // CHECK-SAME: -> !test.test_memref<[42], f64>
+  %in = "test.create_tensor_op"() : () -> !test.test_tensor<[21], f64>
+  %out = "test.dummy_tensor_op"(%in) : (!test.test_tensor<[21], f64>)
+    -> !test.test_tensor<[42], f64>
+
+  // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+  // CHECK: return %[[OUT]]
+  return %out : !test.test_tensor<[42], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 78e44c6ec7a9b..b64d3b7230b36 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1410,3 +1410,37 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
 
   return mlir::success();
 }
+
+::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
+    ::mlir::RewriterBase &rewriter,
+    const ::mlir::bufferization::BufferizationOptions &options,
+    ::mlir::bufferization::BufferizationState &state) {
+  // Note: mlir::bufferization::getBufferType() would internally call
+  // TestCreateTensorOp::getBufferType()
+  const auto bufferizedOutType =
+      mlir::bufferization::getBufferType(getOutput(), options, state);
+  if (mlir::failed(bufferizedOutType))
+    return failure();
+
+  // replace op with memref analogy
+  auto createMemrefOp =
+      rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType);
+
+  mlir::bufferization::replaceOpWithBufferizedValues(
+      rewriter, getOperation(), createMemrefOp.getResult());
+
+  return mlir::success();
+}
+
+mlir::FailureOr<mlir::bufferization::BufferLikeType>
+test::TestCreateTensorOp::getBufferType(
+    mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+    const mlir::bufferization::BufferizationState &,
+    llvm::SmallVector<::mlir::Value> &) {
+  const auto type = dyn_cast<test::TestTensorType>(value.getType());
+  if (type == nullptr)
+    return failure();
+
+  return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
+      getContext(), type.getShape(), type.getElementType(), nullptr));
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 79bcd9c2e0a9a..2a4de535b0841 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3606,4 +3606,57 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
   );
 }
 
+def TestCreateTensorOp : TEST_Op<"create_tensor_op", [BufferizableOpInterface]> {
+  let arguments = (ins);
+  let results = (outs Arg<TestTensorType>:$output);
+  let extraClassDeclaration = [{
+    // BufferizableOpInterface
+    bool bufferizesToMemoryRead(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    bool bufferizesToMemoryWrite(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    bool bufferizesToAllocation(mlir::Value value);
+
+    mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
+      const mlir::bufferization::AnalysisState&);
+
+    mlir::LogicalResult bufferize(
+      mlir::RewriterBase& rewriter,
+      const mlir::bufferization::BufferizationOptions& options,
+      mlir::bufferization::BufferizationState &state);
+
+    mlir::FailureOr<mlir::bufferization::BufferLikeType> getBufferType(
+      mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+      const mlir::bufferization::BufferizationState &,
+      llvm::SmallVector<::mlir::Value> &);
+  }];
+
+  let extraClassDefinition = [{
+    bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    bool test::TestCreateTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    bool test::TestCreateTensorOp::bufferizesToAllocation(mlir::Value value) {
+      return false;
+    }
+
+    ::mlir::bufferization::AliasingValueList
+    test::TestCreateTensorOp::getAliasingValues(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return {};
+    }
+  }];
+}
+
+def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
+  let arguments = (ins);
+  let results = (outs Arg<TestMemrefType>:$output);
+}
+
 #endif // TEST_OPS



More information about the Mlir-commits mailing list