[Mlir-commits] [mlir] [mlir] Add bubbling patterns for non intersecting reshapes (PR #103401)

Ian Wood llvmlistbot at llvm.org
Tue Aug 13 15:45:36 PDT 2024


https://github.com/IanWood1 updated https://github.com/llvm/llvm-project/pull/103401

>From 351237c91e4f56b13171e3cf3ca453b86b79afa0 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 23 May 2024 17:24:08 -0400
Subject: [PATCH 1/3] [mlir] Add bubbling patterns for non intersecting
 reshapes

---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 71 +++++++++++++++++++
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 34 +++++++++
 2 files changed, 105 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e73df61c964341..7aa8a0b37c219c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1086,6 +1086,76 @@ struct FoldReshapeWithGenericOpByExpansion
 private:
   ControlFusionFn controlFoldingReshapes;
 };
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseOp =
+        expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp || !collapseOp->hasOneUse())
+      return failure();
+    auto expandReInds = expandOp.getReassociationIndices();
+    auto collapseReInds = collapseOp.getReassociationIndices();
+
+    // Reshapes are parallel to each other if none of the reassociation indices
+    // have greater than 1 index for both reshapes.
+    for (auto [expandReassociation, collapseReassociation] :
+         llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+        return failure();
+    }
+
+    // Compute new reassociation indices and expanded/collaped shapes.
+    SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+    Location loc = expandOp->getLoc();
+    SmallVector<OpFoldResult> collapseSizes =
+        tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+    SmallVector<OpFoldResult> expandSizes(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> newExpandSizes;
+    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+      if (collapseReassociation.size() != 1) {
+        ReassociationIndices newCollapseReassociation;
+        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+          newCollapseReassociation.push_back(index);
+          newExpandReInds.push_back({index++});
+          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+        }
+        newCollapseReInds.push_back(newCollapseReassociation);
+        expandIndex++;
+        continue;
+      }
+      ReassociationIndices newExpandReassociation;
+      auto expandReassociation = expandReInds[idx];
+      for (size_t i = 0; i < expandReassociation.size(); ++i) {
+        newExpandReassociation.push_back(index);
+        newCollapseReInds.push_back({index++});
+        newExpandSizes.push_back(expandSizes[expandIndex++]);
+      }
+      newExpandReInds.push_back(newExpandReassociation);
+      collapseIndex++;
+    }
+
+    // Swap reshape order.
+    SmallVector<Value> dynamicSizes;
+    SmallVector<int64_t> staticSizes;
+    dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+    auto expandResultType = expandOp.getResultType().clone(staticSizes);
+    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+        newExpandSizes);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        expandOp, newExpand.getResult(), newCollapseReInds);
+    return success();
+  }
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -2083,6 +2153,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                         controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
+  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
 }
 
 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index b8df5fc88e1999..86c2904218385c 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -887,3 +887,37 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
 // CHECK-SAME:       : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
 //      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @no_bubble_intersecting_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   return %[[EXPAND]]

>From 643ee8cd7a6ca805ddbd4fe482d85acd086d2782 Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 13 Aug 2024 17:43:44 +0000
Subject: [PATCH 2/3] Refactor logic and tests to Tensor

---
 .../Dialect/Tensor/Transforms/Transforms.h    |  4 +
 .../Linalg/Transforms/ElementwiseOpFusion.cpp | 71 ------------------
 .../Tensor/Transforms/ReshapePatterns.cpp     | 75 +++++++++++++++++++
 mlir/test/Dialect/Linalg/reshape_fusion.mlir  | 34 ---------
 mlir/test/Dialect/Tensor/bubble-reshapes.mlir | 47 ++++++++++++
 .../Dialect/Tensor/TestTensorTransforms.cpp   | 13 ++++
 6 files changed, 139 insertions(+), 105 deletions(-)
 create mode 100644 mlir/test/Dialect/Tensor/bubble-reshapes.mlir

diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 7f983b8b3cfd06..ae695e0326ca1a 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -67,6 +67,10 @@ void populateDropRedundantInsertSliceRankExpansionPatterns(
 /// `tensor.collapse_shape` into other ops.
 void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns);
 
