[Mlir-commits] [mlir] f0a2fe7 - [mlir][Linalg] Rewrite SubTensors that take a slice out of a unit-extend dimension.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 29 09:19:46 PDT 2021


Author: MaheshRavishankar
Date: 2021-03-29T09:19:36-07:00
New Revision: f0a2fe7f79d79c757fca5bd1498a014f2f98bb72

URL: https://github.com/llvm/llvm-project/commit/f0a2fe7f79d79c757fca5bd1498a014f2f98bb72
DIFF: https://github.com/llvm/llvm-project/commit/f0a2fe7f79d79c757fca5bd1498a014f2f98bb72.diff

LOG: [mlir][Linalg] Rewrite SubTensors that take a slice out of a unit-extend dimension.

Subtensor operations that are taking a slice out of a tensor that is
unit-extent along a dimension can be rewritten to drop that dimension.

Differential Revision: https://reviews.llvm.org/D99226

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2d3e16fab960..d5f08056d551 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -171,8 +171,6 @@ LogicalResult replaceBlockArgForUnitDimLoops<IndexedGenericOp>(
 
 namespace {
 /// Pattern to fold unit-trip count loops in GenericOps.
-// TODO: Generalize this to indexed-generic as well by modifying the region args
-// as well.
 template <typename GenericOpTy>
 struct FoldUnitDimLoops : public OpRewritePattern<GenericOpTy> {
   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
@@ -375,9 +373,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOpTy> {
     return success();
   }
 };
-} // namespace
 
-namespace {
 /// Pattern to fold pair of reshape ops where the intermediate has unit-dims for
 /// example:
 ///
@@ -428,12 +424,12 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
         parentSrcType.getRank() == dstType.getRank())
       return failure();
 
-    // Check if the result tensor_reshape after folding the reshapeOp and
-    // parentReshapeOp are combined.
-    // If the final tensor_reshape is folding, the parentReshapeOp is
-    // introducing unit-dims, and the reshapeOp does an actual reshape.
-    // If the final tensor_reshape op is expanding, the reshapeOp is
-    // introducing unit-dims, and the parentReshapeOp does an actual reshape.
+    // Check if the result tensor_reshape is folding or expanding after folding
+    // the reshapeOp and parentReshapeOp are combined.  If the final
+    // tensor_reshape is folding, the parentReshapeOp is introducing unit-dims,
+    // and the reshapeOp does an actual reshape.  If the final tensor_reshape op
+    // is expanding, the reshapeOp is introducing unit-dims, and the
+    // parentReshapeOp does an actual reshape.
     bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
     ArrayRef<int64_t> expandedShape =
         isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
@@ -485,6 +481,77 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
     return success();
   }
 };
+
+/// Pattern to fold subtensors that are just taking a slice of unit-dimension
+/// tensor. For example
+///
+/// %1 = subtensor %0[0, %o1, 0] [1, %s1, 1] [1, 1, 1]
+///     : tensor<1x?x1xf32> to tensor<1x?x1xf32>
+///
+/// can be replaced with
+///
+/// %0 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+///     : tensor<1x?x1xf32> into tensor<?xf32>
+/// %1 = subtensor %0[%o1] [%s1] [1] : tensor<?xf32> to tensor<?xf32>
+/// %2 = linalg.tensor_reshape %1 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
+///     : tensor<?xf32> into tensor<1x?x1xf32>
+///
+/// The additional tensor_reshapes will hopefully get canonicalized away with
+/// other reshapes that drop unit dimensions. Three condiitions to fold a
+/// dimension
+/// - The offset must be 0
+/// - The size must be 1
+/// - The dimension of the source type must be 1.
+struct FoldUnitDimSubTensorOp : public OpRewritePattern<SubTensorOp> {
+  using OpRewritePattern<SubTensorOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
+                                PatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> mixedOffsets = subTensorOp.getMixedOffsets();
+    SmallVector<OpFoldResult> mixedSizes = subTensorOp.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = subTensorOp.getMixedStrides();
+    auto hasValue = [](OpFoldResult valueOrAttr, int64_t val) {
+      auto attr = valueOrAttr.dyn_cast<Attribute>();
+      return attr && attr.cast<IntegerAttr>().getInt() == val;
+    };
+
+    if (llvm::any_of(mixedStrides, [&](OpFoldResult valueOrAttr) {
+          return !hasValue(valueOrAttr, 1);
+        }))
+      return failure();
+
+    // Find the expanded unit dimensions.
+    SmallVector<ReassociationIndices> reassociation;
+    SmallVector<OpFoldResult> newOffsets, newSizes;
+    ArrayRef<int64_t> sourceShape = subTensorOp.getSourceType().getShape();
+    ReassociationIndices curr;
+    for (int64_t dim : llvm::seq<int64_t>(0, mixedOffsets.size())) {
+      curr.push_back(dim);
+      if (sourceShape[dim] == 1 && hasValue(mixedOffsets[dim], 0) &&
+          hasValue(mixedSizes[dim], 1)) {
+        continue;
+      }
+      newOffsets.push_back(mixedOffsets[dim]);
+      newSizes.push_back(mixedSizes[dim]);
+      reassociation.emplace_back(ReassociationIndices{});
+      std::swap(reassociation.back(), curr);
+    }
+    if (newOffsets.size() == mixedOffsets.size())
+      return failure();
+    reassociation.back().append(curr.begin(), curr.end());
+    SmallVector<OpFoldResult> newStrides(newOffsets.size(),
+                                         rewriter.getI64IntegerAttr(1));
+    Location loc = subTensorOp->getLoc();
+    auto srcReshape = rewriter.create<TensorReshapeOp>(
+        loc, subTensorOp.source(), reassociation);
+    auto newSubTensorOp = rewriter.create<SubTensorOp>(
+        loc, srcReshape, newOffsets, newSizes, newStrides);
+    rewriter.replaceOpWithNewOp<TensorReshapeOp>(
+        subTensorOp, subTensorOp.getType(), newSubTensorOp, reassociation);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Patterns that are used to canonicalize the use of unit-extent dims for
@@ -493,7 +560,7 @@ void mlir::populateLinalgFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
   patterns.add<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
-               ReplaceUnitExtentTensors<GenericOp>,
+               FoldUnitDimSubTensorOp, ReplaceUnitExtentTensors<GenericOp>,
                ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
   patterns.add<FoldReshapeOpWithUnitExtent>(context);

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index cb5d1089eb85..2a6711018988 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -55,12 +55,12 @@ func @drop_one_trip_loops_indexed_generic
     outs(%shape: tensor<?x1x?x1x?xi32>) {
        ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index,
             %arg5 : index, %arg6 : i32, %arg7 : i32) :
-	 %1 = addi %arg1, %arg2 : index
-	 %2 = addi %1, %arg3 : index
-	 %3 = addi %2, %arg4 : index
-	 %4 = addi %3, %arg5 : index
-	 %5 = index_cast %4 : index to i32
-	 %6 = addi %5, %arg6 : i32
+         %1 = addi %arg1, %arg2 : index
+         %2 = addi %1, %arg3 : index
+         %3 = addi %2, %arg4 : index
+         %4 = addi %3, %arg5 : index
+         %5 = index_cast %4 : index to i32
+         %6 = addi %5, %arg6 : i32
          linalg.yield %6 : i32
        } -> tensor<?x1x?x1x?xi32>
   return %0 : tensor<?x1x?x1x?xi32>
