[Mlir-commits] [mlir] [mlir][bufferization] Return BufferLikeType in BufferizableOpInterface (PR #144867)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jun 19 03:33:23 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
Support custom types (2/N): allow value-owning operations (e.g. allocation ops) to bufferize custom tensors into custom buffers. This requires BufferizableOpInterface::getBufferType() to return BufferLikeType instead of BaseMemRefType.
Affected implementors of the interface are updated accordingly.
---
Patch is 32.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144867.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+1-1)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td (+1-1)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td (+3-3)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h (+1)
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h (+2-2)
- (modified) mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp (+4-4)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-6)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-2)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp (+8-7)
- (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+28-33)
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+28-24)
- (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir (+21-2)
- (modified) mlir/test/lib/Dialect/Test/TestOpDefs.cpp (+34)
- (modified) mlir/test/lib/Dialect/Test/TestOps.td (+53)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index c1529a36465ac..6245f88db3d19 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value,
/// This is the default implementation of
/// BufferizableOpInterface::getBufferType. Should not be called from other
/// places.
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
defaultGetBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
index cafe05fe5f189..246ae77f327cf 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td
@@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
Note: This interface method should never be called directly from user
code. Always use `bufferization::getBufferType`.
}],
- /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>",
+ /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>",
/*methodName=*/"getBufferType",
/*args=*/(ins "::mlir::Value":$value,
"const ::mlir::bufferization::BufferizationOptions &":$options,
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
index 32c53ea9c494a..f175b15c8770f 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
@@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
AliasingValueList getAliasingValues(
OpOperand &opOperand, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack);
@@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
bool isWritable(Value value, const AnalysisState &state);
- FailureOr<BaseMemRefType> getBufferType(
+ FailureOr<BufferLikeType> getBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &state, SmallVector<Value> &invocationStack) {
- return ::llvm::cast<BaseMemRefType>(getBuffer().getType());
+ return getBuffer().getType();
}
}];
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
index cbb6054fcf886..da7fee4b4a220 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h
@@ -13,6 +13,7 @@
// Bufferization Type Interfaces
//===----------------------------------------------------------------------===//
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Types.h"
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
index f56c10555f02c..e8a81c74bd77a 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h
@@ -32,7 +32,7 @@ template <typename ConcreteModel, typename ConcreteOp>
struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
: public BufferizableOpInterface::ExternalModel<ConcreteModel, ConcreteOp> {
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel
if (!bufferType)
return op->emitOpError("could not infer buffer type of block argument");
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
protected:
diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
index 85d1b5ac73bf4..afee162053bea 100644
--- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -181,7 +181,7 @@ struct SelectOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -196,17 +196,17 @@ struct SelectOpInterface
if (failed(trueType) || failed(falseType))
return failure();
if (*trueType == *falseType)
- return *trueType;
+ return cast<BufferLikeType>(*trueType);
if (trueType->getMemorySpace() != falseType->getMemorySpace())
return op->emitError("inconsistent memory space on true/false operands");
// If the buffers have different types, they differ only in their layout
// map.
auto memrefType = llvm::cast<MemRefType>(*trueType);
- return getMemRefTypeWithFullyDynamicLayout(
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
RankedTensorType::get(memrefType.getShape(),
memrefType.getElementType()),
- memrefType.getMemorySpace());
+ memrefType.getMemorySpace()));
}
};
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 2ab182c9b7b2e..55784ac20d353 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands(
return AliasingOpOperandList(std::move(result));
}
-FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
+FailureOr<BufferLikeType> bufferization::detail::defaultGetBufferType(
Value value, const BufferizationOptions &options,
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
@@ -953,8 +953,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
- if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(tensorType, options);
+ if (llvm::isa<BlockArgument>(value)) {
+ return cast<BufferLikeType>(
+ bufferization::getMemRefType(tensorType, options));
+ }
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -966,8 +968,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
// If the OpResult has an equivalent OpOperand, both OpResult and
// OpOperand bufferize to the exact same buffer type.
Value equivalentOperand = aliases.getAliases().front().opOperand->get();
- return asMemRefType(getBufferType(equivalentOperand, options,
- bufferizationState, invocationStack));
+ return getBufferType(equivalentOperand, options, bufferizationState,
+ invocationStack);
}
// If we do not know the memory space and there is no default memory space,
@@ -977,7 +979,8 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
+ return cast<BufferLikeType>(
+ getMemRefType(tensorType, options, /*layout=*/{}, *memSpace));
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 9bd87d66c7d36..66949c96798de 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
return {};
}
-FailureOr<BaseMemRefType>
+FailureOr<BufferLikeType>
AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) {
@@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
return getOperation()->emitError("could not infer memory space");
}
- return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace);
+ return cast<BufferLikeType>(
+ getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
}
LogicalResult AllocTensorOp::verify() {
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
index 453ed43bcadd2..bd2aebca68079 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp
@@ -211,7 +211,7 @@ struct CallOpInterface
return result;
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -229,12 +229,13 @@ struct CallOpInterface
Type resultType =
funcType.getResult(cast<OpResult>(value).getResultNumber());
if (auto bufferizedType = dyn_cast<BaseMemRefType>(resultType))
- return bufferizedType;
+ return cast<BufferLikeType>(bufferizedType);
// Otherwise, call the type converter to compute the bufferized type.
auto tensorType = cast<TensorType>(resultType);
- return options.functionArgTypeConverterFn(
- tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
+ return cast<BufferLikeType>(options.functionArgTypeConverterFn(
+ tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp,
+ options));
}
/// All function arguments are writable. It is the responsibility of the
@@ -396,7 +397,7 @@ struct FuncOpInterface
return getAliasingBranchOpOperands(op, cast<BlockArgument>(value), state);
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -405,8 +406,8 @@ struct FuncOpInterface
// Function arguments are special.
if (bbArg.getOwner() == &funcOp.getBody().front())
- return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(),
- options);
+ return cast<BufferLikeType>(
+ getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options));
return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel::
getBufferType(op, value, options, state, invocationStack);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 58562536be61f..d36d91249ed36 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -274,7 +274,7 @@ struct IfOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -313,15 +313,15 @@ struct IfOpInterface
// Best case: Both branches have the exact same buffer type.
if (thenBufferType == elseBufferType)
- return thenBufferType;
+ return cast<BufferLikeType>(thenBufferType);
// Memory space mismatch.
if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace())
return op->emitError("inconsistent memory space on then/else branches");
// Layout maps are different: Promote to fully dynamic layout map.
- return getMemRefTypeWithFullyDynamicLayout(
- cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()));
}
};
@@ -392,7 +392,7 @@ struct IndexSwitchOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -436,7 +436,7 @@ struct IndexSwitchOpInterface
cast<TensorType>(value.getType()), bufferType.getMemorySpace());
}
- return bufferType;
+ return cast<BufferLikeType>(bufferType);
}
};
@@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs,
/// If both buffer types are equal, no casts are needed the computed buffer type
/// can be used directly. Otherwise, the buffer types can only differ in their
/// layout map and a cast must be inserted.
-static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
+static FailureOr<BufferLikeType> computeLoopRegionIterArgBufferType(
Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue,
const BufferizationOptions &options, const BufferizationState &state,
SmallVector<Value> &invocationStack) {
// Determine the buffer type of the init_arg.
- auto initArgBufferType = bufferization::detail::asMemRefType(
- bufferization::getBufferType(initArg, options, state, invocationStack));
+ auto initArgBufferType =
+ bufferization::getBufferType(initArg, options, state, invocationStack);
if (failed(initArgBufferType))
return failure();
@@ -547,16 +547,15 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
}
// Compute the buffer type of the yielded value.
- BaseMemRefType yieldedValueBufferType;
+ BufferLikeType yieldedValueBufferType;
if (isa<BaseMemRefType>(yieldedValue.getType())) {
// scf.yield was already bufferized.
- yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType());
+ yieldedValueBufferType = cast<BufferLikeType>(yieldedValue.getType());
} else {
// Note: This typically triggers a recursive call for the buffer type of
// the iter_arg.
- auto maybeBufferType =
- bufferization::detail::asMemRefType(bufferization::getBufferType(
- yieldedValue, options, state, invocationStack));
+ auto maybeBufferType = bufferization::getBufferType(yieldedValue, options,
+ state, invocationStack);
if (failed(maybeBufferType))
return failure();
yieldedValueBufferType = *maybeBufferType;
@@ -584,8 +583,8 @@ static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType(
"expected same shape");
}
#endif // NDEBUG
- return getMemRefTypeWithFullyDynamicLayout(
- iterTensorType, yieldedBufferType.getMemorySpace());
+ return cast<BufferLikeType>(getMemRefTypeWithFullyDynamicLayout(
+ iterTensorType, yieldedBufferType.getMemorySpace()));
}
/// Return `true` if the given loop may have 0 iterations.
@@ -708,7 +707,7 @@ struct ForOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -719,12 +718,8 @@ struct ForOpInterface
if (auto opResult = dyn_cast<OpResult>(value)) {
// The type of an OpResult must match the corresponding iter_arg type.
BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult);
- auto bufferType =
- bufferization::getBufferType(bbArg, options, state, invocationStack);
- if (failed(bufferType))
- return failure();
- assert(isa<BaseMemRefType>(*bufferType) && "expected memref type");
- return cast<BaseMemRefType>(*bufferType);
+ return bufferization::getBufferType(bbArg, options, state,
+ invocationStack);
}
// Compute result/argument number.
@@ -1047,7 +1042,7 @@ struct WhileOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1081,10 +1076,10 @@ struct WhileOpInterface
Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum];
if (!isa<TensorType>(conditionYieldedVal.getType())) {
// scf.condition was already bufferized.
- return cast<BaseMemRefType>(conditionYieldedVal.getType());
+ return cast<BufferLikeType>(conditionYieldedVal.getType());
}
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
- conditionYieldedVal, options, state, invocationStack));
+ return bufferization::getBufferType(conditionYieldedVal, options, state,
+ invocationStack);
}
/// Assert that yielded values of an scf.while op are equivalent to their
@@ -1303,7 +1298,7 @@ struct ForallOpInterface
return success();
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Operation *op, Value value, const BufferizationOptions &options,
const BufferizationState &state,
SmallVector<Value> &invocationStack) const {
@@ -1312,15 +1307,15 @@ struct ForallOpInterface
if (auto bbArg = dyn_cast<BlockArgument>(value))
// A tensor block argument has the same bufferized type as the
// corresponding output operand.
- return bufferization::detail::asMemRefType(
- bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(),
- options, state, invocationStack));
+ return bufferization::getBufferType(
+ forallOp.getTiedOpOperand(bbArg)->get(), options, state,
+ invocationStack);
// The bufferized result type is the same as the bufferized type of the
// corresponding output operand.
- return bufferization::detail::asMemRefType(bufferization::getBufferType(
+ return bufferization::getBufferType(
forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options,
- state, invocationStack));
+ state, invocationStack);
}
bool isRepetitiveRegion(Operation *op, unsigned index) const {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 729c048db4560..829b2ab92ac24 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -49,7 +49,7 @@ struct CastOpInterface
return {{op->getResult(0), BufferRelation::Equivalent}};
}
- FailureOr<BaseMemRefType>
+ FailureOr<BufferLikeType>
getBufferType(Opera...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/144867
More information about the Mlir-commits
mailing list