[Mlir-commits] [mlir] 2117677 - [mlir] Fix bugs in expand_shape patterns after semantics changes (#94631)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 7 06:09:55 PDT 2024

Author: Max191
Date: 2024-06-07T09:09:51-04:00
New Revision: 2117677e304d334326f6591f3c75fb2f34dc4bcb

URL: https://github.com/llvm/llvm-project/commit/2117677e304d334326f6591f3c75fb2f34dc4bcb
DIFF: https://github.com/llvm/llvm-project/commit/2117677e304d334326f6591f3c75fb2f34dc4bcb.diff

LOG: [mlir] Fix bugs in expand_shape patterns after semantics changes (#94631)

After the `output_shape` field was added to `expand_shape` ops,
dynamically sized expand shapes are now possible, but this was not
accounted for in the folder. This PR tightens the constraints of the
folder to fix this.




diff  --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index e8f6edc3f133e..89bc57f09ec8b 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -85,21 +85,49 @@ bool isReassociationValid(ArrayRef<AffineMap> reassociation,
 template <typename ReshapeOpTy, typename InverseReshapeOpTy>
 static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp,
                                   ArrayRef<Attribute> operands) {
+  // Fold identity reshape.
   if (reshapeOp.getSrcType() == reshapeOp.getType())
     return reshapeOp.getSrc();
-  // Fold producer-consumer reshape ops where the operand type of the
-  // producer is same as the return type of the consumer.
-  auto reshapeSrcOp =
-      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
-  if (reshapeSrcOp && reshapeSrcOp.getSrcType() == reshapeOp.getResultType())
-    return reshapeSrcOp.getSrc();
   // Reshape of a constant can be replaced with a new constant.
   if (auto elements = dyn_cast_or_null<DenseElementsAttr>(operands.front()))
     return elements.reshape(cast<ShapedType>(reshapeOp.getResult().getType()));
+  // Fold if the producer reshape source has the same shape with at most 1
+  // dynamic dimension.
+  auto reshapeSrcOp =
+      reshapeOp.getSrc().template getDefiningOp<InverseReshapeOpTy>();
+  if (!reshapeSrcOp)
+    return nullptr;
+  auto srcType = reshapeSrcOp.getSrcType();
+  auto resultType = reshapeOp.getResultType();
+  if (srcType != resultType)
+    return nullptr;
+  if (llvm::count_if(srcType.getShape(), ShapedType::isDynamic) < 2) {
+    return reshapeSrcOp.getSrc();
+  }
+  // Fold producer-consumer reshape ops when they are perfect inverses of each
+  // other:
+  //   1) Reassociation indices are equivalent.
+  //   2) Boundary types are equivalent.
+  //   3) No reassociations have more than 1 dynamic dimension, and reassociated
+  //      shapes are equal for each reassociation.
+  auto reassociations = reshapeOp.getReassociationIndices();
+  if (reassociations != reshapeSrcOp.getReassociationIndices())
+    return nullptr;
+  // If the reshapes are expanding and then collapsing, the ops can be folded
+  // despite multiple dynamic dimensions.
+  if (srcType.getRank() < reshapeSrcOp.getResultType().getRank())
+    return reshapeSrcOp.getSrc();
+  if (llvm::all_of(reassociations, [&](auto reInd) {
+        ArrayRef<int64_t> srcSlice =
+            srcType.getShape().slice(reInd.front(), reInd.size());
+        return llvm::count_if(srcSlice, ShapedType::isDynamic) < 2;
+      })) {
+    return reshapeSrcOp.getSrc();
+  }
   return nullptr;
@@ -360,10 +388,12 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
           resultShape.slice(resultIndices.front(), resultIndices.size());
       if (srcSubShape.size() == resultSubShape.size()) {
-        if (srcSubShape == resultSubShape)
+        if (srcSubShape == resultSubShape &&
+            llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
-        else
+        } else {
           return std::nullopt;
+        }
       // Find reassociation to collapse `srcSubShape` into `resultSubShape`.

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 6b51d0b294bcf..df2bea08577e2 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1139,7 +1139,7 @@ func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32>
   return %1 : tensor<12x4xf32>
 // CHECK-LABEL: @fold_collapse_of_expand
-//   CHECK-NOT:   linalg.{{.*}}shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
 // -----
@@ -1152,7 +1152,75 @@ func.func @fold_collapse_of_expand_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index
   return %1 : tensor<?x?xf32>
 // CHECK-LABEL: @fold_collapse_of_expand_dynamic
-//   CHECK-NOT:   linalg.{{.*}}_shape
+//   CHECK-NOT:   tensor.{{.*}}_shape
+// -----
+func.func @fold_collapse_of_expand_fully_dynamic(%arg0 : tensor<?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+// CHECK-LABEL: @fold_collapse_of_expand_fully_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+// -----
+func.func @no_fold_parallel_collapse_of_expand_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index, %arg4: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.expand_shape %arg0 [[0, 1], [2], [3]] output_shape [%arg1, %arg2, %arg3, %arg4]
+      : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0], [1], [2, 3]]
+      : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+// CHECK-LABEL: @no_fold_parallel_collapse_of_expand_dynamic
+//       CHECK:   tensor.expand_shape
+//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape
+//       CHECK:   return %[[COLLAPSE]]
+// -----
+func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<3x4x4xf32> into tensor<12x4xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [3, 4, 4]
+      : tensor<12x4xf32> into tensor<3x4x4xf32>
+  return %1 : tensor<3x4x4xf32>
+// CHECK-LABEL: @fold_expand_of_collapse
+//   CHECK-NOT:   tensor.{{.*}}_shape
+// -----
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
+    -> tensor<?x4x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x4x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
+      : tensor<?x?xf32> into tensor<?x4x?xf32>
+  return %1 : tensor<?x4x?xf32>
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
+//   CHECK-NOT:   tensor.{{.*}}_shape
+// -----
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
+    -> tensor<?x?x?xf32> {
+  %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
+      : tensor<?x?xf32> into tensor<?x?x?xf32>
+  return %1 : tensor<?x?x?xf32>
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
+//       CHECK:   tensor.collapse_shape
+//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
+//       CHECK:   return %[[EXPAND]]
 // -----


More information about the Mlir-commits mailing list