@@ -120,8 +120,8 @@ func @drop_all_loops_indexed_generic
     outs(%arg0 : tensor<1x1xi32>) {
        ^bb0(%arg1 : index, %arg2 : index, %arg3: i32, %arg4: i32) :
          %1 = addi %arg1, %arg2 : index
-	 %2 = index_cast %1 : index to i32
-	 %3 = addi %2, %arg3 : i32
+         %2 = index_cast %1 : index to i32
+         %3 = addi %2, %arg3 : i32
          linalg.yield %3 : i32
        } -> tensor<1x1xi32>
   return %0 : tensor<1x1xi32>
@@ -390,3 +390,69 @@ func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32>
 //  CHECK-SAME:   outs(%[[FILL]] : tensor<f32>)
 //       CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor<f32> into tensor<1xf32>
 //       CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>
+
+
+// -----
+
+func @fold_subtensor(
+    %arg0 : tensor<1x?x?x1x?x1x1xf32>, %arg1 : index, %arg2 : index,
+    %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
+    -> tensor<1x?x?x1x?x1x1xf32> {
+  %0 = subtensor %arg0[0, %arg1, %arg2, 0, %arg3, 0, 0]
+                      [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
+      tensor<1x?x?x1x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
+  return %0 : tensor<1x?x?x1x?x1x1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
+//      CHECK: func @fold_subtensor
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x?x1x?x1x1xf32>
+// CHECK-SAME:   %[[ARG1:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-z0-9]+]]: index
+//      CHECK:   %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+//      CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
+// CHECK-SAME:       [%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK-SAME:       [%[[ARG4]], %[[ARG5]], %[[ARG6]]]
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
+//      CHECK:   return %[[RESULT_RESHAPE]]
+
+// -----
+
+func @no_fold_subtensor(
+    %arg0 : tensor<1x?x?x?x?x1x1xf32>, %arg1 : index, %arg2 : index,
+    %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index)
+    -> tensor<1x?x?x1x?x1x1xf32> {
+  %0 = subtensor %arg0[%arg1, 0, %arg2, 0, 0, %arg3, 0]
+                      [1, %arg4, %arg5, 1, %arg6, 1, 1] [1, 1, 1, 1, 1, 1, 1] :
+      tensor<1x?x?x?x?x1x1xf32> to tensor<1x?x?x1x?x1x1xf32>
+  return %0 : tensor<1x?x?x1x?x1x1xf32>
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1)>
+//  CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2)>
+//  CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3)>
+//  CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4)>
+//  CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6)>
+//      CHECK: func @no_fold_subtensor
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<1x?x?x?x?x1x1xf32>
+// CHECK-SAME:   %[[ARG1:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG2:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG3:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG4:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG5:[a-z0-9]+]]: index
+// CHECK-SAME:   %[[ARG6:[a-z0-9]+]]: index
+//      CHECK:   %[[SRC_RESHAPE:.+]] = linalg.tensor_reshape %[[ARG0]]
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
+//      CHECK:   %[[SUBTENSOR:.+]] = subtensor %[[SRC_RESHAPE]]
+// CHECK-SAME:       [%[[ARG1]], 0, %[[ARG2]], 0, 0, %[[ARG3]]]
+// CHECK-SAME:       [1, %[[ARG4]], %[[ARG5]], 1, %[[ARG6]], 1]
+//      CHECK:   %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[SUBTENSOR]]
+// CHECK-SAME:       [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]], #[[MAP4]], #[[MAP5]]]
+//      CHECK:   return %[[RESULT_RESHAPE]]


        


More information about the Mlir-commits mailing list