[Mlir-commits] [mlir] a63f572 - [mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (#144867)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 2 11:27:38 PDT 2025


Author: Andrei Golubev
Date: 2025-07-02T11:27:35-07:00
New Revision: a63f57262898588b576d66e5fd79c0aa64b35f2d

URL: https://github.com/llvm/llvm-project/commit/a63f57262898588b576d66e5fd79c0aa64b35f2d
DIFF: https://github.com/llvm/llvm-project/commit/a63f57262898588b576d66e5fd79c0aa64b35f2d.diff

LOG: [mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (#144867)

Support custom types (2/N): allow value-owning operations (e.g.
allocation ops) to bufferize custom tensors into custom buffers. This
requires BufferizableOpInterface::getBufferType() to return
BufferLikeType instead of BaseMemRefType.

Affected implementors of the interface are updated accordingly.

Relates to ee070d08163ac09842d9bf0c1315f311df39faf1.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
    mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
    mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
    mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
    mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
    mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
    mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
    mlir/test/lib/Dialect/Test/TestOpDefs.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
 /// This is the default implementation of
 /// BufferizableOpInterface::getBufferType. Should not be called from other
 /// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 defaultGetBufferType(Value value, const BufferizationOptions &options,
                      const BufferizationState &state,
                      SmallVector<Value> &invocationStack);

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
           Note: This interface method should never be called directly from user
           code. Always use `bufferization::getBufferType`.
         }],
-        /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+        /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
         /*methodName=*/"getBufferType",
         /*args=*/(ins "::mlir::Value":$value,
                       "const ::mlir::bufferization::BufferizationOptions &":$options,

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
     AliasingValueList getAliasingValues(
         OpOperand &opOperand, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state,
         SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
 
     bool isWritable(Value value, const AnalysisState &state);
 
-    FailureOr<BaseMemRefType> getBufferType(
+    FailureOr<BufferLikeType> getBufferType(
         Value value, const BufferizationOptions &options,
         const BufferizationState &state, SmallVector<Value> &invocationStack) {
-      return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+      return getBuffer().getType();
     }
   }];
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..a2bfcb7ed2b75 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -18,7 +18,6 @@
 
 namespace mlir::bufferization {
 struct BufferizationOptions;
-class BufferizationState;
 class BufferLikeType;
 } // namespace mlir::bufferization
 

diff  --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
 struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     : public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
     if (!bufferType)
       return op->emitOpError("could not infer buffer type of block argument");
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 
 protected:

diff  --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
     if (failed(trueType) || failed(falseType))
       return failure();
     if (*trueType == *falseType)
-      return *trueType;
+      return cast<BufferLikeType>(*trueType);
     if (trueType->getMemorySpace() != falseType->getMemorySpace())
       return op->emitError("inconsistent memory space on true/false operands");
 
     // If the buffers have 
diff erent types, they 
diff er only in their layout
     // map.
     auto memrefType = llvm::cast<MemRefType>(*trueType);
-    return getMemRefTypeWithFullyDynamicLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
         RankedTensorType::get(memrefType.getShape(),
                               memrefType.getElementType()),
-        memrefType.getMemorySpace());
+        memrefType.getMemorySpace()));
   }
 };
 

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 5e69fa386ecbc..8f17a82fabe03 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -947,7 +947,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
   return AliasingOpOperandList(std::move(result));
 }
 
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
     Value value, const BufferizationOptions &options,
     const BufferizationState &bufferizationState,
     SmallVector<Value> &invocationStack) {
@@ -955,8 +955,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   auto tensorType = cast<TensorType>(value.getType());
 
   // No further analysis is possible for a block argument.
-  if (llvm::isa<BlockArgument>(value))
-    return bufferization::getMemRefType(tensorType, options);
+  if (llvm::isa<BlockArgument>(value)) {
+    return cast<BufferLikeType>(
+        bufferization::getMemRefType(tensorType, options));
+  }
 
   // Value is an OpResult.
   Operation *op = getOwnerOfValue(value);
@@ -968,8 +970,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
     // If the OpResult has an equivalent OpOperand, both OpResult and
     // OpOperand bufferize to the exact same buffer type.
     Value equivalentOperand = aliases.getAliases().front().opOperand->get();
-    return asMemRefType(getBufferType(equivalentOperand, options,
-                                      bufferizationState, invocationStack));
+    return getBufferType(equivalentOperand, options, bufferizationState,
+                         invocationStack);
   }
 
   // If we do not know the memory space and there is no default memory space,
