[Mlir-commits] [mlir] b62f9f4 - [mlir][Linalg] Add pattern to fold linalg.tensor_reshape that add unit extent dims.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 23 00:02:25 PDT 2020


Author: MaheshRavishankar
Date: 2020-09-23T00:01:58-07:00
New Revision: b62f9f4407a5ed6e5722e177e906efcebebce9eb

URL: https://github.com/llvm/llvm-project/commit/b62f9f4407a5ed6e5722e177e906efcebebce9eb
DIFF: https://github.com/llvm/llvm-project/commit/b62f9f4407a5ed6e5722e177e906efcebebce9eb.diff

LOG: [mlir][Linalg] Add pattern to fold linalg.tensor_reshape that add unit extent dims.

A sequence of two reshapes such that one of them is just adding unit
extent dims can be folded to a single reshape.

Differential Revision: https://reviews.llvm.org/D88057

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e322a853daa7..d036fe5fdebd 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -540,6 +540,7 @@ static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
       reshapeOp.getResultType().hasStaticShape() &&
       reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
     return reshapeSrcOp.src();
+  // Reshape of a constant can be replaced with a new constant.
   if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) {
     return elements.reshape(
         reshapeOp.getResult().getType().template cast<ShapedType>());

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 136647750364..08e7e352d63e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -353,12 +353,126 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
 };
 } // namespace
 
