[Mlir-commits] [mlir] 12e41d9 - [mlir][bufferize] Infer memref types when possible
Matthias Springer
llvmlistbot at llvm.org
Sun May 15 17:02:24 PDT 2022
Author: Matthias Springer
Date: 2022-05-16T02:02:08+02:00
New Revision: 12e41d9264b6f84213be86aab75016fb82ebc1d1
URL: https://github.com/llvm/llvm-project/commit/12e41d9264b6f84213be86aab75016fb82ebc1d1
DIFF: https://github.com/llvm/llvm-project/commit/12e41d9264b6f84213be86aab75016fb82ebc1d1.diff
LOG: [mlir][bufferize] Infer memref types when possible
Instead of recomputing memref types from tensor types, try to infer them when possible. This results in more precise layout maps.
Differential Revision: https://reviews.llvm.org/D125614
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index bb9ec01380e4e..a6233295cb650 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -503,6 +503,9 @@ struct BufferizationState {
Optional<Operation *> customCopyInsertionPoint = None);
/// Return the buffer type for a given OpOperand (tensor) after bufferization.
+ ///
+ /// Note: Op implementations should preferrably call `getBuffer()->getType()`.
+ /// This function should only be used if `getBuffer` cannot be used.
BaseMemRefType getBufferType(OpOperand &opOperand) const;
/// Return a reference to the BufferizationOptions.
@@ -546,9 +549,18 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
-/// Return a MemRefType to which the `tensorType` can be bufferized in a
-/// composable fashion. The layout must be the most dynamic possible and
-/// canonicalize away once bufferization is finished.
+/// Return a MemRefType to which the `tensorType` can be bufferized.
+///
+/// If possible, op bufferization implementations should not use this function
+/// and instead infer precise memref types for tensor results by themselves.
+///
+/// Unless a layout map was specified, `options` flags determine what kind of
+/// layout map will be used. For best composability (without copies), the fully
+/// dynamic layout map is used by default.
+///
+/// Note: Canonicalization patterns could clean up layout maps and infer more
+/// precise layout maps after bufferization. However, many possible
+/// canonicalizations are currently not implemented.
BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
index 4f1add5a899b9..3bdd3d92cfdc0 100644
--- a/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -82,17 +82,22 @@ struct IndexCastOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const {
auto castOp = cast<arith::IndexCastOp>(op);
+ auto resultTensorType = castOp.getType().cast<TensorType>();
Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/);
auto sourceType = source.getType().cast<BaseMemRefType>();
// Result type should have same layout and address space as the source type.
- MemRefLayoutAttrInterface layout = {};
- if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>())
- layout = rankedMemRefType.getLayout();
- Type resultType =
- getMemRefType(castOp.getType().cast<TensorType>(), state.getOptions(),
- layout, sourceType.getMemorySpace());
+ BaseMemRefType resultType;
+ if (auto rankedMemRefType = sourceType.dyn_cast<MemRefType>()) {
+ resultType = MemRefType::get(
+ rankedMemRefType.getShape(), resultTensorType.getElementType(),
+ rankedMemRefType.getLayout(), rankedMemRefType.getMemorySpace());
+ } else {
+ auto unrankedMemrefType = sourceType.cast<UnrankedMemRefType>();
+ resultType = UnrankedMemRefType::get(resultTensorType.getElementType(),
+ unrankedMemrefType.getMemorySpace());
+ }
replaceOpWithNewBufferizedOp<arith::IndexCastOp>(rewriter, op, resultType,
source);
@@ -146,15 +151,14 @@ struct SelectOpInterface
// both of them to the most dynamic MemRef type.
if (trueBuffer.getType() != falseBuffer.getType()) {
auto trueType = trueBuffer.getType().cast<MemRefType>();
- auto tensorType = selectOp.getTrueValue().getType().cast<TensorType>();
int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset;
- SmallVector<int64_t> dynamicStrides(tensorType.getRank(),
+ SmallVector<int64_t> dynamicStrides(trueType.getRank(),
ShapedType::kDynamicStrideOrOffset);
AffineMap stridedLayout = makeStridedLinearLayoutMap(
dynamicStrides, dynamicOffset, op->getContext());
- BaseMemRefType castedType = bufferization::getMemRefType(
- tensorType, state.getOptions(), AffineMapAttr::get(stridedLayout),
- trueType.getMemorySpace());
+ auto castedType =
+ MemRefType::get(trueType.getShape(), trueType.getElementType(),
+ stridedLayout, trueType.getMemorySpaceAsInt());
trueBuffer = rewriter.create<memref::CastOp>(loc, castedType, trueBuffer);
falseBuffer =
rewriter.create<memref::CastOp>(loc, castedType, falseBuffer);
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 337b0aa57ea3e..638deb6a68c4b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -79,6 +79,7 @@ struct ExecuteRegionOpInterface
SmallVector<Type> newResultTypes;
for (Type type : executeRegionOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
+ // TODO: Infer the result type instead of computing it.
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newResultTypes.push_back(type);
@@ -188,6 +189,7 @@ struct IfOpInterface
SmallVector<Type> newTypes;
for (Type returnType : ifOp->getResultTypes()) {
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
+ // TODO: Infer the result type instead of computing it.
newTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newTypes.push_back(returnType);
diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
index ce3c85e6454e0..ac868a3723a78 100644
--- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -66,6 +66,7 @@ struct AssumingOpInterface
SmallVector<Type> newResultTypes;
for (Type type : assumingOp->getResultTypes()) {
if (auto tensorType = type.dyn_cast<TensorType>()) {
+ // TODO: Infer the result type instead of computing it.
newResultTypes.push_back(getMemRefType(tensorType, state.getOptions()));
} else {
newResultTypes.push_back(type);
More information about the Mlir-commits
mailing list