[Mlir-commits] [mlir] [mlir] Add bubbling patterns for non intersecting reshapes (PR #94637)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 6 09:24:29 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (Max191)

<details>
<summary>Changes</summary>

This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other.

---
Full diff: https://github.com/llvm/llvm-project/pull/94637.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h (+42-10) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+71) 
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+34) 
- (modified) mlir/test/Dialect/Tensor/canonicalize.mlir (+70-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..96f0f7bf1aa49 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,51 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
-
+  // Fold identity reshape.
   if (reshapeOp.getSrcType() == reshapeOp.getType())
     return reshapeOp.getSrc();
 
-  // Fold producer-consumer reshape ops where the operand type of the
-  // producer is same as the return type of the consumer.
-  auto reshapeSrcOp =
-      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
-    return reshapeSrcOp.getSrc();
-
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
 
+  // Fold if the producer reshape source has the same shape with at most 1
+  // dynamic dimension.
+  auto reshapeSrcOp =
+      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+  if (!reshapeSrcOp)
+    return nullptr;
+  auto srcType = reshapeSrcOp.getSrcType();
+  auto resultType = reshapeOp.getResultType();
+  if (srcType != resultType)
+    return nullptr;
+
+  if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+    return reshapeSrcOp.getSrc();
+  }
+
+  // Fold producer-consumer reshape ops when they are perfect inverses of each
+  // other:
+  //   1) Reassociation indices are equivalent.
+  //   2) Boundary types are equivalent.
+  //   3) No reassociations have more than 1 dynamic dimension, and reassociated
+  //      shapes are equal for each reassociation.
+  auto reassociations = reshapeOp.getReassociationIndices();
+  if (reassociations != reshapeSrcOp.getReassociationIndices())
+    return nullptr;
+  // If the reshapes are expanding and then collapsing, the ops can be folded
+  // despite multiple dynamic dimensions.
+  if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+    return reshapeSrcOp.getSrc();
+  ArrayRef<int64_t> expandedSrcShape = srcType.getShape();
+  ArrayRef<int64_t> expandedResultShape = resultType.getShape();
+  if (llvm::all_of(reassociations, [&](auto reInd) {
+        ArrayRef<int64_t> srcSlice =
+            expandedSrcShape.slice(reInd.front(), reInd.size());
+        return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
+      })) {
+    return reshapeSrcOp.getSrc();
+  }
   return nullptr;
 }
 
@@ -360,10 +390,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
 
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape)
+        if (srcSubShape == resultSubShape &&
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
           composedReassociation.push_back(srcIndices);
-        else
+        } else {
           return std::nullopt;
+        }
       }
 
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index ad313c2d5ce60..579116904aad2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1023,6 +1023,76 @@ struct FoldReshapeWithGenericOpByExpansion
 private:
   ControlFusionFn controlFoldingReshapes;
 };
+
+/// Pattern to bubble up a tensor.expand_shape op through a producer
+/// tensor.collapse_shape op that has non intersecting reassociations.
+struct BubbleUpExpandThroughParallelCollapse
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+  using OpRewritePattern<tensor::ExpandShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    auto collapseOp =
+        expandOp.getSrc().getDefiningOp<tensor::CollapseShapeOp>();
+    if (!collapseOp || !collapseOp->hasOneUse())
+      return failure();
+    auto expandReInds = expandOp.getReassociationIndices();
+    auto collapseReInds = collapseOp.getReassociationIndices();
+
+    // Reshapes are parallel to each other if none of the reassociation indices
+    // have greater than 1 index for both reshapes.
+    for (auto [expandReassociation, collapseReassociation] :
+         llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
+        return failure();
+    }
+
+    // Compute new reassociation indices and expanded/collaped shapes.
+    SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
+    Location loc = expandOp->getLoc();
+    SmallVector<OpFoldResult> collapseSizes =
+        tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
+    SmallVector<OpFoldResult> expandSizes(getMixedValues(
+        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> newExpandSizes;
+    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
+    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+      if (collapseReassociation.size() != 1) {
+        ReassociationIndices newCollapseReassociation;
+        for (size_t i = 0; i < collapseReassociation.size(); ++i) {
+          newCollapseReassociation.push_back(index);
+          newExpandReInds.push_back({index++});
+          newExpandSizes.push_back(collapseSizes[collapseIndex++]);
+        }
+        newCollapseReInds.push_back(newCollapseReassociation);
+        expandIndex++;
+        continue;
+      }
+      ReassociationIndices newExpandReassociation;
+      auto expandReassociation = expandReInds[idx];
+      for (size_t i = 0; i < expandReassociation.size(); ++i) {
+        newExpandReassociation.push_back(index);
+        newCollapseReInds.push_back({index++});
+        newExpandSizes.push_back(expandSizes[expandIndex++]);
+      }
+      newExpandReInds.push_back(newExpandReassociation);
+      collapseIndex++;
+    }
+
+    // Swap reshape order.
+    SmallVector<Value> dynamicSizes;
+    SmallVector<int64_t> staticSizes;
+    dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes);
+    auto expandResultType = expandOp.getResultType().clone(staticSizes);
+    auto newExpand = rewriter.create<tensor::ExpandShapeOp>(
+        loc, expandResultType, collapseOp.getSrc(), newExpandReInds,
+        newExpandSizes);
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        expandOp, newExpand.getResult(), newCollapseReInds);
+    return success();
+  }
+};
+
 } // namespace
 
 //===---------------------------------------------------------------------===//
@@ -1939,6 +2009,7 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
+  patterns.add<BubbleUpExpandThroughParallelCollapse>(patterns.getContext());
 }
 
 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index f42666f81bbad..1354b138983a0 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -826,3 +826,37 @@ func.func @linalg_add_reshape_producer_fusion(%arg0 : tensor<?x7x?x8xf32>,
 // CHECK-SAME:     [0, 1], [2, 3]
 // CHECK-SAME:     tensor<?x7x?x8xf32> into tensor<?x?xf32>
 //      CHECK:   return %[[T4]]
+
+// -----
+
+func.func @bubble_parallel_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @bubble_parallel_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME:   %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index
+//  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//  CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
+//  CHECK-DAG:   %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
+//  CHECK-DAG:   %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
+// CHECK-SAME:       output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor<?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+//      CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @no_bubble_intersecting_reshapes(%arg0: tensor<?x?x?x?xf32>, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor<?x?x?x?xf32> {
+  %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]]
+              output_shape [%s0, %s1, %s2, %s3] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %expand : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @no_bubble_intersecting_reshapes
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]]
+//      CHECK:   return %[[EXPAND]]
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index f7fbd3834288b..9a6b03986ccb6 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand
-//   CHECK-NOT:   linalg.{{.*}}shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
   return %1 : tensor<?x?xf32>
 }
 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
-//   CHECK-NOT:   linalg.{{.*}}_shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
+      : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
+      : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
+//       CHECK:   tensor.expand_shape
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
+//       CHECK:   return %[[COLLAPSE]]
+
+// -----
+
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<3x4x4xf32> into tensor<12x4xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+      : tensor<12x4xf32> into tensor<3x4x4xf32>
+  return %1 : tensor<3x4x4xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x4x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+}
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+
+// -----
+
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
 
 // -----
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/94637


More information about the Mlir-commits mailing list