@@ -979,7 +981,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
   if (!memSpace.has_value())
     return op->emitError("could not infer memory space");
 
-  return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+  return cast<BufferLikeType>(
+      getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
 }
 
 bool bufferization::detail::defaultIsRepetitiveRegion(

diff  --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
   return {};
 }
 
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
 AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
                              const BufferizationState &state,
                              SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
     return getOperation()->emitError("could not infer memory space");
   }
 
-  return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+  return cast<BufferLikeType>(
+      getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
 }
 
 LogicalResult AllocTensorOp::verify() {

diff  --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
     return result;
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
     Type resultType =
         funcType.getResult(cast<OpResult>(value).getResultNumber());
     if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
-      return bufferizedType;
+      return cast<BufferLikeType>(bufferizedType);
 
     // Otherwise, call the type converter to compute the bufferized type.
     auto tensorType = cast<TensorType>(resultType);
-    return options.functionArgTypeConverterFn(
-        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+    return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+        tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+        options));
   }
 
   /// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
     return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
 
     // Function arguments are special.
     if (bbArg.getOwner() == &funcOp.getBody().front())
-      return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
-                                          options);
+      return cast<BufferLikeType>(
+          getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
 
     return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
         getBufferType(op, value, options, state, invocationStack);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
 
     // Best case: Both branches have the exact same buffer type.
     if (thenBufferType == elseBufferType)
-      return thenBufferType;
+      return cast<BufferLikeType>(thenBufferType);
 
     // Memory space mismatch.
     if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
       return op->emitError("inconsistent memory space on then/else branches");
 
     // Layout maps are 
diff erent: Promote to fully dynamic layout map.
-    return getMemRefTypeWithFullyDynamicLayout(
-        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+    return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+        cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
   }
 };
 
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
           cast<TensorType>(value.getType()), bufferType.getMemorySpace());
     }
 
-    return bufferType;
+    return cast<BufferLikeType>(bufferType);
   }
 };
 
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
 /// If both buffer types are equal, no casts are needed the computed buffer type
 /// can be used directly. Otherwise, the buffer types can only 
diff er in their
 /// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
     Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
     const BufferizationOptions &options, const BufferizationState &state,
     SmallVector<Value> &invocationStack) {
   // Determine the buffer type of the init_arg.
-  auto initArgBufferType = bufferization::detail::asMemRefType(
-      bufferization::getBufferType(initArg, options, state, invocationStack));
+  auto initArgBufferType =
+      bufferization::getBufferType(initArg, options, state, invocationStack);
   if (failed(initArgBufferType))
     return failure();
 
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
   }
 
   // Compute the buffer type of the yielded value.
-  BaseMemRefType yieldedValueBufferType;
+  BufferLikeType yieldedValueBufferType;
   if (isa<BaseMemRefType>(yieldedValue.getType())) {
     // scf.yield was already bufferized.
-    yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+    yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
   } else {
     // Note: This typically triggers a recursive call for the buffer type of
     // the iter_arg.
-    auto maybeBufferType =
-        bufferization::detail::asMemRefType(bufferization::getBufferType(
-            yieldedValue, options, state, invocationStack));
+    auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+                                                        state, invocationStack);
     if (failed(maybeBufferType))
       return failure();
     yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
         "expected same shape");
   }
 #endif // NDEBUG
