[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)
Ian Wood
llvmlistbot at llvm.org
Thu Nov 20 11:24:39 PST 2025
================
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
- for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && (l != 0 || h != 0))
- return failure();
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+ expandedPadding.lowPad, expandedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+ FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
----------------
IanWood1 wrote:
Should this be removed so that users can have control over it?
https://github.com/llvm/llvm-project/pull/168888
More information about the Mlir-commits
mailing list