[Mlir-commits] [mlir] 04ff600 - [mlir][tensor][bufferize] Implement getBufferType for Expand/CollapseShapeOp
Matthias Springer
llvmlistbot at llvm.org
Wed Sep 21 02:32:14 PDT 2022
Author: Matthias Springer
Date: 2022-09-21T18:31:59+09:00
New Revision: 04ff6009fcaf2f920c2ccadfc882d51cd00eb928
URL: https://github.com/llvm/llvm-project/commit/04ff6009fcaf2f920c2ccadfc882d51cd00eb928
DIFF: https://github.com/llvm/llvm-project/commit/04ff6009fcaf2f920c2ccadfc882d51cd00eb928.diff
LOG: [mlir][tensor][bufferize] Implement getBufferType for Expand/CollapseShapeOp
This function must be implemented for all ops, where the result memref type is different from the input memref type.
Differential Revision: https://reviews.llvm.org/D134331
Added:
Modified:
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/SCF/one-shot-bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 2a3b1ee924f49..2ea19bbb0216c 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,7 +1342,13 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
"ArrayRef<ReassociationIndices>":$reassociation)>
];
- let extraClassDeclaration = commonExtraClassDeclaration;
+
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ static FailureOr<MemRefType> computeExpandedType(
+ MemRefType srcType, ArrayRef<int64_t> resultShape,
+ ArrayRef<ReassociationIndices> reassociation);
+ }];
+
let hasVerifier = 1;
}
@@ -1389,6 +1395,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
Note: This op currently assumes that the inner strides are of the
source/result layout map are the faster-varying ones.
}];
+
let builders = [
// Builders for a contracting reshape whose result type is computed from
// `src` and `reassociation`.
@@ -1422,12 +1429,16 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
build($_builder, $_state, resultType, src, reassociationMaps, attrs);
}]>
];
+
let extraClassDeclaration = commonExtraClassDeclaration # [{
/// Return `true` if this source MemRef type is guaranteed to be collapsible
/// according to the given reassociation indices. In the presence of dynamic
/// strides this is usually not the case.
static bool isGuaranteedCollapsible(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+
+ static MemRefType computeCollapsedType(
+ MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index f1c61c49473b5..227352b28de5a 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1801,9 +1801,9 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
}
-static FailureOr<MemRefType>
-computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
- ArrayRef<ReassociationIndices> reassociation) {
+FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
+ MemRefType srcType, ArrayRef<int64_t> resultShape,
+ ArrayRef<ReassociationIndices> reassociation) {
if (srcType.getLayout().isIdentity()) {
// If the source is contiguous (i.e., no layout map specified), so is the
// result.
@@ -1827,7 +1827,7 @@ void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
// Only ranked memref source values are supported.
auto srcType = src.getType().cast<MemRefType>();
FailureOr<MemRefType> resultType =
- computeExpandedType(srcType, resultShape, reassociation);
+ ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
// Failure of this assertion usually indicates a problem with the source
// type, e.g., could not get strides/offset.
assert(succeeded(resultType) && "could not compute layout");
@@ -1846,7 +1846,7 @@ LogicalResult ExpandShapeOp::verify() {
return failure();
// Compute expected result type (including layout map).
- FailureOr<MemRefType> expectedResultType = computeExpandedType(
+ FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
srcType, resultType.getShape(), getReassociationIndices());
if (failed(expectedResultType))
return emitOpError("invalid source layout map");
@@ -1943,9 +1943,8 @@ bool CollapseShapeOp::isGuaranteedCollapsible(
/*strict=*/true));
}
-static MemRefType
-computeCollapsedType(MemRefType srcType,
- ArrayRef<ReassociationIndices> reassociation) {
+MemRefType CollapseShapeOp::computeCollapsedType(
+ MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
SmallVector<int64_t> resultShape;
resultShape.reserve(reassociation.size());
for (const ReassociationIndices &group : reassociation) {
@@ -1979,7 +1978,8 @@ void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto srcType = src.getType().cast<MemRefType>();
- MemRefType resultType = computeCollapsedType(srcType, reassociation);
+ MemRefType resultType =
+ CollapseShapeOp::computeCollapsedType(srcType, reassociation);
build(b, result, resultType, src, attrs);
result.addAttribute(::mlir::getReassociationAttrName(),
getReassociationIndicesAttribute(b, reassociation));
@@ -2039,9 +2039,9 @@ struct CollapseShapeOpMemRefCastFolder
if (!CastOp::canFoldIntoConsumerOp(cast))
return failure();
- Type newResultType =
- computeCollapsedType(cast.getOperand().getType().cast<MemRefType>(),
- op.getReassociationIndices());
+ Type newResultType = CollapseShapeOp::computeCollapsedType(
+ cast.getOperand().getType().cast<MemRefType>(),
+ op.getReassociationIndices());
if (newResultType == op.getResultType()) {
rewriter.updateRootInPlace(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c6945ff0a067f..737e64258db55 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -109,6 +109,29 @@ struct CollapseShapeOpInterface
return BufferRelation::Equivalent;
}
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
+ auto maybeSrcBufferType = bufferization::getBufferType(
+ collapseShapeOp.getSrc(), options, fixedTypes);
+ if (failed(maybeSrcBufferType))
+ return failure();
+ auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+ bool canBeCollapsed = memref::CollapseShapeOp::isGuaranteedCollapsible(
+ srcBufferType, collapseShapeOp.getReassociationIndices());
+
+ if (!canBeCollapsed) {
+ // If dims cannot be collapsed, this op bufferizes to a new allocation.
+ RankedTensorType tensorResultType = collapseShapeOp.getResultType();
+ return bufferization::getMemRefTypeWithStaticIdentityLayout(
+ tensorResultType, srcBufferType.getMemorySpaceAsInt());
+ }
+
+ return memref::CollapseShapeOp::computeCollapsedType(
+ srcBufferType, collapseShapeOp.getReassociationIndices());
+ }
+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
@@ -232,6 +255,23 @@ struct ExpandShapeOpInterface
return BufferRelation::Equivalent;
}
+ FailureOr<BaseMemRefType>
+ getBufferType(Operation *op, Value value, const BufferizationOptions &options,
+ const DenseMap<Value, BaseMemRefType> &fixedTypes) const {
+ auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
+ auto maybeSrcBufferType = bufferization::getBufferType(
+ expandShapeOp.getSrc(), options, fixedTypes);
+ if (failed(maybeSrcBufferType))
+ return failure();
+ auto srcBufferType = maybeSrcBufferType->cast<MemRefType>();
+ auto maybeResultType = memref::ExpandShapeOp::computeExpandedType(
+ srcBufferType, expandShapeOp.getResultType().getShape(),
+ expandShapeOp.getReassociationIndices());
+ if (failed(maybeResultType))
+ return failure();
+ return *maybeResultType;
+ }
+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto expandShapeOp = cast<tensor::ExpandShapeOp>(op);
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
index ea2b8dbdc2bd9..7a15d9b43c58b 100644
--- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir
@@ -858,3 +858,21 @@ func.func @non_tensor_for_arg(%A : tensor<?xf32> {bufferization.writable = true}
}
return %r1#1 : tensor<?xf32>
}
+
+// -----
+
+// This is a regression test. Just check that the IR bufferizes.
+
+// CHECK-LABEL: func @buffer_type_of_collapse_shape
+func.func @buffer_type_of_collapse_shape(%arg0: tensor<f64>) {
+ %true = arith.constant true
+ %0 = scf.while (%arg1 = %arg0) : (tensor<f64>) -> (tensor<f64>) {
+ scf.condition(%true) %arg1 : tensor<f64>
+ } do {
+ ^bb0(%_: tensor<f64>):
+ %3 = bufferization.alloc_tensor() : tensor<1xf64>
+ %16 = tensor.collapse_shape %3 [] : tensor<1xf64> into tensor<f64>
+ scf.yield %16 : tensor<f64>
+ }
+ return
+}
More information about the Mlir-commits
mailing list