[Mlir-commits] [mlir] 892fdc9 - [mlir][Linalg] Generalize the logic to compute reassociation maps

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 30 07:58:59 PDT 2020


Author: Mahesh Ravishankar
Date: 2020-09-30T07:58:06-07:00
New Revision: 892fdc923f06adbef507ebe594fa7b48224d93f0

URL: https://github.com/llvm/llvm-project/commit/892fdc923f06adbef507ebe594fa7b48224d93f0
DIFF: https://github.com/llvm/llvm-project/commit/892fdc923f06adbef507ebe594fa7b48224d93f0.diff

LOG: [mlir][Linalg] Generalize the logic to compute reassociation maps
while folding tensor_reshape op.

While folding reshapes that introduce unit extent dims, the logic to
compute the reassociation maps can be generalized to handle some
corner cases, for example, when the folded shape still has unit-extent
dims but corresponds to folded unit extent dims of the expanded shape.

Differential Revision: https://reviews.llvm.org/D88521

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 08e7e352d63e9..611c938ab542f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -403,61 +403,58 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
         srcType.getRank() < dstType.getRank() ||
         parentSrcType.getRank() == dstType.getRank())
       return failure();
+
     // Check if the result tensor_reshape after folding the reshapeOp and
     // parentReshapeOp are combined.
     // If the final tensor_reshape is folding, the parentReshapeOp is
     // introducing unit-dims, and the reshapeOp does an actual reshape.
-    // If the final tensor_reshape op is expanding, the reshapeOp is introducing
-    // unit-dims, and the parentReshapeOp does an actual reshape.
+    // If the final tensor_reshape op is expanding, the reshapeOp is
+    // introducing unit-dims, and the parentReshapeOp does an actual reshape.
     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
-    auto reassociationMaps = isFoldingPattern
-                                 ? reshapeOp.getReassociationMaps()
-                                 : parentReshapeOp.getReassociationMaps();
-    DenseSet<unsigned> conservedDimensions;
-    for (auto &map : reassociationMaps) {
-      if (map.getNumResults() == 1) {
-        conservedDimensions.insert(
-            map.getResult(0).cast<AffineDimExpr>().getPosition());
-      }
-    }
-
-    // Find positions at which the unit-dims exist.
-    int64_t nonUnitDimPos = 0;
-    DenseMap<unsigned, unsigned> nonUnitSrcDims;
-    ArrayRef<int64_t> nonUnitShape =
+    ArrayRef<int64_t> expandedShape =
         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
-    for (auto shape : enumerate(srcType.getShape())) {
-      // Case 1 : It is a conserved dimension.
-      if (conservedDimensions.count(shape.index())) {
-        nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
-        continue;
+    ArrayRef<int64_t> foldedShape =
+        isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
+
+    unsigned expandedDim = 0, foldedDim = 0;
+    SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
+        foldedShape.size());
+    while (expandedDim < expandedShape.size() &&
+           foldedDim < foldedShape.size()) {
+      int64_t dstSize = foldedShape[foldedDim];
+      int64_t srcSize = expandedShape[expandedDim];
+      while (srcSize < dstSize && expandedDim < expandedShape.size()) {
+        reassociationExprs[foldedDim].push_back(
+            rewriter.getAffineDimExpr(expandedDim++));
+        srcSize *= expandedShape[expandedDim];
       }
-      // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
-      if (shape.value() == 1)
-        continue;
-      // Case 3 : Dimensions match, treat it as a non-unit src dim.
-      if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
-          nonUnitShape[nonUnitDimPos] == shape.value()) {
-        nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
-        continue;
+      if (srcSize == dstSize) {
+        reassociationExprs[foldedDim].push_back(
+            rewriter.getAffineDimExpr(expandedDim++));
+        // If the next dim in foldedShape is not 1, treat subsequent dims in
+        // expandedShape which are 1 to be collapsed.
+        if (foldedDim == foldedShape.size() - 1 ||
+            foldedShape[foldedDim + 1] != 1) {
+          while (expandedDim < expandedShape.size() &&
+                 expandedShape[expandedDim] == 1) {
+            reassociationExprs[foldedDim].push_back(
+                rewriter.getAffineDimExpr(expandedDim++));
+          }
+        }
+      } else {
+        return failure();
       }
-      return failure();
+      foldedDim++;
     }
+    if (expandedDim != expandedShape.size())
+      return failure();
 
-    // Compute reassociation maps for the final operation. Use the reassociation
-    // maps that is actually doing a reshape (and not just introducing
-    // unit-dims). From these maps, prune the unit-extent dimensions.
-    for (AffineMap &map : reassociationMaps) {
-      SmallVector<AffineExpr, 4> exprs;
-      exprs.reserve(nonUnitSrcDims.size());
-      for (auto result : map.getResults()) {
-        unsigned dim = result.cast<AffineDimExpr>().getPosition();
-        if (nonUnitSrcDims.count(dim))
-          exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
-      }
-      map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
-                           rewriter.getContext());
-    }
+    SmallVector<AffineMap, 4> reassociationMaps =
+        llvm::to_vector<4>(llvm::map_range(
+            reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
+              return AffineMap::get(expandedShape.size(), 0, exprs,
+                                    rewriter.getContext());
+            }));
     rewriter.replaceOpWithNewOp<TensorReshapeOp>(
         reshapeOp, dstType, parentReshapeOp.src(),
         rewriter.getAffineMapArrayAttr(reassociationMaps));

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 06e56c5cb7d2a..1793d2b59b706 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -240,3 +240,19 @@ func @fold_reshape(%arg0 : tensor<2048x1x2048xf32>) -> tensor<4x512x1x512x4xf32>
     : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32>
   return %1 : tensor<4x512x1x512x4xf32>
 }
+
+// -----
+
+//   CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+//       CHECK: func @fold_reshape
+//       CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]]
+//  CHECK-SAME:   tensor<2xf32> into tensor<2x1xf32>
+func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32>
+{
+  %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32>
+  %1 = linalg.tensor_reshape %0
+  [affine_map<(d0, d1, d2) -> (d0)>,
+   affine_map<(d0, d1, d2) -> (d1, d2)>
+  ] : tensor<2x1x1xf32> into tensor<2x1xf32>
+  return %1 : tensor<2x1xf32>
+}


        


More information about the Mlir-commits mailing list