[Mlir-commits] [mlir] d7a9bf9 - [mlir][tensor] Fix verifier and bufferization of collapse_shape
Matthias Springer
llvmlistbot at llvm.org
Fri Apr 8 02:23:08 PDT 2022
Author: Matthias Springer
Date: 2022-04-08T18:20:40+09:00
New Revision: d7a9bf91431a08bf43cc5b7111a043de9defaee9
URL: https://github.com/llvm/llvm-project/commit/d7a9bf91431a08bf43cc5b7111a043de9defaee9
DIFF: https://github.com/llvm/llvm-project/commit/d7a9bf91431a08bf43cc5b7111a043de9defaee9.diff
LOG: [mlir][tensor] Fix verifier and bufferization of collapse_shape
Insert a buffer copy unless the dims are guaranteed to be collapsible. In the verifier, accept collapses unless they are guaranteed to be non-collapsible.
Differential Revision: https://reviews.llvm.org/D123316
Added:
Modified:
mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/test/Dialect/Tensor/bufferize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index cbd83b85a2787..7ccc2480f4be2 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -427,6 +427,8 @@ class AlwaysCopyAnalysisState : public AnalysisState {
/// BufferizationState provides helper functions for performing bufferization
/// rewrites and handling memref buffers.
struct BufferizationState {
+ enum ForceInPlacability { FORCE_INPLACE, FORCE_OUT_OF_PLACE };
+
BufferizationState(const AnalysisState &analysisState)
: analysisState(analysisState) {}
@@ -448,11 +450,19 @@ struct BufferizationState {
/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization was decided.
+ ///
+ /// Whether a buffer is in-place or out-of-place is queried from the analysis
+ /// state. Some analyses may always conservatively opt for out-of-place
+ /// bufferization. Inplacability decisions can be overridden with the optional
+ /// `overrideInPlace` parameter.
FailureOr<Value>
getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
- bool forceInPlace = false,
+ Optional<ForceInPlacability> overrideInPlace = None,
Optional<Operation *> customCopyInsertionPoint = None);
+ /// Return the buffer type for a given OpOperand (tensor) after bufferization.
+ BaseMemRefType getBufferType(OpOperand &opOperand) const;
+
/// Return a reference to the BufferizationOptions.
const BufferizationOptions &getOptions() const {
return analysisState.getOptions();
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 6b727f9183b28..5f7ec96162f88 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1295,7 +1295,14 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
"ArrayRef<ReassociationIndices>":$reassociation)>
];
- let extraClassDeclaration = commonExtraClassDeclaration;
+ let extraClassDeclaration = commonExtraClassDeclaration # [{
+ /// Return `true` if this source MemRef type is guaranteed to be collapsible
+ /// according to the given reassociation indices. In the presence of dynamic
+ /// strides this is usually not the case.
+ static bool isGuaranteedCollapsible(
+ MemRefType srcType, ArrayRef<ReassociationIndices> reassociation);
+ }];
+
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 091462f1ed73a..f2c67ed754a50 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -247,12 +247,12 @@ Value mlir::bufferization::lookupBuffer(RewriterBase &rewriter, Value tensor,
tensor);
}
-/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
+/// Return the buffer (memref) for a given OpOperand (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
-/// bufferization is necessary.
+/// bufferization was decided.
FailureOr<Value>
BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
- bool forceInPlace,
+ Optional<ForceInPlacability> overrideInPlace,
Optional<Operation *> customCopyInsertionPoint) {
const BufferizationOptions &options = analysisState.getOptions();
OpBuilder::InsertionGuard guard(rewriter);
@@ -263,7 +263,11 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
Value operand = opOperand.get();
Value operandBuffer = lookupBuffer(rewriter, operand, options);
- if (forceInPlace || analysisState.isInPlace(opOperand))
+ // Can `operandBuffer` be used directly or do we need a copy?
+ bool inplace =
+ overrideInPlace != FORCE_OUT_OF_PLACE &&
+ (overrideInPlace == FORCE_INPLACE || analysisState.isInPlace(opOperand));
+ if (inplace)
return operandBuffer;
// Bufferizing out-of-place: Allocate a new buffer.
@@ -317,6 +321,18 @@ BufferizationState::getBuffer(RewriterBase &rewriter, OpOperand &opOperand,
return resultBuffer;
}
+/// Return the buffer type for a given OpOperand (tensor) after bufferization.
+BaseMemRefType BufferizationState::getBufferType(OpOperand &opOperand) const {
+ Value tensor = opOperand.get();
+ auto tensorType = tensor.getType().dyn_cast<TensorType>();
+ assert(tensorType && "unexpected non-tensor type");
+
+ if (auto toTensorOp = tensor.getDefiningOp<bufferization::ToTensorOp>())
+ return toTensorOp.memref().getType().cast<BaseMemRefType>();
+
+ return getMemRefType(tensorType, getOptions());
+}
+
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
Operation *op,
ValueRange values) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 4ee0e6360e117..6146175debb7d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -48,8 +48,9 @@ static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
continue;
}
// Input operands are never written to.
- newInputBuffers.push_back(
- *state.getBuffer(rewriter, *opOperand, /*forceInPlace=*/true));
+ newInputBuffers.push_back(*state.getBuffer(
+ rewriter, *opOperand,
+ BufferizationState::ForceInPlacability::FORCE_INPLACE));
}
// New output operands for the cloned op.
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 86dc0199b0c39..9e1010d4896de 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1812,10 +1812,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
///
/// Note: All collapsed dims in a reassociation group must be contiguous. It is
/// not possible to check this by inspecting a MemRefType in the general case.
-/// But it is assumed. If this is not the case, the behavior is undefined.
+/// If non-contiguity cannot be checked statically, the collapse is assumed to
+/// be valid (and thus accepted by this function) unless `strict = true`.
static FailureOr<AffineMap>
computeCollapsedLayoutMap(MemRefType srcType,
- ArrayRef<ReassociationIndices> reassociation) {
+ ArrayRef<ReassociationIndices> reassociation,
+ bool strict = false) {
int64_t srcOffset;
SmallVector<int64_t> srcStrides;
auto srcShape = srcType.getShape();
@@ -1837,11 +1839,26 @@ computeCollapsedLayoutMap(MemRefType srcType,
auto stride = Wrapper::stride(resultStrides[resultStrideIndex--]);
for (int64_t idx : llvm::reverse(trailingReassocs)) {
stride = stride * Wrapper::size(srcShape[idx]);
- // Both are either static strides of the same value, or both are dynamic.
- // The dynamic case is best effort atm : we can't check it statically.
- // One exception to the dynamic check is when the srcShape is `1`, in
- // which case it can never produce a non-contiguity.
- if (stride != Wrapper::stride(srcStrides[idx - 1]) && srcShape[idx] != 1)
+
+ // Both source and result stride must have the same static value. In that
+ // case, we can be sure, that the dimensions are collapsible (because they
+ // are contiguous).
+ //
+ // One special case is when the srcShape is `1`, in which case it can
+ // never produce non-contiguity.
+ if (srcShape[idx] == 1)
+ continue;
+
+ // If `strict = false` (default during op verification), we accept cases
+ // where one or both strides are dynamic. This is best effort: We reject
+ // ops where obviously non-contiguous dims are collapsed, but accept ops
+ // where we cannot be sure statically. Such ops may fail at runtime. See
+ // the op documentation for details.
+ auto srcStride = Wrapper::stride(srcStrides[idx - 1]);
+ if (strict && (stride.saturated || srcStride.saturated))
+ return failure();
+
+ if (!stride.saturated && !srcStride.saturated && stride != srcStride)
return failure();
}
}
@@ -1849,6 +1866,16 @@ computeCollapsedLayoutMap(MemRefType srcType,
srcType.getContext());
}
+bool ExpandShapeOp::isGuaranteedCollapsible(
+ MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
+ // MemRefs with standard layout are always collapsible.
+ if (srcType.getLayout().isIdentity())
+ return true;
+
+ return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
+ /*strict=*/true));
+}
+
static MemRefType
computeCollapsedType(MemRefType srcType,
ArrayRef<ReassociationIndices> reassociation) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 01d8da85ce962..94df8c54d941e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -109,12 +109,12 @@ struct CollapseShapeOpInterface
BufferizationState &state) const {
auto collapseShapeOp = cast<tensor::CollapseShapeOp>(op);
RankedTensorType tensorResultType = collapseShapeOp.getResultType();
- Value buffer =
- *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/);
+ OpOperand &srcOperand = collapseShapeOp->getOpOperand(0) /*src*/;
+ auto bufferType = state.getBufferType(srcOperand).cast<MemRefType>();
if (tensorResultType.getRank() == 0) {
// 0-d collapses must go through a
diff erent op builder.
- auto bufferType = buffer.getType().cast<MemRefType>();
+ Value buffer = *state.getBuffer(rewriter, srcOperand);
MemRefType resultType;
if (bufferType.getLayout().isIdentity()) {
@@ -141,6 +141,18 @@ struct CollapseShapeOpInterface
return success();
}
+ // If the dims are not collapsible (due to an incompatible source layout
+ // map), force an out-of-place bufferization, i.e., a buffer copy. This
+ // newly allocated buffer will have no layout map and thus be collapsible.
+ bool canBeCollapsed = memref::ExpandShapeOp::isGuaranteedCollapsible(
+ bufferType, collapseShapeOp.getReassociationIndices());
+ Optional<BufferizationState::ForceInPlacability> overrideInPlace =
+ canBeCollapsed
+ ? None
+ : Optional<BufferizationState::ForceInPlacability>(
+ BufferizationState::ForceInPlacability::FORCE_OUT_OF_PLACE);
+ Value buffer = *state.getBuffer(rewriter, srcOperand, overrideInPlace);
+
// Result type is inferred by the builder.
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
rewriter, op, buffer, collapseShapeOp.getReassociationIndices());
@@ -248,9 +260,12 @@ struct ExtractSliceOpInterface
BufferizationState &state) const {
auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
Location loc = extractSliceOp.getLoc();
+
+ // Even if this op was decided to bufferize out-of-place, do not insert the
+ // buffer copy yet. This is done later in this function.
Value srcMemref =
*state.getBuffer(rewriter, extractSliceOp->getOpOperand(0) /*source*/,
- /*forceInPlace=*/true);
+ BufferizationState::ForceInPlacability::FORCE_INPLACE);
auto srcMemrefType = srcMemref.getType().cast<MemRefType>();
auto dstTensorType =
extractSliceOp.result().getType().cast<RankedTensorType>();
diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir
index da27b9c80b6e5..204eaab203486 100644
--- a/mlir/test/Dialect/Tensor/bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/bufferize.mlir
@@ -384,3 +384,20 @@ func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor<i32> {
%1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
return %1 : tensor<i32>
}
+
+// CHECK-LABEL: func @tensor.collapse_shape_of_slice2(
+func @tensor.collapse_shape_of_slice2(
+ %arg0: tensor<?x?x?x?xi64>, %o1: index, %o2: index, %o3: index, %o4: index)
+ -> tensor<87x63648xi64> {
+ // CHECK: %[[subview:.*]] = memref.subview %{{.*}} : memref<?x?x?x?xi64> to memref<87x78x68x12xi64, #{{.*}}>
+ %0 = tensor.extract_slice %arg0[%o1, %o2, %o3, %o4] [87, 78, 68, 12] [1, 1, 1, 1] : tensor<?x?x?x?xi64> to tensor<87x78x68x12xi64>
+
+ // This memref may not be collapsible, so the buffer must be copied to get rid
+ // of the layout map.
+ // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<87x78x68x12xi64>
+ // CHECK: memref.copy %[[subview]], %[[alloc]]
+ // CHECK: memref.collapse_shape %[[alloc]] [
+ // CHECK-SAME: [0], [1, 2, 3]] : memref<87x78x68x12xi64> into memref<87x63648xi64>
+ %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor<87x78x68x12xi64> into tensor<87x63648xi64>
+ return %1 : tensor<87x63648xi64>
+}
More information about the Mlir-commits
mailing list