[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)

Ian Wood llvmlistbot at llvm.org
Thu Nov 20 11:24:39 PST 2025


================
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
                                          "fusion blocked by control function");
     }
 
-    ArrayRef<int64_t> low = padOp.getStaticLow();
-    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    RankedTensorType expandedType = reshapeOp.getSrcType();
     SmallVector<ReassociationIndices> reassociations =
         reshapeOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+        padOp, expandedType.getShape(), reassociations, rewriter);
+    if (failed(maybeExpandedPadding))
+      return failure();
+    PadDimInfo expandedPadding = maybeExpandedPadding.value();
 
-    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
-      if (reInd.size() != 1 && (l != 0 || h != 0))
-        return failure();
+    Location loc = padOp->getLoc();
+    RankedTensorType expandedPaddedType =
+        padOp.getResultType().clone(expandedPadding.paddedShape);
+
+    auto newPadOp = tensor::PadOp::create(
+        rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+        expandedPadding.lowPad, expandedPadding.highPad,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+  FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+                                   ControlFusionFn foldReshapes,
+                                   PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+    if (!padOp->hasOneUse())
+      return failure();
----------------
IanWood1 wrote:

Should this be removed so that users can have control over it?

https://github.com/llvm/llvm-project/pull/168888


More information about the Mlir-commits mailing list