[Mlir-commits] [mlir] b62f9f4 - [mlir][Linalg] Add pattern to fold linalg.tensor_reshape that add unit extent dims.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 23 00:02:25 PDT 2020
Author: MaheshRavishankar
Date: 2020-09-23T00:01:58-07:00
New Revision: b62f9f4407a5ed6e5722e177e906efcebebce9eb
URL: https://github.com/llvm/llvm-project/commit/b62f9f4407a5ed6e5722e177e906efcebebce9eb
DIFF: https://github.com/llvm/llvm-project/commit/b62f9f4407a5ed6e5722e177e906efcebebce9eb.diff
LOG: [mlir][Linalg] Add pattern to fold linalg.tensor_reshape that add unit extent dims.
A sequence of two reshapes such that one of them is just adding unit
extent dims can be folded to a single reshape.
Differential Revision: https://reviews.llvm.org/D88057
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e322a853daa7..d036fe5fdebd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -540,6 +540,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
reshapeOp.getResultType().hasStaticShape() &&
reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
return reshapeSrcOp.src();
+ // Reshape of a constant can be replaced with a new constant.
if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
return elements.reshape(
reshapeOp.getResult().getType().template cast<ShapedType>());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 136647750364..08e7e352d63e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -353,12 +353,126 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
};
} // namespace
+namespace {
+/// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
+/// example:
+///
+/// %0 = linalg.tensor_reshape %arg0
+/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+/// : tensor<2048xf32> into tensor<1x4x1x512xf32>
+/// %1 = linalg.tensor_reshape %0
+/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+/// affine_map<(d0, d1, d2, d3) -> (d3)>]
+/// : tensor<1x4x1x512xf32> into tensor<4x512xf32>
+///
+/// can be replaced with
+///
+/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
+/// : tensor<2048xf32> into tensor<4x512xf32>
+///
+/// Similarly,
+///
+/// %0 = linalg.tensor_reshape %arg0
+/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+/// affine_map<(d0, d1, d2, d3) -> (d3)>]
+/// : tensor<4x512xf32> into tensor<1x4x1x512xf32>
+/// %1 = linalg.tensor_reshape %0
+/// [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+/// : tensor<1x4x1x512xf32> into tensor<2048xf32>
+///
+/// can be replaced with
+///
+/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
+/// : tensor<4x512xf32> into tensor<2048xf32>
+struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
+ using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ // Check that the source operand is created from a reshape as well.
+ TensorReshapeOp parentReshapeOp =
+ reshapeOp.src().getDefiningOp<TensorReshapeOp>();
+ if (!parentReshapeOp)
+ return failure();
+
+ RankedTensorType srcType = reshapeOp.getSrcType(),
+ dstType = reshapeOp.getResultType(),
+ parentSrcType = parentReshapeOp.getSrcType();
+ if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
+ !parentSrcType.hasStaticShape() ||
+ srcType.getRank() < dstType.getRank() ||
+ parentSrcType.getRank() == dstType.getRank())
+ return failure();
+ // Check if the result tensor_reshape after folding the reshapeOp and
+ // parentReshapeOp are combined.
+ // If the final tensor_reshape is folding, the parentReshapeOp is
+ // introducing unit-dims, and the reshapeOp does an actual reshape.
+ // If the final tensor_reshape op is expanding, the reshapeOp is introducing
+ // unit-dims, and the parentReshapeOp does an actual reshape.
+ bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
+ auto reassociationMaps = isFoldingPattern
+ ? reshapeOp.getReassociationMaps()
+ : parentReshapeOp.getReassociationMaps();
+ DenseSet<unsigned> conservedDimensions;
+ for (auto &map : reassociationMaps) {
+ if (map.getNumResults() == 1) {
+ conservedDimensions.insert(
+ map.getResult(0).cast<AffineDimExpr>().getPosition());
+ }
+ }
+
+ // Find positions at which the unit-dims exist.
+ int64_t nonUnitDimPos = 0;
+ DenseMap<unsigned, unsigned> nonUnitSrcDims;
+ ArrayRef<int64_t> nonUnitShape =
+ isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
+ for (auto shape : enumerate(srcType.getShape())) {
+ // Case 1 : It is a conserved dimension.
+ if (conservedDimensions.count(shape.index())) {
+ nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
+ continue;
+ }
+ // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
+ if (shape.value() == 1)
+ continue;
+ // Case 3 : Dimensions match, treat it as a non-unit src dim.
+ if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
+ nonUnitShape[nonUnitDimPos] == shape.value()) {
+ nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
+ continue;
+ }
+ return failure();
+ }
+
+ // Compute reassociation maps for the final operation. Use the reassociation
+ // maps that is actually doing a reshape (and not just introducing
+ // unit-dims). From these maps, prune the unit-extent dimensions.
+ for (AffineMap &map : reassociationMaps) {
+ SmallVector<AffineExpr, 4> exprs;
+ exprs.reserve(nonUnitSrcDims.size());
+ for (auto result : map.getResults()) {
+ unsigned dim = result.cast<AffineDimExpr>().getPosition();
+ if (nonUnitSrcDims.count(dim))
+ exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
+ }
+ map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
+ rewriter.getContext());
+ }
+ rewriter.replaceOpWithNewOp<TensorReshapeOp>(
+ reshapeOp, dstType, parentReshapeOp.src(),
+ rewriter.getAffineMapArrayAttr(reassociationMaps));
+ return success();
+ }
+};
+} // namespace
+
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
MLIRContext *context, OwningRewritePatternList &patterns) {
patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+ patterns.insert<FoldReshapeOpWithUnitExtent>(context);
}
namespace {
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5e30b529ec7f..06e56c5cb7d2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -158,3 +158,85 @@ func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
// CHECK-SAME: %[[A]]
+
+// -----
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @fold_reshape
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
+// CHECK-SAME: tensor<2048xf32> into tensor<4x512xf32>
+func @fold_reshape(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+ : tensor<2048xf32> into tensor<1x4x1x512xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>]
+ : tensor<1x4x1x512xf32> into tensor<4x512xf32>
+ return %1 : tensor<4x512xf32>
+}
+
+// -----
+
+// CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK: func @fold_reshape
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
+// CHECK-SAME: tensor<4x512xf32> into tensor<2048xf32>
+func @fold_reshape(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3) -> (d3)>]
+ : tensor<4x512xf32> into tensor<1x4x1x512xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+ : tensor<1x4x1x512xf32> into tensor<2048xf32>
+ return %1 : tensor<2048xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+// CHECK: func @fold_reshape
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+// CHECK-SAME: tensor<2048x1xf32> into tensor<4x512x1xf32>
+func @fold_reshape(%arg0 : tensor<2048x1xf32>) -> tensor<4x512x1xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d4)>]
+ : tensor<2048x1xf32> into tensor<1x4x1x512x1xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d3)>,
+ affine_map<(d0, d1, d2, d3, d4) -> (d4)>]
+ : tensor<1x4x1x512x1xf32> into tensor<4x512x1xf32>
+ return %1 : tensor<4x512x1xf32>
+}
+
+// -----
+
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+// CHECK: func @fold_reshape
+// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
+func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
+{
+ %0 = linalg.tensor_reshape %arg0
+ [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)>]
+ : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
+ %1 = linalg.tensor_reshape %0
+ [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3, d4)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7)>,
+ affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d8)>]
+ : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
+ return %1 : tensor<4x512x1x512x4xf32>
+}
More information about the Mlir-commits
mailing list