[Mlir-commits] [mlir] [Draft][MLIR] Add reshape propagation through tensor.pad (PR #136681)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 28 08:42:45 PDT 2025


================
@@ -1101,6 +1101,84 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    // return failure if padOp has *any* dynamic padding
+    if (!padOp.getLow().empty() || !padOp.getHigh().empty()) {
+      return failure();
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    ArrayRef<int64_t> low = padOp.getStaticLow();
+    ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+      if (reInd.size() != 1 && (l != 0 || h != 0))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(padOp.getMixedLowPad()[idx]);
+        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+      }
+    }
+
+    // Calculate expanded shape manually
+    auto reshapeType = cast<RankedTensorType>(expandOp.getType());
+    ArrayRef<int64_t> finalShape = reshapeType.getShape();
+    SmallVector<int64_t> expandedShape;
+    for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+      for (auto outDimIdx : outGroup) {
+        int64_t sz = finalShape[outDimIdx] - low[inDimIdx] - high[inDimIdx];
+        expandedShape.push_back(sz);
----------------
Max191 wrote:

To support dynamic cases, this can be simplified a bit to just take the corresponding shape from the input to the pad. For example:
```
%pad = tensor.pad %in low = [0, 1] high = [0, 2] ... tensor<4x3xf32> to tensor<4x6xf32>
%expand = tensor.expand_shape %pad [[0, 1], [2]] ... tensor<4x6xf32> to tensor<2x2x6xf32>
```
You can initialize with the result shape of the expand (`[2, 2, 6]`), and replace any padded dims with the shape of the pad input (`[4, 3]`). So the final shape would become `[2, 2, 3]`, where `6` was replaced with `3`.

https://github.com/llvm/llvm-project/pull/136681


More information about the Mlir-commits mailing list