[Mlir-commits] [mlir] [mlir][tensor] Fold rank increasing expand_shape into insert_slice (PR #93018)

Adam Siemieniuk llvmlistbot at llvm.org
Thu May 23 03:40:30 PDT 2024


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/93018

>From 90fbffb0374642f146b6cb377f2647f4ffda1521 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 17 May 2024 16:38:55 +0200
Subject: [PATCH 1/2] [mlir][tensor] Fold rank increasing expand_shape into
 insert_slice

---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 38 ++++++++++++--
 .../Tensor/fold-reassociative-reshapes.mlir   | 49 +++++++++++++++++++
 2 files changed, 83 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index d40e5f33d2a73..824bae63f14c6 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -79,12 +79,42 @@ struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Fold rank increasing expand_shape into insert_slice.
+template <typename OpTy>
+struct FoldRankIncreasingExpandIntoInsert : public OpRewritePattern<OpTy> {
+  using OpRewritePattern<OpTy>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(OpTy insertSliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandShapeOp = insertSliceOp.getSource()
+                             .template getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandShapeOp)
+      return failure();
+
+    // Only fold away simple rank increasing expansion.
+    SliceVerificationResult res = isRankReducedType(
+        expandShapeOp.getResultType(), expandShapeOp.getSrcType());
+    if (res != SliceVerificationResult::Success) {
+      return rewriter.notifyMatchFailure(insertSliceOp,
+                                         "expected rank increasing expansion");
+    }
+
+    rewriter.modifyOpInPlace(insertSliceOp, [&]() {
+      insertSliceOp.setOperand(/*source=*/0, expandShapeOp.getSrc());
+    });
+    return success();
+  }
+};
 } // namespace
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldExpandOfRankReducingExtract,
-               FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
-               FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>>(
-      patterns.getContext());
+  patterns
+      .add<FoldExpandOfRankReducingExtract,
+           FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+           FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
+           FoldRankIncreasingExpandIntoInsert<tensor::InsertSliceOp>,
+           FoldRankIncreasingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
+          patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index d3ac6ce792f36..9e9c66f2d3123 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -54,3 +54,52 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
   }
   return %1 : tensor<?x?x?x?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @rank_increasing_insert_of_expand_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   return %[[insert]]
+func.func @rank_increasing_insert_of_expand_shape(
+    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
+  -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
+      : tensor<?x?xf32> into tensor<1x?x?xf32>
+  %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
+      : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+  return %1 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @rank_increasing_parallel_insert_of_expand_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+func.func @rank_increasing_parallel_insert_of_expand_shape(
+    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
+  -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
+      : tensor<?x?xf32> into tensor<1x?x?xf32>
+  %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
+          : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+    }
+  }
+  return %1 : tensor<?x?x?x?xf32>
+}

>From af5fecea53423dd81d45c376f1ec3a89427716f7 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 23 May 2024 12:40:19 +0200
Subject: [PATCH 2/2] Expand test cases

---
 .../Tensor/fold-reassociative-reshapes.mlir   | 81 +++++++++++++++----
 1 file changed, 67 insertions(+), 14 deletions(-)

diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 9e9c66f2d3123..d6d10df02c36c 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -57,48 +57,101 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
 
 // -----
 
-// CHECK-LABEL: func @rank_increasing_insert_of_expand_shape(
+// CHECK-LABEL: func @insert_of_rank_increasing_expand_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
 //       CHECK:   return %[[insert]]
-func.func @rank_increasing_insert_of_expand_shape(
+func.func @insert_of_rank_increasing_expand_shape(
     %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
   -> tensor<?x?x?x?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
   %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
-  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
-      : tensor<?x?xf32> into tensor<1x?x?xf32>
-  %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
-      : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
+      : tensor<?x?xf32> into tensor<1x?x1x?xf32>
+  %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
+      : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
   return %1 : tensor<?x?x?x?xf32>
 }
 
 // -----
 
-// CHECK-LABEL: func @rank_increasing_parallel_insert_of_expand_shape(
+// CHECK-LABEL: func @insert_of_non_rank_increasing_expand_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
 //  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
-//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
-func.func @rank_increasing_parallel_insert_of_expand_shape(
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   return %[[insert]]
+func.func @insert_of_non_rank_increasing_expand_shape(
+    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
+  -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = tensor.insert_slice %0 into %d[%x, %y, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
+      : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %1 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @parallel_insert_of_rank_increasing_expand_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+func.func @parallel_insert_of_rank_increasing_expand_shape(
     %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
   -> tensor<?x?x?x?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
   %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
-  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [1, %sz0, %sz1]
-      : tensor<?x?xf32> into tensor<1x?x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1], [2, 3]] output_shape [1, %sz0, 1, %sz1]
+      : tensor<?x?xf32> into tensor<1x?x1x?xf32>
+  %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, %sz0, 1, %sz1][1, 1, 1, 1]
+          : tensor<1x?x1x?xf32> into tensor<?x?x?x?xf32>
+    }
+  }
+  return %1 : tensor<?x?x?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @parallel_insert_of_non_rank_increasing_expand_shape(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
+//       CHECK:   tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+func.func @parallel_insert_of_non_rank_increasing_expand_shape(
+    %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
+  -> tensor<?x?x?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %sz0 = tensor.dim %t, %c0 : tensor<?x?xf32>
+  %sz1 = tensor.dim %t, %c1 : tensor<?x?xf32>
+  %0 = tensor.expand_shape %t [[0, 1], [2]] output_shape [%sz, %sz0, %sz1]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
   %1 = scf.forall (%i, %j) in (%x, %y) shared_outs(%o = %d) -> (tensor<?x?x?x?xf32>) {
     scf.forall.in_parallel {
-      tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][1, 1, %sz0, %sz1][1, 1, 1, 1]
-          : tensor<1x?x?xf32> into tensor<?x?x?x?xf32>
+      tensor.parallel_insert_slice %0 into %o[%i, %j, 0, 0][%sz, 1, %sz0, %sz1][1, 1, 1, 1]
+          : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
     }
   }
   return %1 : tensor<?x?x?x?xf32>



More information about the Mlir-commits mailing list