-  return getMemRefTypeWithFullyDynamicLayout(
-      iterTensorType, yieldedBufferType.getMemorySpace());
+  return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+      iterTensorType, yieldedBufferType.getMemorySpace()));
 }
 
 /// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
     if (auto opResult = dyn_cast<OpResult>(value)) {
       // The type of an OpResult must match the corresponding iter_arg type.
       BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
-      auto bufferType =
-          bufferization::getBufferType(bbArg, options, state, invocationStack);
-      if (failed(bufferType))
-        return failure();
-      assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
-      return cast<BaseMemRefType>(*bufferType);
+      return bufferization::getBufferType(bbArg, options, state,
+                                          invocationStack);
     }
 
     // Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
     Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
     if (!isa<TensorType>(conditionYieldedVal.getType())) {
       // scf.condition was already bufferized.
-      return cast<BaseMemRefType>(conditionYieldedVal.getType());
+      return cast<BufferLikeType>(conditionYieldedVal.getType());
     }
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
-        conditionYieldedVal, options, state, invocationStack));
+    return bufferization::getBufferType(conditionYieldedVal, options, state,
+                                        invocationStack);
   }
 
   /// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
     if (auto bbArg = dyn_cast<BlockArgument>(value))
       // A tensor block argument has the same bufferized type as the
       // corresponding output operand.
-      return bufferization::detail::asMemRefType(
-          bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
-                                       options, state, invocationStack));
+      return bufferization::getBufferType(
+          forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+          invocationStack);
 
     // The bufferized result type is the same as the bufferized type of the
     // corresponding output operand.
-    return bufferization::detail::asMemRefType(bufferization::getBufferType(
+    return bufferization::getBufferType(
         forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
-        state, invocationStack));
+        state, invocationStack);
   }
 
   bool isRepetitiveRegion(Operation *op, unsigned index) const {

diff  --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
     return {{op->getResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -68,20 +68,22 @@ struct CastOpInterface
     if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
       // When casting to a ranked tensor, we cannot infer any static offset or
       // strides from the source. Assume fully dynamic.
-      return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+      return cast<BufferLikeType>(
+          getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
     }
 
     // Case 2: Casting to an unranked tensor type
     if (isa<UnrankedTensorType>(castOp.getType())) {
-      return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+      return cast<BufferLikeType>(
+          getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
     }
 
     // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
     // change.
     auto rankedResultType = cast<RankedTensorType>(castOp.getType());
-    return MemRefType::get(
+    return cast<BufferLikeType>(MemRefType::get(
         rankedResultType.getShape(), rankedResultType.getElementType(),
-        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
+        llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -141,7 +143,7 @@ struct CollapseShapeOpInterface
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -157,12 +159,13 @@ struct CollapseShapeOpInterface
     if (!canBeCollapsed) {
       // If dims cannot be collapsed, this op bufferizes to a new allocation.
       RankedTensorType tensorResultType = collapseShapeOp.getResultType();
-      return bufferization::getMemRefTypeWithStaticIdentityLayout(
-          tensorResultType, srcBufferType.getMemorySpace());
+      return cast<BufferLikeType>(
+          bufferization::getMemRefTypeWithStaticIdentityLayout(
+              tensorResultType, srcBufferType.getMemorySpace()));
     }
 
-    return memref::CollapseShapeOp::computeCollapsedType(
-        srcBufferType, collapseShapeOp.getReassociationIndices());
+    return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
+        srcBufferType, collapseShapeOp.getReassociationIndices()));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -319,7 +322,7 @@ struct ExpandShapeOpInterface
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -334,7 +337,7 @@ struct ExpandShapeOpInterface
         expandShapeOp.getReassociationIndices());
     if (failed(maybeResultType))
       return failure();
-    return *maybeResultType;
+    return cast<BufferLikeType>(*maybeResultType);
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -404,7 +407,7 @@ struct ExtractSliceOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -417,10 +420,10 @@ struct ExtractSliceOpInterface
     SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
-    return memref::SubViewOp::inferRankReducedResultType(
+    return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
         extractSliceOp.getType().getShape(),
         llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
-        mixedStrides);
+        mixedStrides));
   }
 };
 
