[Mlir-commits] [mlir] [mlir][Tensor] Generalize the pattern to swap `tensor.collapse_shape` -> `tensor.expand_shape`. (PR #133819)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 15 14:10:13 PDT 2025


================
@@ -166,56 +166,127 @@ struct BubbleUpExpandThroughParallelCollapse
       return failure();
     }
 
-    // Reshapes are parallel to each other if none of the reassociation indices
-    // have greater than 1 index for both reshapes.
+    // Reshapes are parallel to each other (by construction the number of
+    // reassociations specified in the collapse and expand are the same), if at
+    // any position
+    // 1. either the reassociation indices are of the same size, or
+    // 2. either the reassociation in the collapse or the expand is of size 1.
+    ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
+    ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
     for (auto [expandReassociation, collapseReassociation] :
          llvm::zip_equal(expandReInds, collapseReInds)) {
+      if (collapseReassociation.size() == expandReassociation.size()) {
+        // Even if the reassociations are the same, the collapse/expand should
+        // result in the same dimensions. i.e  4x8x2 into 64 should be expanded
+        // into 4x8x2 again. In presense of dynamic dimensions one can only
+        // verify "equality" when there is only one dynamic dimension present,
+        // and all other static dimensions are equal.
+        ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
+            collapseReassociation.front(), collapseReassociation.size());
+        int64_t numCollapsedDynamic =
+            llvm::count_if(collapsedStaticShapes,
+                           [](int64_t d) { return ShapedType::isDynamic(d); });
+        ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
+            expandReassociation.front(), expandReassociation.size());
+        int64_t numExpandedDynamic =
+            llvm::count_if(expandedStaticShapes,
+                           [](int64_t d) { return ShapedType::isDynamic(d); });
+        if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
+            collapsedStaticShapes != expandedStaticShapes) {
+          return failure();
+        }
+        continue;
+      }
+      // If the reassociations are not same, one or the other needs to be of
+      // size one.
       if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
         return failure();
     }
 
     // Compute new reassociation indices and expanded/collaped shapes.
     SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
     Location loc = expandOp->getLoc();
-    SmallVector<OpFoldResult> collapseSizes =
+    SmallVector<OpFoldResult> sourceSizes =
         tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
-    SmallVector<OpFoldResult> expandSizes(getMixedValues(
-        expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
     SmallVector<OpFoldResult> newExpandSizes;
-    int64_t index = 0, expandIndex = 0, collapseIndex = 0;
-    for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+
+    int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
+            resultSizeIndex = 0;
+
+    for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end;
----------------
MaheshRavishankar wrote:

Fixed to camelCase. I tried to do
```
for (auto &[collapsedReassocation, expandReassocation] : llvm::zip_equal(collapseReInds, expandReInds))
```

and it didnt compile. Dont know why.

https://github.com/llvm/llvm-project/pull/133819


More information about the Mlir-commits mailing list