+/// Populates `patterns` with patterns that bubble up `tensor.expand_shape`
+/// through `tensor.collapse_shape` ops.
+void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that fold tensor.empty with its
 /// consumers.
 ///
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 7aa8a0b37c219c..e73df61c964341 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1086,76 +1086,6 @@ struct FoldReshapeWithGenericOpByExpansion
 private:
   ControlFusionFn controlFoldingReshapes;
 };
-
-/// Pattern to bubble up a tensor.expand_shape op through a producer
-/// tensor.collapse_shape op that has non intersecting reassociations.
-struct BubbleUpExpandThroughParallelCollapse
-    : public OpRewritePattern<tensor::ExpandShapeOp> {
-  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
-                                PatternRewriter &rewriter) const override {
-    auto collapseOp =
-        expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
-    if (!collapseOp || !collapseOp->hasOneUse())
-      return failure();
-    auto expandReInds = expandOp.getReassociationIndices();
-    auto collapseReInds = collapseOp.getReassociationIndices();
-
-    // Reshapes are parallel to each other if none of the reassociation indices
-    // have greater than 1 index for both reshapes.
-    for (auto [expandReassociation, collapseReassociation] :
-         llvm::zip_equal(expandReInds, collapseReInds)) {
-      if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
-        return failure();
-    }
-
-    // Compute new reassociation indices and expanded/collaped shapes.
-    SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
-    Location loc = expandOp->getLoc();
-    SmallVector<OpFoldResult> collapseSizes =
-        tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
-    SmallVector<OpFoldResult> expandSizes(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
-    SmallVector<OpFoldResult> newExpandSizes;
-    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
-    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
-      if (collapseReassociation.size() != 1) {
-        ReassociationIndices newCollapseReassociation;
-        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
-          newCollapseReassociation.push_back(index);
-          newExpandReInds.push_back({index++});
-          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
-        }
-        newCollapseReInds.push_back(newCollapseReassociation);
-        expandIndex++;
-        continue;
-      }
-      ReassociationIndices newExpandReassociation;
-      auto expandReassociation = expandReInds[idx];
-      for (size_t i = 0; i < expandReassociation.size(); ++i) {
-        newExpandReassociation.push_back(index);
-        newCollapseReInds.push_back({index++});
-        newExpandSizes.push_back(expandSizes[expandIndex++]);
-      }
-      newExpandReInds.push_back(newExpandReassociation);
-      collapseIndex++;
-    }
-
-    // Swap reshape order.
-    SmallVector<Value> dynamicSizes;
-    SmallVector<int64_t> staticSizes;
-    dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
-    auto expandResultType = expandOp.getResultType().clone(staticSizes);
-    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
-        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
-        newExpandSizes);
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        expandOp, newExpand.getResult(), newCollapseReInds);
-    return success();
-  }
-};
-
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -2153,7 +2083,6 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                         controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
-  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
 }
 
 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index be0d71866a095e..061817e41d181e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -140,6 +140,76 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
     return success();
   }
 };
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseOp =
+        expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp || !collapseOp->hasOneUse())
+      return failure();
+    auto expandReInds = expandOp.getReassociationIndices();
+    auto collapseReInds = collapseOp.getReassociationIndices();
+
+    // Reshapes are parallel to each other if none of the reassociation indices
+    // have greater than 1 index for both reshapes.
+    for (auto [expandReassociation, collapseReassociation] :
+         llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+        return failure();
+    }
+
+    // Compute new reassociation indices and expanded/collaped shapes.
+    SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+    Location loc = expandOp->getLoc();
+    SmallVector<OpFoldResult> collapseSizes =
+        tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+    SmallVector<OpFoldResult> expandSizes(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> newExpandSizes;
+    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+      if (collapseReassociation.size() != 1) {
+        ReassociationIndices newCollapseReassociation;
+        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+          newCollapseReassociation.push_back(index);
+          newExpandReInds.push_back({index++});
+          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+        }
+        newCollapseReInds.push_back(newCollapseReassociation);
+        expandIndex++;
+        continue;
+      }
+      ReassociationIndices newExpandReassociation;
+      auto expandReassociation = expandReInds[idx];
+      for (size_t i = 0; i < expandReassociation.size(); ++i) {
+        newExpandReassociation.push_back(index);
+        newCollapseReInds.push_back({index++});
+        newExpandSizes.push_back(expandSizes[expandIndex++]);
+      }
+      newExpandReInds.push_back(newExpandReassociation);
+      collapseIndex++;
+    }
+
+    // Swap reshape order.
+    SmallVector<Value> dynamicSizes;
+    SmallVector<int64_t> staticSizes;
+    dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+    auto expandResultType = expandOp.getResultType().clone(staticSizes);
+    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+        newExpandSizes);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        expandOp, newExpand.getResult(), newCollapseReInds);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
@@ -152,3 +222,8 @@ void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
            FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
           patterns.getContext());
 }