@@ -501,8 +504,8 @@ struct FromElementsOpInterface
         /*copy=*/false);
     if (failed(tensorAlloc))
       return failure();
-    FailureOr<BaseMemRefType> memrefType = bufferization::detail::asMemRefType(
-        bufferization::getBufferType(*tensorAlloc, options, state));
+    FailureOr<BufferLikeType> memrefType =
+        bufferization::getBufferType(*tensorAlloc, options, state);
     if (failed(memrefType))
       return failure();
     Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -753,7 +756,7 @@ struct PadOpInterface
     return {};
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -765,9 +768,10 @@ struct PadOpInterface
     if (failed(maybeSrcBufferType))
       return failure();
     MemRefLayoutAttrInterface layout;
-    return MemRefType::get(padOp.getResultType().getShape(),
-                           padOp.getResultType().getElementType(), layout,
-                           maybeSrcBufferType->getMemorySpace());
+    return cast<BufferLikeType>(
+        MemRefType::get(padOp.getResultType().getShape(),
+                        padOp.getResultType().getElementType(), layout,
+                        maybeSrcBufferType->getMemorySpace()));
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -927,7 +931,7 @@ struct ReshapeOpInterface
     return success();
   }
 
-  FailureOr<BaseMemRefType>
+  FailureOr<BufferLikeType>
   getBufferType(Operation *op, Value value, const BufferizationOptions &options,
                 const BufferizationState &state,
                 SmallVector<Value> &invocationStack) const {
@@ -937,9 +941,9 @@ struct ReshapeOpInterface
         reshapeOp.getSource(), options, state, invocationStack);
     if (failed(maybeSourceBufferType))
       return failure();
-    return getMemRefTypeWithStaticIdentityLayout(
+    return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
         reshapeOp.getResult().getType(),
-        cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
+        cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
   }
 };
 

diff  --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index da3c26ce36ba5..8031732011839 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -272,10 +272,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
 
 // -----
 
-// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK:       func.func @custom_op(
 // CHECK-SAME:    %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
 // CHECK-SAME:  ) -> !test.test_tensor<[32, 128], f64> {
-func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+func.func @custom_op(%arg: !test.test_tensor<[32, 64], f64>)
     -> !test.test_tensor<[32, 128], f64> {
   // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]]
   // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
@@ -288,3 +288,22 @@ func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
   // CHECK: return %[[OUT]]
   return %out : !test.test_tensor<[32, 128], f64>
 }
+
+// -----
+
+// CHECK:       func.func @custom_origin_op()
+// CHECK-SAME:  -> !test.test_tensor<[42], f64> {
+func.func @custom_origin_op() -> !test.test_tensor<[42], f64> {
+  // CHECK: %[[MEMREF:.*]] = "test.create_memref_op"() : ()
+  // CHECK-SAME: -> !test.test_memref<[21], f64>
+  // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+  // CHECK-SAME: : (!test.test_memref<[21], f64>)
+  // CHECK-SAME: -> !test.test_memref<[42], f64>
+  %in = "test.create_tensor_op"() : () -> !test.test_tensor<[21], f64>
+  %out = "test.dummy_tensor_op"(%in) : (!test.test_tensor<[21], f64>)
+    -> !test.test_tensor<[42], f64>
+
+  // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+  // CHECK: return %[[OUT]]
+  return %out : !test.test_tensor<[42], f64>
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 6c1a5d3441530..3ab4ef2680978 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1420,3 +1420,37 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
 
   return mlir::success();
 }
