[Mlir-commits] [mlir] [mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (PR #144867)
Andrei Golubev
llvmlistbot at llvm.org
Thu Jun 19 03:32:51 PDT 2025
https://github.com/andrey-golubev created https://github.com/llvm/llvm-project/pull/144867
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.
Affected implementors of the interface are updated accordingly.
>From 57b09078fef7be7a9395d4e144a2dcd9dae49fb8 Mon Sep 17 00:00:00 2001
From: "Golubev, Andrey" <andrey.golubev at intel.com>
Date: Thu, 19 Jun 2025 10:29:41 +0000
Subject: [PATCH] [mlir][bufferization] Return BufferLikeType in
BufferizableOpInterface
Support custom types (2/N): allow value-owning operations (e.g.
allocation ops) to bufferize into custom types. This requires
BufferizableOpInterface::getBufferType() to return BufferLikeType
instead of BaseMemRefType.
Affected implementors of the interface are update accordingly.
---
.../IR/BufferizableOpInterface.h | 2 +-
.../IR/BufferizableOpInterface.td | 2 +-
.../Bufferization/IR/BufferizationOps.td | 6 +-
.../IR/BufferizationTypeInterfaces.h | 1 +
.../IR/UnstructuredControlFlow.h | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 8 +--
.../IR/BufferizableOpInterface.cpp | 15 +++--
.../Bufferization/IR/BufferizationOps.cpp | 5 +-
.../FuncBufferizableOpInterfaceImpl.cpp | 15 ++---
.../BufferizableOpInterfaceImpl.cpp | 61 +++++++++----------
.../BufferizableOpInterfaceImpl.cpp | 52 ++++++++--------
.../Transforms/one-shot-bufferize.mlir | 23 ++++++-
mlir/test/lib/Dialect/Test/TestOpDefs.cpp | 34 +++++++++++
mlir/test/lib/Dialect/Test/TestOps.td | 53 ++++++++++++++++
14 files changed, 196 insertions(+), 85 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+ return getBuffer().getType();
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (!bufferType)
return op->emitOpError("could not infer buffer type of block argument");
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
- return *trueType;
+ return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
return op->emitError("inconsistent memory space on true/false operands");
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
- memrefType.getMemorySpace());
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
- if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(tensorType, options);
+ if (llvm::isa<BlockArgument>(value)) {
+ return cast<BufferLikeType>(
+ bufferization::getMemRefType(tensorType, options));
+ }
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return asMemRefType(getBufferType(equivalentOperand, options,
- bufferizationState, invocationStack));
+ return getBufferType(equivalentOperand, options, bufferizationState,
+ invocationStack);
}
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+ return cast<BufferLikeType>(
+ getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
- return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
return result;
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return bufferizedType;
+ return cast<BufferLikeType>(bufferizedType);
// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
- return options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options));
}
/// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
- options);
+ return cast<BufferLikeType>(
+ getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
// Best case: Both branches have the exact same buffer type.
if (thenBufferType == elseBufferType)
- return thenBufferType;
+ return cast<BufferLikeType>(thenBufferType);
// Memory space mismatch.
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
- return getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
}
};
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
};
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
- auto initArgBufferType = bufferization::detail::asMemRefType(
- bufferization::getBufferType(initArg, options, state, invocationStack));
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}
// Compute the buffer type of the yielded value.
- BaseMemRefType yieldedValueBufferType;
+ BufferLikeType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+ yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- yieldedValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+ state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- iterTensorType, yieldedBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ iterTensorType, yieldedBufferType.getMemorySpace()));
}
/// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- auto bufferType =
- bufferization::getBufferType(bbArg, options, state, invocationStack);
- if (failed(bufferType))
- return failure();
- assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
- return cast<BaseMemRefType>(*bufferType);
+ return bufferization::getBufferType(bbArg, options, state,
+ invocationStack);
}
// Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return cast<BaseMemRefType>(conditionYieldedVal.getType());
+ return cast<BufferLikeType>(conditionYieldedVal.getType());
}
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
- conditionYieldedVal, options, state, invocationStack));
+ return bufferization::getBufferType(conditionYieldedVal, options, state,
+ invocationStack);
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
- return bufferization::detail::asMemRefType(
- bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
- options, state, invocationStack));
+ return bufferization::getBufferType(
+ forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+ invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
+ return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- state, invocationStack));
+ state, invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
return {{op->getResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -68,20 +68,22 @@ struct CastOpInterface
if (isa<UnrankedTensorType>(castOp.getSource().getType())) {
// When casting to a ranked tensor, we cannot infer any static offset or
// strides from the source. Assume fully dynamic.
- return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
}
// Case 2: Casting to an unranked tensor type
if (isa<UnrankedTensorType>(castOp.getType())) {
- return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace));
}
// Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not
// change.
auto rankedResultType = cast<RankedTensorType>(castOp.getType());
- return MemRefType::get(
+ return cast<BufferLikeType>(MemRefType::get(
rankedResultType.getShape(), rankedResultType.getElementType(),
- llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace);
+ llvm::cast<MemRefType>(*maybeSrcBufferType).getLayout(), memorySpace));
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -141,7 +143,7 @@ struct CollapseShapeOpInterface
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -157,12 +159,13 @@ struct CollapseShapeOpInterface
if (!canBeCollapsed) {
// If dims cannot be collapsed, this op bufferizes to a new allocation.
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
- return bufferization::getMemRefTypeWithStaticIdentityLayout(
- tensorResultType, srcBufferType.getMemorySpace());
+ return cast<BufferLikeType>(
+ bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorResultType, srcBufferType.getMemorySpace()));
}
- return memref::CollapseShapeOp::computeCollapsedType(
- srcBufferType, collapseShapeOp.getReassociationIndices());
+ return cast<BufferLikeType>(memref::CollapseShapeOp::computeCollapsedType(
+ srcBufferType, collapseShapeOp.getReassociationIndices()));
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -319,7 +322,7 @@ struct ExpandShapeOpInterface
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -334,7 +337,7 @@ struct ExpandShapeOpInterface
expandShapeOp.getReassociationIndices());
if (failed(maybeResultType))
return failure();
- return *maybeResultType;
+ return cast<BufferLikeType>(*maybeResultType);
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -404,7 +407,7 @@ struct ExtractSliceOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -417,10 +420,10 @@ struct ExtractSliceOpInterface
SmallVector<OpFoldResult> mixedOffsets = extractSliceOp.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSliceOp.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSliceOp.getMixedStrides();
- return memref::SubViewOp::inferRankReducedResultType(
+ return cast<BufferLikeType>(memref::SubViewOp::inferRankReducedResultType(
extractSliceOp.getType().getShape(),
llvm::cast<MemRefType>(*srcMemrefType), mixedOffsets, mixedSizes,
- mixedStrides);
+ mixedStrides));
}
};
@@ -501,8 +504,8 @@ struct FromElementsOpInterface
/*copy=*/false);
if (failed(tensorAlloc))
return failure();
- FailureOr<BaseMemRefType> memrefType = bufferization::detail::asMemRefType(
- bufferization::getBufferType(*tensorAlloc, options, state));
+ FailureOr<BufferLikeType> memrefType =
+ bufferization::getBufferType(*tensorAlloc, options, state);
if (failed(memrefType))
return failure();
Value buffer = rewriter.create<bufferization::ToBufferOp>(
@@ -753,7 +756,7 @@ struct PadOpInterface
return {};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -765,9 +768,10 @@ struct PadOpInterface
if (failed(maybeSrcBufferType))
return failure();
MemRefLayoutAttrInterface layout;
- return MemRefType::get(padOp.getResultType().getShape(),
- padOp.getResultType().getElementType(), layout,
- maybeSrcBufferType->getMemorySpace());
+ return cast<BufferLikeType>(
+ MemRefType::get(padOp.getResultType().getShape(),
+ padOp.getResultType().getElementType(), layout,
+ maybeSrcBufferType->getMemorySpace()));
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
@@ -927,7 +931,7 @@ struct ReshapeOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -937,9 +941,9 @@ struct ReshapeOpInterface
reshapeOp.getSource(), options, state, invocationStack);
if (failed(maybeSourceBufferType))
return failure();
- return getMemRefTypeWithStaticIdentityLayout(
+ return cast<BufferLikeType>(getMemRefTypeWithStaticIdentityLayout(
reshapeOp.getResult().getType(),
- cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace());
+ cast<BaseMemRefType>(maybeSourceBufferType.value()).getMemorySpace()));
}
};
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
index da3c26ce36ba5..8031732011839 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir
@@ -272,10 +272,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x
// -----
-// CHECK-LABEL: func.func @test_dialect_op(
+// CHECK: func.func @custom_op(
// CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64>
// CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> {
-func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
+func.func @custom_op(%arg: !test.test_tensor<[32, 64], f64>)
-> !test.test_tensor<[32, 128], f64> {
// CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]]
// CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
@@ -288,3 +288,22 @@ func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>)
// CHECK: return %[[OUT]]
return %out : !test.test_tensor<[32, 128], f64>
}
+
+// -----
+
+// CHECK: func.func @custom_origin_op()
+// CHECK-SAME: -> !test.test_tensor<[42], f64> {
+func.func @custom_origin_op() -> !test.test_tensor<[42], f64> {
+ // CHECK: %[[MEMREF:.*]] = "test.create_memref_op"() : ()
+ // CHECK-SAME: -> !test.test_memref<[21], f64>
+ // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]])
+ // CHECK-SAME: : (!test.test_memref<[21], f64>)
+ // CHECK-SAME: -> !test.test_memref<[42], f64>
+ %in = "test.create_tensor_op"() : () -> !test.test_tensor<[21], f64>
+ %out = "test.dummy_tensor_op"(%in) : (!test.test_tensor<[21], f64>)
+ -> !test.test_tensor<[42], f64>
+
+ // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]]
+ // CHECK: return %[[OUT]]
+ return %out : !test.test_tensor<[42], f64>
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
index 78e44c6ec7a9b..b64d3b7230b36 100644
--- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
+++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp
@@ -1410,3 +1410,37 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize(
return mlir::success();
}
+
+::mlir::LogicalResult test::TestCreateTensorOp::bufferize(
+ ::mlir::RewriterBase &rewriter,
+ const ::mlir::bufferization::BufferizationOptions &options,
+ ::mlir::bufferization::BufferizationState &state) {
+ // Note: mlir::bufferization::getBufferType() would internally call
+ // TestCreateTensorOp::getBufferType()
+ const auto bufferizedOutType =
+ mlir::bufferization::getBufferType(getOutput(), options, state);
+ if (mlir::failed(bufferizedOutType))
+ return failure();
+
+ // replace op with memref analogy
+ auto createMemrefOp =
+ rewriter.create<test::TestCreateMemrefOp>(getLoc(), *bufferizedOutType);
+
+ mlir::bufferization::replaceOpWithBufferizedValues(
+ rewriter, getOperation(), createMemrefOp.getResult());
+
+ return mlir::success();
+}
+
+mlir::FailureOr<mlir::bufferization::BufferLikeType>
+test::TestCreateTensorOp::getBufferType(
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ const mlir::bufferization::BufferizationState &,
+ llvm::SmallVector<::mlir::Value> &) {
+ const auto type = dyn_cast<test::TestTensorType>(value.getType());
+ if (type == nullptr)
+ return failure();
+
+ return cast<mlir::bufferization::BufferLikeType>(test::TestMemrefType::get(
+ getContext(), type.getShape(), type.getElementType(), nullptr));
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 79bcd9c2e0a9a..2a4de535b0841 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3606,4 +3606,57 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> {
);
}
+def TestCreateTensorOp : TEST_Op<"create_tensor_op", [BufferizableOpInterface]> {
+ let arguments = (ins);
+ let results = (outs Arg<TestTensorType>:$output);
+ let extraClassDeclaration = [{
+ // BufferizableOpInterface
+ bool bufferizesToMemoryRead(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ bool bufferizesToMemoryWrite(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ bool bufferizesToAllocation(mlir::Value value);
+
+ mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&,
+ const mlir::bufferization::AnalysisState&);
+
+ mlir::LogicalResult bufferize(
+ mlir::RewriterBase& rewriter,
+ const mlir::bufferization::BufferizationOptions& options,
+ mlir::bufferization::BufferizationState &state);
+
+ mlir::FailureOr<mlir::bufferization::BufferLikeType> getBufferType(
+ mlir::Value value, const mlir::bufferization::BufferizationOptions &,
+ const mlir::bufferization::BufferizationState &,
+ llvm::SmallVector<::mlir::Value> &);
+ }];
+
+ let extraClassDefinition = [{
+ bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return true;
+ }
+ bool test::TestCreateTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return true;
+ }
+ bool test::TestCreateTensorOp::bufferizesToAllocation(mlir::Value value) {
+ return false;
+ }
+
+ ::mlir::bufferization::AliasingValueList
+ test::TestCreateTensorOp::getAliasingValues(::mlir::OpOperand&,
+ const ::mlir::bufferization::AnalysisState&) {
+ return {};
+ }
+ }];
+}
+
+def TestCreateMemrefOp : TEST_Op<"create_memref_op"> {
+ let arguments = (ins);
+ let results = (outs Arg<TestMemrefType>:$output);
+}
+
#endif // TEST_OPS
More information about the Mlir-commits
mailing list