[Mlir-commits] [mlir] [mlir][bufferization] Use Type instead of Value in unknown conversion (PR #144658)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jun 18 02:21:33 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
Author: Andrei Golubev (andrey-golubev)
<details>
<summary>Changes</summary>
Generally, bufferization should be able to create a memref from a tensor without needing to know more than just a mlir::Type. Thus, change BufferizationOptions::UnknownTypeConverterFn to accept just a type (mlir::TensorType for now) instead of mlir::Value. Additionally, apply the same rationale to getMemRefType() helper function.
Both changes are prerequisites to enable custom types support in one-shot bufferization.
---
Full diff: https://github.com/llvm/llvm-project/pull/144658.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+5-4)
- (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+9-10)
- (modified) mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp (+2-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index adccbef754ec5..2fb795f16ae2c 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -265,9 +265,9 @@ struct BufferizationOptions {
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
func::FuncOp, const BufferizationOptions &)>;
/// Tensor -> MemRef type converter.
- /// Parameters: Value, memory space, bufferization options
+ /// Parameters: tensor type, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
- Value, Attribute memorySpace, const BufferizationOptions &)>;
+ TensorType, Attribute memorySpace, const BufferizationOptions &)>;
// Produce a MemorySpace attribute from a tensor type
using DefaultMemorySpaceFn =
std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
-/// Return a MemRefType to which the type of the given value can be bufferized.
+/// Return a MemRefType to which the TensorType can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
@@ -667,7 +667,8 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
-BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
+BaseMemRefType getMemRefType(TensorType tensorType,
+ const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = nullptr);
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 1d6e1bdaf80f5..dd43647682ea2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -345,10 +345,9 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
}
/// Default unknown type converter: Use a fully dynamic layout map.
BaseMemRefType
-defaultUnknownTypeConverter(Value value, Attribute memorySpace,
+defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithFullyDynamicLayout(
- llvm::cast<TensorType>(value.getType()), memorySpace);
+ return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
}
} // namespace
@@ -724,7 +723,8 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return getMemRefType(cast<TensorType>(value.getType()), options,
+ /*layout=*/{}, *memSpace);
}
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,12 +797,10 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
// Bufferization-specific IRMapping support with debugging.
//===----------------------------------------------------------------------===//
-BaseMemRefType bufferization::getMemRefType(Value value,
+BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
Attribute memorySpace) {
- auto tensorType = llvm::cast<TensorType>(value.getType());
-
// Case 1: Unranked memref type.
if (auto unrankedTensorType =
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -819,7 +817,7 @@ BaseMemRefType bufferization::getMemRefType(Value value,
memorySpace);
}
- return options.unknownTypeConverterFn(value, memorySpace, options);
+ return options.unknownTypeConverterFn(tensorType, memorySpace, options);
}
BaseMemRefType
@@ -955,10 +953,11 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
const BufferizationState &bufferizationState,
SmallVector<Value> &invocationStack) {
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
+ auto tensorType = cast<TensorType>(value.getType());
// No further analysis is possible for a block argument.
if (llvm::isa<BlockArgument>(value))
- return bufferization::getMemRefType(value, options);
+ return bufferization::getMemRefType(tensorType, options);
// Value is an OpResult.
Operation *op = getOwnerOfValue(value);
@@ -981,7 +980,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
if (!memSpace.has_value())
return op->emitError("could not infer memory space");
- return getMemRefType(value, options, /*layout=*/{}, *memSpace);
+ return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
}
bool bufferization::detail::defaultIsRepetitiveRegion(
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index c7681d309a4af..7e9b9119ce949 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
"'unknown-type-conversion'");
return signalPassFailure();
}
- opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
+ opt.unknownTypeConverterFn = [=](TensorType tensorType,
+ Attribute memorySpace,
const BufferizationOptions &options) {
- auto tensorType = cast<TensorType>(value.getType());
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index a3ab53d818115..15e5102462ad7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
OneShotBufferizationOptions options;
options.bufferizeFunctionBoundaries = true;
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
- options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
+ options.unknownTypeConverterFn = [](TensorType tensorType,
+ Attribute memorySpace,
const BufferizationOptions &options) {
- return getMemRefTypeWithStaticIdentityLayout(
- cast<TensorType>(value.getType()), memorySpace);
+ return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
};
if (analysisOnly) {
options.testAnalysisOnly = true;
``````````
</details>
https://github.com/llvm/llvm-project/pull/144658
More information about the Mlir-commits
mailing list