[Mlir-commits] [mlir] [mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. (PR #127943)

Ian Wood llvmlistbot at llvm.org
Tue Mar 11 20:58:43 PDT 2025


================
@@ -910,31 +875,31 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
          "preconditions for fuse operation failed");
 
   Location loc = linalgOp.getLoc();
-  // Check if reshape is expanding or collapsing.
-  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
-  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
-  bool isExpanding = (expandingReshapeOp != nullptr);
-  RankedTensorType expandedType = isExpanding
-                                      ? expandingReshapeOp.getResultType()
-                                      : collapsingReshapeOp.getSrcType();
-  RankedTensorType collapsedType = isExpanding
-                                       ? expandingReshapeOp.getSrcType()
-                                       : collapsingReshapeOp.getResultType();
+  SmallVector<OpFoldResult> expandedShape, collapsedShape;
+  SmallVector<AffineMap, 4> reassociationIndices;
+  Value src;
+  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+    // Try to move the dynamic dimensions in output shape before the `linalgOp`
+    // to maintain SSA validity
+    if (failed(moveValueDefinitions(
+            rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
+      return std::nullopt;
+
+    expandedShape = expandingReshapeOp.getMixedOutputShape();
+    reassociationIndices = expandingReshapeOp.getReassociationMaps();
+    src = expandingReshapeOp.getSrc();
+  } else {
+    auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
----------------
IanWood1 wrote:

nit: use `cast`

https://github.com/llvm/llvm-project/pull/127943


More information about the Mlir-commits mailing list