[Mlir-commits] [mlir] [mlir][Tensor] Generalize the pattern to swap `tensor.collapse_shape` -> `tensor.expand_shape`. (PR #133819)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 15 14:10:13 PDT 2025
================
@@ -166,56 +166,127 @@ struct BubbleUpExpandThroughParallelCollapse
return failure();
}
- // Reshapes are parallel to each other if none of the reassociation indices
- // have greater than 1 index for both reshapes.
+ // Reshapes are parallel to each other (by construction the number of
+ // reassociations specified in the collapse and expand are the same), if at
+ // any position
+ // 1. either the reassociation indices are of the same size, or
+ // 2. either the reassociation in the collapse or the expand is of size 1.
+ ArrayRef<int64_t> staticSourceSize = collapseOp.getSrcType().getShape();
+ ArrayRef<int64_t> staticResultSize = expandOp.getStaticOutputShape();
for (auto [expandReassociation, collapseReassociation] :
llvm::zip_equal(expandReInds, collapseReInds)) {
+ if (collapseReassociation.size() == expandReassociation.size()) {
+ // Even if the reassociations are the same, the collapse/expand should
+ // result in the same dimensions. i.e 4x8x2 into 64 should be expanded
+ // into 4x8x2 again. In presense of dynamic dimensions one can only
+ // verify "equality" when there is only one dynamic dimension present,
+ // and all other static dimensions are equal.
+ ArrayRef<int64_t> collapsedStaticShapes = staticSourceSize.slice(
+ collapseReassociation.front(), collapseReassociation.size());
+ int64_t numCollapsedDynamic =
+ llvm::count_if(collapsedStaticShapes,
+ [](int64_t d) { return ShapedType::isDynamic(d); });
+ ArrayRef<int64_t> expandedStaticShapes = staticResultSize.slice(
+ expandReassociation.front(), expandReassociation.size());
+ int64_t numExpandedDynamic =
+ llvm::count_if(expandedStaticShapes,
+ [](int64_t d) { return ShapedType::isDynamic(d); });
+ if (numCollapsedDynamic > 1 || numExpandedDynamic > 1 ||
+ collapsedStaticShapes != expandedStaticShapes) {
+ return failure();
+ }
+ continue;
+ }
+ // If the reassociations are not same, one or the other needs to be of
+ // size one.
if (collapseReassociation.size() != 1 && expandReassociation.size() != 1)
return failure();
}
// Compute new reassociation indices and expanded/collaped shapes.
SmallVector<ReassociationIndices> newExpandReInds, newCollapseReInds;
Location loc = expandOp->getLoc();
- SmallVector<OpFoldResult> collapseSizes =
+ SmallVector<OpFoldResult> sourceSizes =
tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc());
- SmallVector<OpFoldResult> expandSizes(getMixedValues(
- expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> resultSizes = expandOp.getMixedOutputShape();
SmallVector<OpFoldResult> newExpandSizes;
- int64_t index = 0, expandIndex = 0, collapseIndex = 0;
- for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) {
+
+ int64_t newExpandIndex = 0, newCollapseIndex = 0, sourceSizeIndex = 0,
+ resultSizeIndex = 0;
+
+ for (size_t idx = 0, idx_end = collapseReInds.size(); idx < idx_end;
----------------
MaheshRavishankar wrote:
Fixed to camelCase. I tried to do
```
for (auto &[collapsedReassocation, expandReassocation] : llvm::zip_equal(collapseReInds, expandReInds))
```
and it didnt compile. Dont know why.
https://github.com/llvm/llvm-project/pull/133819
More information about the Mlir-commits
mailing list