[Mlir-commits] [mlir] [Draft][MLIR] Add reshape propagation through tensor.pad (PR #136681)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 28 08:42:45 PDT 2025
================
@@ -1101,6 +1101,84 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ // return failure if padOp has *any* dynamic padding
+ if (!padOp.getLow().empty() || !padOp.getHigh().empty()) {
+ return failure();
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ // Calculate expanded shape manually
+ auto reshapeType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = reshapeType.getShape();
+ SmallVector<int64_t> expandedShape;
+ for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ for (auto outDimIdx : outGroup) {
+ int64_t sz = finalShape[outDimIdx] - low[inDimIdx] - high[inDimIdx];
+ expandedShape.push_back(sz);
+ }
+ }
+
+ // Apply the reshape to the pad's source first
+ Location loc = expandOp.getLoc();
+ Value expandedSrc = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(expandedShape,
+ padOp.getSourceType().getElementType()),
+ padOp.getSource(), reassociations);
----------------
Max191 wrote:
For dynamic cases, you will need to pass a list of OpFoldResult for the output sizes of the expand_shape. To compute this, you can do something similar to what I suggested above for the shape, but you can use the `expandOp.getMixedOutputShape()`, and replace with the pad source mixed sizes (the `tensor::getMixedSizes()` utility will be helpful for this).
https://github.com/llvm/llvm-project/pull/136681
More information about the Mlir-commits
mailing list