[Mlir-commits] [mlir] [Draft][MLIR] Add reshape propagation through tensor.pad (PR #136681)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 28 08:42:45 PDT 2025


================
@@ -1101,6 +1101,84 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    // return failure if padOp has *any* dynamic padding
+    if (!padOp.getLow().empty() || !padOp.getHigh().empty()) {
+      return failure();
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+      if (reInd.size() != 1 && (l != 0 || h != 0))
+        return failure();
+    }
----------------
Max191 wrote:

This is already all the restriction that is needed on this pattern. Any non-expanded dimension can have any amount of padding (static or dynamic), since it is not touched by the reshape.

https://github.com/llvm/llvm-project/pull/136681


More information about the Mlir-commits mailing list