[Mlir-commits] [mlir] [MLIR] Add shape propagation through tensor.pad (PR #136681)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 25 10:37:28 PDT 2025
================
@@ -1100,6 +1100,198 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ Value constantPaddingValue = padOp.getConstantPaddingValue();
+ if (!constantPaddingValue) {
+ return rewriter.notifyMatchFailure(
+ expandOp, "cannot fold with non-constant padding value");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() > 1 &&
+ (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)))
+ return rewriter.notifyMatchFailure(
+ expandOp, "fusion blocked by non-zero padding");
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ newLow.append(reInd.size(), low[idx]);
+ newHigh.append(reInd.size(), high[idx]);
+ }
+
+ Location loc = expandOp.getLoc();
+ ArrayRef<int64_t> finalShape = expandOp.getResultType().getShape();
+ SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
+
+ for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[inDimIdx];
+ OpFoldResult h = high[inDimIdx];
+
+ if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
+ assert(reInd.size() == 1 && "expected single dimension");
+ expandedShape[reInd[0]] =
+ tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);
+ ;
+ }
+ }
+
+ for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+ if (dimSize == ShapedType::kDynamic &&
+ !isa<Value>(expandedShape[outDimIdx]) &&
+ !isa<Attribute>(expandedShape[outDimIdx])) {
+ expandedShape[outDimIdx] =
+ tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
+ }
+ }
----------------
Max191 wrote:
You can delete this loop because `expandOp.getMixedOutputShape()` will already populate the expandedShape with the right dynamic values.
https://github.com/llvm/llvm-project/pull/136681
More information about the Mlir-commits
mailing list