[Mlir-commits] [mlir] [mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. (PR #127943)
Ian Wood
llvmlistbot at llvm.org
Tue Mar 11 20:58:43 PDT 2025
================
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
"preconditions for fuse operation failed");
Location loc = linalgOp.getLoc();
- // Check if reshape is expanding or collapsing.
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
- bool isExpanding = (expandingReshapeOp != nullptr);
- RankedTensorType expandedType = isExpanding
- ? expandingReshapeOp.getResultType()
- : collapsingReshapeOp.getSrcType();
- RankedTensorType collapsedType = isExpanding
- ? expandingReshapeOp.getSrcType()
- : collapsingReshapeOp.getResultType();
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
+ SmallVector<AffineMap, 4> reassociationIndices;
+ Value src;
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+ // Try to move the dynamic dimensions in output shape before the `linalgOp`
+ // to maintain SSA validity
+ if (failed(moveValueDefinitions(
+ rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
+ return std::nullopt;
+
+ expandedShape = expandingReshapeOp.getMixedOutputShape();
+ reassociationIndices = expandingReshapeOp.getReassociationMaps();
+ src = expandingReshapeOp.getSrc();
+ } else {
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
----------------
IanWood1 wrote:
nit: use `cast`
https://github.com/llvm/llvm-project/pull/127943
More information about the Mlir-commits
mailing list