[Mlir-commits] [mlir] 9cdf6b6 - [mlir][tensor] Support parallel_insert_slice in reassociative reshape folder

Matthias Springer llvmlistbot at llvm.org
Wed Dec 7 07:29:22 PST 2022


Author: Matthias Springer
Date: 2022-12-07T16:25:10+01:00
New Revision: 9cdf6b641da1a7ba0145b224460c64efd65017e0

URL: https://github.com/llvm/llvm-project/commit/9cdf6b641da1a7ba0145b224460c64efd65017e0
DIFF: https://github.com/llvm/llvm-project/commit/9cdf6b641da1a7ba0145b224460c64efd65017e0.diff

LOG: [mlir][tensor] Support parallel_insert_slice in reassociative reshape folder

Differential Revision: https://reviews.llvm.org/D139540

Added: 
    

Modified: 
    mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
    mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index b655df3c2cc48..d40e5f33d2a73 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -11,8 +11,6 @@
 #include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 
-#define DEBUG_TYPE "mlir-tensor-split-padding"
-
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -51,13 +49,14 @@ struct FoldExpandOfRankReducingExtract
 };
 
 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
-struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
-  using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
+template <typename OpTy>
+struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
+  LogicalResult matchAndRewrite(OpTy insertSliceOp,
                                 PatternRewriter &rewriter) const override {
     auto collapseShapeOp =
-        insertSliceOp.getSource().getDefiningOp<CollapseShapeOp>();
+        insertSliceOp.getSource().template getDefiningOp<CollapseShapeOp>();
     if (!collapseShapeOp)
       return failure();
     RankedTensorType srcType = collapseShapeOp.getSrcType();
@@ -67,16 +66,16 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
     // has no rank-reduction anymore are supported at the moment.
     RankedTensorType nonReducingInsertType =
         RankedTensorType::get(insertSliceOp.getStaticSizes(),
-                              insertSliceOp.getType().getElementType());
+                              insertSliceOp.getDestType().getElementType());
     if (nonReducingInsertType != srcType)
       return failure();
 
     SmallVector<OpFoldResult> mixedOffsets = insertSliceOp.getMixedOffsets();
     SmallVector<OpFoldResult> mixedSizes = insertSliceOp.getMixedSizes();
     SmallVector<OpFoldResult> mixedStrides = insertSliceOp.getMixedStrides();
-    rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
-        insertSliceOp, collapseShapeOp.getSrc(), insertSliceOp.getDest(),
-        mixedOffsets, mixedSizes, mixedStrides);
+    rewriter.replaceOpWithNewOp<OpTy>(insertSliceOp, collapseShapeOp.getSrc(),
+                                      insertSliceOp.getDest(), mixedOffsets,
+                                      mixedSizes, mixedStrides);
     return success();
   }
 };
@@ -84,6 +83,8 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<InsertSliceOp> {
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldExpandOfRankReducingExtract, FoldInsertOfRankReducingInsert>(
+  patterns.add<FoldExpandOfRankReducingExtract,
+               FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+               FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
       patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 15a00a58c0f5a..e6256a9b9ea8d 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -33,3 +33,22 @@ func.func @rank_reducing_insert_of_collapse_shape(
       : tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
+//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
+func.func @rank_reducing_parallel_insert_of_collapse_shape(
+    %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index, %thr: index)
+  -> tensor<?x?x?x?xf32> {
+  %0 = tensor.collapse_shape %t [[0, 1], [2], [3]]
+      : tensor<?x1x1x5xf32> into tensor<?x1x5xf32>
+  %1 = scf.foreach_thread (%iv) in (%thr) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
+    scf.foreach_thread.perform_concurrently {
+      tensor.parallel_insert_slice %0 into %o[0, 0, 0, 0][%sz, 1, 1, 5][1, 1, 1, 1]
+          : tensor<?x1x5xf32> into tensor<?x?x?x?xf32>
+    }
+  }
+  return %1 : tensor<?x?x?x?xf32>
+}


        


More information about the Mlir-commits mailing list