+
+void mlir::tensor::populateBubbleUpExpandShapePatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 86c2904218385c..b8df5fc88e1999 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -887,37 +887,3 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
 //      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
 // CHECK-SAME:       : tensor<?x?x?x?x?x?xi32> into tensor<?x?x?x?xi32>
 //      CHECK:   return %[[COLLAPSE]]
-
-// -----
-
-func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
-  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
-  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
-              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
-  return %expand : tensor<?x?x?x?xf32>
-}
-//      CHECK: func @bubble_parallel_reshapes
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
-// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
-//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-//  CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
-//  CHECK-DAG:   %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
-//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
-// CHECK-SAME:       output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
-//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
-//      CHECK:   return %[[COLLAPSE]]
-
-// -----
-
-func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
-  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
-  %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
-              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
-  return %expand : tensor<?x?x?x?xf32>
-}
-//      CHECK: func @no_bubble_intersecting_reshapes
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
-//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
-//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
-//      CHECK:   return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
new file mode 100644
index 00000000000000..cf6b12852bcd39
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @no_bubble_full_intersecting_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   return %[[EXPAND]]
+
+// -----
+
+func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0, 1], [2, 3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @no_bubble_partial_intersecting_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]]
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]]
+//      CHECK:   return %[[EXPAND]]
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index ae4f77f5873e2b..34de600132f5de 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -72,6 +72,11 @@ struct TestTensorTransforms
       llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
       llvm::cl::init(false)};
 
+  Option<bool> testBubbleUpExpandShapePatterns{
+      *this, "test-expand-shape-bubbling",
+      llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
+      llvm::cl::init(false)};
+
   Option<bool> testFoldIntoPackAndUnpack{
       *this, "test-fold-into-pack-and-unpack",
       llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"),
@@ -102,6 +107,12 @@ static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyBubbleUpExpandShapePatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateBubbleUpExpandShapePatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) {
   RewritePatternSet patterns(rootOp->getContext());
   tensor::populateFoldIntoPackAndUnpackPatterns(patterns);
@@ -386,6 +397,8 @@ void TestTensorTransforms::runOnOperation() {
     applyDropRedundantInsertSliceRankExpansionPatterns(rootOp);
   if (testReassociativeReshapeFolding)
     applyReassociativeReshapeFoldingPatterns(rootOp);
+  if (testBubbleUpExpandShapePatterns)
+    applyBubbleUpExpandShapePatterns(rootOp);
   if (testFoldIntoPackAndUnpack)
     applyFoldIntoPackAndUnpackPatterns(rootOp);
   if (testRewriteExtractSliceWithTiledCollapseShape) {

>From 87b1c6fb36f076dda711ddc8e42539e8a3d3359d Mon Sep 17 00:00:00 2001
From: Ian Wood <ianwood2024 at u.northwestern.edu>
Date: Tue, 13 Aug 2024 22:45:20 +0000
Subject: [PATCH 3/3] Add to fusion pass and remove single use check

---
 mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp | 2 ++
 mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp     | 2 +-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index e73df61c964341..9f1b6fdc55df3b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Matchers.h"
@@ -2144,6 +2145,7 @@ struct LinalgElementwiseOpFusionPass
     // Add elementwise op fusion patterns.
     populateElementwiseOpsFusionPatterns(patterns, defaultControlFn);
     populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn);
+    tensor::populateBubbleUpExpandShapePatterns(patterns);
 
     // General canonicalization patterns.
     affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 061817e41d181e..5edd7a02bc42b1 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -151,7 +151,7 @@ struct BubbleUpExpandThroughParallelCollapse
                                 PatternRewriter &rewriter) const override {
     auto collapseOp =
         expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
-    if (!collapseOp || !collapseOp->hasOneUse())
+    if (!collapseOp)
       return failure();
     auto expandReInds = expandOp.getReassociationIndices();
     auto collapseReInds = collapseOp.getReassociationIndices();



More information about the Mlir-commits mailing list