[Mlir-commits] [mlir] 6176d6a - [mlir][tensor] Support parallel_insert_slice in MergeConsecutiveInsertExtractSlicePatterns.cpp
Matthias Springer
llvmlistbot at llvm.org
Fri Jan 6 03:34:55 PST 2023
Author: Matthias Springer
Date: 2023-01-06T12:33:45+01:00
New Revision: 6176d6a93e71e6d2bf89bd50e41c30e936ed05a9
URL: https://github.com/llvm/llvm-project/commit/6176d6a93e71e6d2bf89bd50e41c30e936ed05a9
DIFF: https://github.com/llvm/llvm-project/commit/6176d6a93e71e6d2bf89bd50e41c30e936ed05a9.diff
LOG: [mlir][tensor] Support parallel_insert_slice in MergeConsecutiveInsertExtractSlicePatterns.cpp
Differential Revision: https://reviews.llvm.org/D141116
Added:
Modified:
mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
index 262ef483d847..416988204655 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/MergeConsecutiveInsertExtractSlicePatterns.cpp
@@ -41,12 +41,13 @@ struct MergeConsecutiveExtractSlice : public OpRewritePattern<ExtractSliceOp> {
};
/// Merges consecutive tensor.insert_slice ops into one.
-struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+template <typename OpTy>
+struct MergeConsecutiveInsertSlice : public OpRewritePattern<OpTy> {
+ using OpRewritePattern<OpTy>::OpRewritePattern;
- LogicalResult matchAndRewrite(InsertSliceOp nextOp,
+ LogicalResult matchAndRewrite(OpTy nextOp,
PatternRewriter &rewriter) const override {
- auto prevOp = nextOp.getSource().getDefiningOp<InsertSliceOp>();
+ auto prevOp = nextOp.getSource().template getDefiningOp<InsertSliceOp>();
if (!prevOp)
return failure();
@@ -67,7 +68,7 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
!prevOp.getDestType().hasStaticShape())
return failure();
- rewriter.replaceOpWithNewOp<InsertSliceOp>(
+ rewriter.replaceOpWithNewOp<OpTy>(
nextOp, prevOp.getSource(), nextOp.getDest(), nextOp.getMixedOffsets(),
nextOp.getMixedSizes(), nextOp.getMixedStrides());
return success();
@@ -77,6 +78,8 @@ struct MergeConsecutiveInsertSlice : public OpRewritePattern<InsertSliceOp> {
void mlir::tensor::populateMergeConsecutiveInsertExtractSlicePatterns(
RewritePatternSet &patterns) {
- patterns.add<MergeConsecutiveExtractSlice, MergeConsecutiveInsertSlice>(
+ patterns.add<MergeConsecutiveExtractSlice,
+ MergeConsecutiveInsertSlice<InsertSliceOp>,
+ MergeConsecutiveInsertSlice<ParallelInsertSliceOp>>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
index f5d77f63561c..a120b0f1a9ca 100644
--- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
+++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir
@@ -81,3 +81,20 @@ func.func @insert_slice_rank_reducing_dynamic_shape(
// CHECK-LABEL: func.func @insert_slice_rank_reducing_dynamic_shape
// CHECK-COUNT-2: tensor.insert_slice
+
+// -----
+
+// CHECK-LABEL: func.func @parallel_insert_slice
+// CHECK-NOT: tensor.insert_slice
+// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<1x2xf32>
+func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor<f32>, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> {
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %r = scf.foreach_thread (%arg2, %arg3) in (%c1, %c2) shared_outs(%arg4 = %t0) -> (tensor<1x2xf32>) {
+ %inserted_slice = tensor.insert_slice %t1 into %t2[0, 0] [1, 1] [1, 1] : tensor<f32> into tensor<1x1xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %inserted_slice into %arg4[%arg2, %arg3] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x2xf32>
+ }
+ }
+ return %r : tensor<1x2xf32>
+}
More information about the Mlir-commits
mailing list