[Mlir-commits] [mlir] 6176d6a - [mlir][tensor] Support parallel_insert_slice in MergeConsecutiveInsertExtractSlicePatterns.cpp

Matthias Springer llvmlistbot at llvm.org
Fri Jan 6 03:34:55 PST 2023


Author: Matthias Springer
Date: 2023-01-06T12:33:45+01:00
New Revision: 6176d6a93e71e6d2bf89bd50e41c30e936ed05a9

URL: https://github.com/llvm/llvm-project/commit/6176d6a93e71e6d2bf89bd50e41c30e936ed05a9
DIFF: https://github.com/llvm/llvm-project/commit/6176d6a93e71e6d2bf89bd50e41c30e936ed05a9.diff

LOG: [mlir][tensor] Support parallel_insert_slice in MergeConsecutiveInsertExtractSlicePatterns.cpp

Differential Revision: https://reviews.llvm.org/D141116

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
    mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 262ef483d847..416988204655 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -41,12 +41,13 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
 };
 
 /// Merges consecutive tensor.insert_slice ops into one.
-struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
-  using OpRewritePattern::OpRewritePattern;
+template <typename OpTy>
+struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(InsertSliceOp nextOp,
+  LogicalResult matchAndRewrite(OpTy nextOp,
                                 PatternRewriter &rewriter) const override {
-    auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
+    auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
     if (!prevOp)
       return failure();
 
@@ -67,7 +68,7 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
         !prevOp.getDestType().hasStaticShape())
       return failure();
 
-    rewriter.replaceOpWithNewOp<InsertSliceOp>(
+    rewriter.replaceOpWithNewOp<OpTy>(
         nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
         nextOp.getMixedSizes(), nextOp.getMixedStrides());
     return success();
@@ -77,6 +78,8 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
 
 void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
+  patterns.add<MergeConsecutiveExtractSlice,
+               MergeConsecutiveInsertSlice<InsertSliceOp>,
+               MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
       patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
index f5d77f63561c..a120b0f1a9ca 100644
--- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
@@ -81,3 +81,20 @@ func.func @insert_slice_rank_reducing_dynamic_shape(
 
 //   CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
 // CHECK-COUNT-2:   tensor.insert_slice
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice
+//   CHECK-NOT:   tensor.insert_slice
+//       CHECK:   tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
+func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor<f32>, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> {
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %r = scf.foreach_thread (%arg2, %arg3) in (%c1, %c2) shared_outs(%arg4 = %t0) -> (tensor<1x2xf32>) {
+    %inserted_slice = tensor.insert_slice %t1 into %t2[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<1x1xf32>
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %inserted_slice into %arg4[%arg2, %arg3] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x2xf32>
+    }
+  }
+  return %r : tensor<1x2xf32>
+}


        


More information about the Mlir-commits mailing list