+namespace {
+/// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
+/// example:
+///
+///  %0 = linalg.tensor_reshape %arg0
+///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+///    : tensor<2048xf32> into tensor<1x4x1x512xf32>
+///  %1 = linalg.tensor_reshape %0
+///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+///     affine_map<(d0, d1, d2, d3) -> (d3)>]
+///    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
+///
+/// can be replaced with
+///
+///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
+///    : tensor<2048xf32> into tensor<4x512xf32>
+///
+/// Similarly,
+///
+///  %0 = linalg.tensor_reshape %arg0
+///    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+///     affine_map<(d0, d1, d2, d3) -> (d3)>]
+///    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
+///  %1 = linalg.tensor_reshape %0
+///   [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+///    : tensor<1x4x1x512xf32> into tensor<2048xf32>
+///
+/// can be replaced with
+///
+///  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>]
+///    : tensor<4x512xf32> into tensor<2048xf32>
+struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
+  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    // Check that the source operand is created from a reshape as well.
+    TensorReshapeOp parentReshapeOp =
+        reshapeOp.src().getDefiningOp<TensorReshapeOp>();
+    if (!parentReshapeOp)
+      return failure();
+
+    RankedTensorType srcType = reshapeOp.getSrcType(),
+                     dstType = reshapeOp.getResultType(),
+                     parentSrcType = parentReshapeOp.getSrcType();
+    if (!srcType.hasStaticShape() || !dstType.hasStaticShape() ||
+        !parentSrcType.hasStaticShape() ||
+        srcType.getRank() < dstType.getRank() ||
+        parentSrcType.getRank() == dstType.getRank())
+      return failure();
+    // Check if the result tensor_reshape after folding the reshapeOp and
+    // parentReshapeOp are combined.
+    // If the final tensor_reshape is folding, the parentReshapeOp is
+    // introducing unit-dims, and the reshapeOp does an actual reshape.
+    // If the final tensor_reshape op is expanding, the reshapeOp is introducing
+    // unit-dims, and the parentReshapeOp does an actual reshape.
+    bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
+    auto reassociationMaps = isFoldingPattern
+                                 ? reshapeOp.getReassociationMaps()
+                                 : parentReshapeOp.getReassociationMaps();
+    DenseSet<unsigned> conservedDimensions;
+    for (auto &map : reassociationMaps) {
+      if (map.getNumResults() == 1) {
+        conservedDimensions.insert(
+            map.getResult(0).cast<AffineDimExpr>().getPosition());
+      }
+    }
+
+    // Find positions at which the unit-dims exist.
+    int64_t nonUnitDimPos = 0;
+    DenseMap<unsigned, unsigned> nonUnitSrcDims;
+    ArrayRef<int64_t> nonUnitShape =
+        isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
+    for (auto shape : enumerate(srcType.getShape())) {
+      // Case 1 : It is a conserved dimension.
+      if (conservedDimensions.count(shape.index())) {
+        nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
+        continue;
+      }
+      // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
+      if (shape.value() == 1)
+        continue;
+      // Case 3 : Dimensions match, treat it as a non-unit src dim.
+      if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
+          nonUnitShape[nonUnitDimPos] == shape.value()) {
+        nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
+        continue;
+      }
+      return failure();
+    }
+
+    // Compute reassociation maps for the final operation. Use the reassociation
+    // maps that is actually doing a reshape (and not just introducing
+    // unit-dims). From these maps, prune the unit-extent dimensions.
+    for (AffineMap &map : reassociationMaps) {
+      SmallVector<AffineExpr, 4> exprs;
+      exprs.reserve(nonUnitSrcDims.size());
+      for (auto result : map.getResults()) {
+        unsigned dim = result.cast<AffineDimExpr>().getPosition();
+        if (nonUnitSrcDims.count(dim))
+          exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
+      }
+      map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
+                           rewriter.getContext());
+    }
+    rewriter.replaceOpWithNewOp<TensorReshapeOp>(
+        reshapeOp, dstType, parentReshapeOp.src(),
+        rewriter.getAffineMapArrayAttr(reassociationMaps));
+    return success();
+  }
+};
+} // namespace
+
 /// Patterns that are used to canonicalize the use of unit-extent dims for
 /// broadcasting.
 void mlir::populateLinalgFoldUnitExtentDimsPatterns(
     MLIRContext *context, OwningRewritePatternList &patterns) {
   patterns.insert<FoldUnitDimLoops, ReplaceUnitExtentTensors>(context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
+  patterns.insert<FoldReshapeOpWithUnitExtent>(context);
 }
 
 namespace {

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5e30b529ec7f..06e56c5cb7d2 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -158,3 +158,85 @@ func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
 //  CHECK-SAME:     indexing_maps = [#[[$MAP0]], #[[$MAP1]]]
 //  CHECK-SAME:     iterator_types = ["parallel", "parallel"]
 //  CHECK-SAME:     %[[A]]
+
+// -----
+
+//       CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//       CHECK: func @fold_reshape
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
+//  CHECK-SAME:   tensor<2048xf32> into tensor<4x512xf32>
+func @fold_reshape(%arg0 : tensor<2048xf32>) -> tensor<4x512xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+    : tensor<2048xf32> into tensor<1x4x1x512xf32>
+  %1 = linalg.tensor_reshape %0
+    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+     affine_map<(d0, d1, d2, d3) -> (d3)>]
+    : tensor<1x4x1x512xf32> into tensor<4x512xf32>
+  return %1 : tensor<4x512xf32>
+}
+
+// -----
+
+//       CHECK: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//       CHECK: func @fold_reshape
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]]
+//  CHECK-SAME:   tensor<4x512xf32> into tensor<2048xf32>
+func @fold_reshape(%arg0 : tensor<4x512xf32>) -> tensor<2048xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
+     affine_map<(d0, d1, d2, d3) -> (d3)>]
+    : tensor<4x512xf32> into tensor<1x4x1x512xf32>
+  %1 = linalg.tensor_reshape %0
+    [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>]
+    : tensor<1x4x1x512xf32> into tensor<2048xf32>
+  return %1 : tensor<2048xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)>
+//       CHECK: func @fold_reshape
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]]]
+//  CHECK-SAME:   tensor<2048x1xf32> into tensor<4x512x1xf32>
+func @fold_reshape(%arg0 : tensor<2048x1xf32>) -> tensor<4x512x1xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d4)>]
+    : tensor<2048x1xf32> into tensor<1x4x1x512x1xf32>
+  %1 = linalg.tensor_reshape %0
+    [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d3)>,
+     affine_map<(d0, d1, d2, d3, d4) -> (d4)>]
+    : tensor<1x4x1x512x1xf32> into tensor<4x512x1xf32>
+  return %1 : tensor<4x512x1xf32>
+}
+
+// -----
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>
+//   CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2)>
+//   CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>
+//       CHECK: func @fold_reshape
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+//  CHECK-SAME:   tensor<2048x1x2048xf32> into tensor<4x512x1x512x4xf32>
+func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
+{
+  %0 = linalg.tensor_reshape %arg0
+    [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>,
+     affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>,
+     affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7, d8)>]
+    : tensor<2048x1x2048xf32> into tensor<1x4x1x512x1x1x512x1x4xf32>
+  %1 = linalg.tensor_reshape %0
+    [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2)>,
+     affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d3, d4)>,
+     affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5)>,
+     affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d6, d7)>,
+     affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d8)>]
+    : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
+  return %1 : tensor<4x512x1x512x4xf32>
+}


        


More information about the Mlir-commits mailing list