[Mlir-commits] [mlir] 4317a3d - [mlir][Linalg] Disable fusion of reshape with elementwise ops for purely dynamic cases.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 6 10:32:35 PST 2022
Author: MaheshRavishankar
Date: 2022-01-06T10:32:24-08:00
New Revision: 4317a3dfad52ee2830d686c6f3a9ef8011f7e6ad
URL: https://github.com/llvm/llvm-project/commit/4317a3dfad52ee2830d686c6f3a9ef8011f7e6ad
DIFF: https://github.com/llvm/llvm-project/commit/4317a3dfad52ee2830d686c6f3a9ef8011f7e6ad.diff
LOG: [mlir][Linalg] Disable fusion of reshape with elementwise ops for purely dynamic cases.
`tensor.collapse_shape` op when fused with a consumer elementwise
`linalg.generic` operation results in creation of tensor.expand_shape
ops. In purely dynamic cases this can end up with a dynamic dimensions
being expanded to more than one dynamic dimension. This is disallowed
by the semantics of `tensor.expand_shape` operation. (While the
transformation is itself correct, its a gap in the specification of
`tensor.expand_shape` that is the issue). So disallow fusions which
result in such a pattern.
Differential Revision: https://reviews.llvm.org/D116703
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/test/Dialect/Linalg/reshape_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 6fd3927c80cac..8f7c331597bce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -524,6 +524,7 @@ class ExpansionInfo {
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
+ ArrayRef<int64_t> collapsedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
@@ -533,6 +534,7 @@ class ExpansionInfo {
ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
+ ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
@@ -541,6 +543,8 @@ class ExpansionInfo {
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
SmallVector<SmallVector<int64_t>> expandedShapeMap;
+ /// Extent of the loop in the original operation.
+ SmallVector<int64_t> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
@@ -549,6 +553,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
ArrayRef<int64_t> expandedShape,
+ ArrayRef<int64_t> collapsedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
@@ -558,6 +563,8 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
linalgOp.getStaticLoopRanges();
if (!originalLoopRange)
return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range");
+ originalLoopExtent.assign(originalLoopRange->begin(),
+ originalLoopRange->end());
reassociation.clear();
expandedShapeMap.clear();
@@ -576,7 +583,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
// The remaining dimensions remain the same.
for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
if (expandedShapeMap[i].empty())
- expandedShapeMap[i] = {(*originalLoopRange)[i]};
+ expandedShapeMap[i] = {originalLoopExtent[i]};
// Compute reassociation map from the original op to the expanded op.
unsigned sum = 0;
@@ -601,6 +608,30 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
LogicalResult isGenericOpExpandable(GenericOp genericOp,
const ExpansionInfo &expansionInfo,
PatternRewriter &rewriter) {
+ // Current reshape only supports expansion of a dynamic dim when only one of
+ // the expanded dims are dynamic.
+ for (auto originalShape : llvm::enumerate(expansionInfo.getOriginalShape()))
+ if (ShapedType::isDynamic(originalShape.value())) {
+ // All but one of the expanded dims must be static.
+ bool foundDynamicExpandedDim = false;
+ for (auto expandedShape :
+ expansionInfo.getExpandedShapeOfDim(originalShape.index())) {
+ if (ShapedType::isDynamic(expandedShape)) {
+ if (foundDynamicExpandedDim) {
+ return rewriter.notifyMatchFailure(
+ genericOp,
+ "cannot expanded dynamic dims into multiple dynamic dims");
+ }
+ foundDynamicExpandedDim = true;
+ }
+ }
+ if (!foundDynamicExpandedDim) {
+ return rewriter.notifyMatchFailure(
+ genericOp, "dynamic dim expansion needs at least one dynamic dim "
+ "in result shape");
+ }
+ }
+
if (!genericOp.hasIndexSemantics())
return success();
for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
@@ -731,13 +762,16 @@ fuseWithReshapeByExpansion(GenericOp genericOp, Operation *reshapeOp,
RankedTensorType expandedType = isExpanding
? expandingReshapeOp.getResultType()
: collapsingReshapeOp.getSrcType();
+ RankedTensorType collapsedType = isExpanding
+ ? expandingReshapeOp.getSrcType()
+ : collapsingReshapeOp.getResultType();
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
genericOp, fusableOpOperand,
isExpanding ? expandingReshapeOp.getReassociationMaps()
: collapsingReshapeOp.getReassociationMaps(),
- expandedType.getShape(), rewriter)))
+ expandedType.getShape(), collapsedType.getShape(), rewriter)))
return llvm::None;
if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter)))
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 324aa2809ce37..9582a4bbafc43 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -507,3 +507,26 @@ func @unit_dim_reshape_expansion_full
// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>)
// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)
+// -----
+
+func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>
+ %1 = tensor.dim %0, %c0 : tensor<?xf32>
+ %2 = linalg.init_tensor [%1] : tensor<?xf32>
+ %3 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]}
+ ins(%0 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {
+ ^bb0(%arg1 : f32, %arg2: f32):
+ %4 = arith.addf %arg1, %arg1 : f32
+ linalg.yield %4 : f32
+ } -> tensor<?xf32>
+ return %3 : tensor<?xf32>
+}
+// CHECK: func @no_fuse_dynamic_dims
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
+// CHECK: return %[[GENERIC]]
More information about the Mlir-commits
mailing list