[Mlir-commits] [mlir] df5c981 - [mlir][Linalg] Add DropUnitDims support for tensor::ParallelInsertSliceOp.

Nicolas Vasilache llvmlistbot at llvm.org
Tue Jul 5 01:36:18 PDT 2022


Author: Nicolas Vasilache
Date: 2022-07-05T01:36:13-07:00
New Revision: df5c981be35a3267a50bf382d179e48bb2242e0f

URL: https://github.com/llvm/llvm-project/commit/df5c981be35a3267a50bf382d179e48bb2242e0f
DIFF: https://github.com/llvm/llvm-project/commit/df5c981be35a3267a50bf382d179e48bb2242e0f.diff

LOG: [mlir][Linalg] Add DropUnitDims support for tensor::ParallelInsertSliceOp.

ParallelInsertSlice behaves similarly to tensor::InsertSliceOp in its
rank-reducing properties.
This revision extends rank-reducing rewrite behavior and reuses most of the
existing implementation.

Differential Revision: https://reviews.llvm.org/D129091

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
    mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 2ef7e074127da..b35558e8a0a6b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -457,7 +457,7 @@ struct ReplaceUnitExtents : public OpRewritePattern<GenericOp> {
 
 namespace {
 /// Convert `extract_slice` operations to rank-reduced versions.
-struct UseRankReducedExtractSliceOp
+struct RankReducedExtractSliceOp
     : public OpRewritePattern<tensor::ExtractSliceOp> {
   using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
 
@@ -487,27 +487,37 @@ struct UseRankReducedExtractSliceOp
 };
 
 /// Convert `insert_slice` operations to rank-reduced versions.
-struct UseRankReducedInsertSliceOp
-    : public OpRewritePattern<tensor::InsertSliceOp> {
-  using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
+/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
+template <typename InsertOpTy>
+struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
+  using OpRewritePattern<InsertOpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
+  LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
                                 PatternRewriter &rewriter) const override {
-    RankedTensorType sourceType = insertOp.getSourceType();
-    SmallVector<OpFoldResult> offsets = insertOp.getMixedOffsets();
-    SmallVector<OpFoldResult> sizes = insertOp.getMixedSizes();
-    SmallVector<OpFoldResult> strides = insertOp.getMixedStrides();
+    RankedTensorType sourceType = insertSliceOp.getSourceType();
+    SmallVector<OpFoldResult> offsets = insertSliceOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = insertSliceOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = insertSliceOp.getMixedStrides();
     auto reassociation = getReassociationMapForFoldingUnitDims(sizes);
     if (!reassociation ||
         reassociation->size() == static_cast<size_t>(sourceType.getRank()))
       return failure();
-    Location loc = insertOp.getLoc();
-    auto reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
-        loc, insertOp.getSource(), *reassociation);
-    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        insertOp, reshapedSource, insertOp.getDest(),
-        insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
-        insertOp.getMixedStrides());
+    Location loc = insertSliceOp.getLoc();
+    tensor::CollapseShapeOp reshapedSource;
+    {
+      OpBuilder::InsertionGuard g(rewriter);
+      // The only 
diff erence between InsertSliceOp and ParallelInsertSliceOp is
+      // the the insertion point is just before the ParallelCombiningOp in the
+      // parallel case.
+      if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
+        rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+      reshapedSource = rewriter.create<tensor::CollapseShapeOp>(
+          loc, insertSliceOp.getSource(), *reassociation);
+    }
+    rewriter.replaceOpWithNewOp<InsertOpTy>(
+        insertSliceOp, reshapedSource, insertSliceOp.getDest(),
+        insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
+        insertSliceOp.getMixedStrides());
     return success();
   }
 };
@@ -518,8 +528,9 @@ struct UseRankReducedInsertSliceOp
 void mlir::linalg::populateFoldUnitExtentDimsPatterns(
     RewritePatternSet &patterns) {
   auto *context = patterns.getContext();
-  patterns.add<FoldUnitDimLoops, ReplaceUnitExtents,
-               UseRankReducedExtractSliceOp, UseRankReducedInsertSliceOp>(
+  patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
+               RankReducedInsertSliceOp<tensor::InsertSliceOp>,
+               RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
       context);
   linalg::FillOp::getCanonicalizationPatterns(patterns, context);
   linalg::InitTensorOp::getCanonicalizationPatterns(patterns, context);

diff  --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 8c424ffbea268..a272dea3fa232 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -827,3 +827,23 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te
 // CHECK-LABEL: func @sparse_case
 //  CHECK-NEXT:   linalg.init_tensor
 //  CHECK-NEXT:   linalg.generic
+
+// -----
+
+func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.init_tensor [4, 2] : tensor<4x2xf32>
+  %res = scf.foreach_thread (%arg0, %arg1) in (%c4, %c2) -> (tensor<4x2xf32>) {
+    %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32>
+    %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+    scf.foreach_thread.perform_concurrently {
+      //      CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
+      // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
+      tensor.parallel_insert_slice %2 into %0[%arg0, %arg1] [1, 1] [1, 1] :
+        tensor<1x1xf32> into tensor<4x2xf32>
+    }
+  }  
+  return %res: tensor<4x2xf32>
+}


        


More information about the Mlir-commits mailing list