+
+::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
+    ::mlir::RewriterBase &rewriter,
+    const ::mlir::bufferization::BufferizationOptions &options,
+    ::mlir::bufferization::BufferizationState &state) {
+  // Note: mlir::bufferization::getBufferType() would internally call
+  // TestCreateTensorOp::getBufferType()
+  const auto bufferizedOutType =
+      mlir::bufferization::getBufferType(getOutput(), options, state);
+  if (mlir::failed(bufferizedOutType))
+    return failure();
+
+  // replace op with memref analogy
+  auto createMemrefOp =
+      rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType);
+
+  mlir::bufferization::replaceOpWithBufferizedValues(
+      rewriter, getOperation(), createMemrefOp.getResult());
+
+  return mlir::success();
+}
+
+mlir::FailureOr<mlir::bufferization::BufferLikeType>
+test::TestCreateTensorOp::getBufferType(
+    mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+    const mlir::bufferization::BufferizationState &,
+    llvm::SmallVector<::mlir::Value> &) {
+  const auto type = dyn_cast<test::TestTensorType>(value.getType());
+  if (type == nullptr)
+    return failure();
+
+  return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
+      getContext(), type.getShape(), type.getElementType(), nullptr));
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0ad5bfa9a58ab..5320ba4ea3829 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3615,29 +3615,16 @@ def TestAllocWithMultipleResults : TEST_Op<"alloc_with_multiple_results"> {
 // Test Ops bufferization
 //===----------------------------------------------------------------------===//
 
-def TestDummyTensorOp : TEST_Op<"dummy_tensor_op", [BufferizableOpInterface]> {
+def TestDummyTensorOp : TEST_Op<"dummy_tensor_op",
+    [DeclareOpInterfaceMethods<BufferizableOpInterface,
+        ["bufferize", "bufferizesToMemoryRead",
+         "bufferizesToMemoryWrite", "getAliasingValues"]>]> {
   let arguments = (ins
     Arg<TestTensorType>:$input
   );
   let results = (outs
     Arg<TestTensorType>:$output
   );
-  let extraClassDeclaration = [{
-    // BufferizableOpInterface
-    bool bufferizesToMemoryRead(mlir::OpOperand&,
-      const mlir::bufferization::AnalysisState&);
-
-    bool bufferizesToMemoryWrite(mlir::OpOperand&,
-      const mlir::bufferization::AnalysisState&);
-
-    mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
-      const mlir::bufferization::AnalysisState&);
-
-    mlir::LogicalResult bufferize(
-      mlir::RewriterBase& rewriter,
-      const mlir::bufferization::BufferizationOptions& options,
-      mlir::bufferization::BufferizationState &state);
-  }];
 
   let extraClassDefinition = [{
     bool test::TestDummyTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
@@ -3665,6 +3652,39 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
   );
 }
 
+def TestCreateTensorOp : TEST_Op<"create_tensor_op",
+    [DeclareOpInterfaceMethods<BufferizableOpInterface,
+        ["bufferize", "getBufferType", "bufferizesToMemoryRead",
+         "bufferizesToMemoryWrite", "getAliasingValues",
+         "bufferizesToAllocation"]>]> {
+  let arguments = (ins);
+  let results = (outs Arg<TestTensorType>:$output);
+  let extraClassDefinition = [{
+    bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    bool test::TestCreateTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return true;
+    }
+    bool test::TestCreateTensorOp::bufferizesToAllocation(mlir::Value value) {
+      return false;
+    }
+
+    ::mlir::bufferization::AliasingValueList
+    test::TestCreateTensorOp::getAliasingValues(::mlir::OpOperand&,
+        const ::mlir::bufferization::AnalysisState&) {
+      return {};
+    }
+  }];
+}
+
+def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
+  let arguments = (ins);
+  let results = (outs Arg<TestMemrefType>:$output);
+}
+
 //===----------------------------------------------------------------------===//
 // Test assembly format references
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list