[Mlir-commits] [mlir] [MLIR] Add shape propagation through tensor.pad (PR #136681)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jul 28 06:48:18 PDT 2025


================
@@ -1100,6 +1100,174 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    Value constantPaddingValue = padOp.getConstantPaddingValue();
+    if (!constantPaddingValue) {
+      return rewriter.notifyMatchFailure(
+          expandOp, "cannot fold with non-constant padding value");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() > 1 && (!isConstantIntValue(low[idx], 0) ||
+                               !isConstantIntValue(high[idx], 0)))
+        return rewriter.notifyMatchFailure(
+            expandOp, "fusion blocked by non-zero padding");
+
+      newLow.append(reInd.size(), low[idx]);
+      newHigh.append(reInd.size(), high[idx]);
+    }
+
+    Location loc = expandOp.getLoc();
+    SmallVector<OpFoldResult> expandedShape = expandOp.getMixedOutputShape();
+    for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[inDimIdx];
+      OpFoldResult h = high[inDimIdx];
+
+      if (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0)) {
+        assert(reInd.size() == 1 && "expected single dimension");
+        expandedShape[reInd[0]] =
+            tensor::getMixedSize(rewriter, loc, padOp.getSource(), inDimIdx);
+        ;
+      }
+    }
+
+    SmallVector<int64_t> staticExpandedShape;
+    std::tie(staticExpandedShape, std::ignore) =
+        decomposeMixedValues(expandedShape);
+
+    auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+        loc,
+        RankedTensorType::get(staticExpandedShape,
+                              padOp.getSource().getType().getElementType()),
+        padOp.getSource(), reassociations, expandedShape);
+
+    rewriter.replaceOpWithNewOp<tensor::PadOp>(
+        expandOp, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
+/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
+/// by bubbling the collapse_shape before the pad.
+struct FoldReshapeWithProducerPadOpByCollapsing
+    : public OpRewritePattern<tensor::CollapseShapeOp> {
+
+  FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();
+
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(collapseOp,
+                                         "fusion blocked by control function");
+    }
+
+    Value constantPaddingValue = padOp.getConstantPaddingValue();
+    if (!constantPaddingValue) {
+      return rewriter.notifyMatchFailure(
+          collapseOp, "cannot fold with non-constant padding value");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        collapseOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      if (reInd.size() > 1) {
+        for (auto dimIdx : reInd) {
+          if (!isConstantIntValue(low[dimIdx], 0) ||
+              !isConstantIntValue(high[dimIdx], 0)) {
+            return failure();
+          }
+        }
+      }
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      newLow.push_back(low[reInd[0]]);
+      newHigh.push_back(high[reInd[0]]);
+    }
+
----------------
Max191 wrote:

nit: combine this loop with the loop above.

https://github.com/llvm/llvm-project/pull/136681


More information about the Mlir-commits mailing list