[Mlir-commits] [mlir] [MLIR] Add shape propagation through tensor.pad (PR #136681)

Hyunsung Lee llvmlistbot at llvm.org
Wed Jul 16 14:49:47 PDT 2025


================
@@ -1100,6 +1102,267 @@ class FoldPadWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+bool isZero(OpFoldResult value) {
+  if (auto attr = dyn_cast<Attribute>(value)) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+      return intAttr.getInt() == 0;
+  }
+  if (auto val = dyn_cast<Value>(value)) {
+    if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+      if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+        return attr.getInt() == 0;
+    }
+  }
+  return false;
+}
+
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+  FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+                                          ControlFusionFn foldReshapes,
+                                          PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
+    }
+
+    Value constantPaddingValue = padOp.getConstantPaddingValue();
+    if (!constantPaddingValue) {
+      return rewriter.notifyMatchFailure(
+          expandOp, "cannot fold with non-constant padding value");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+    SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      OpFoldResult l = low[idx];
+      OpFoldResult h = high[idx];
+      if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
+        return failure();
+    }
+
+    SmallVector<OpFoldResult> newLow, newHigh;
+    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      for (size_t i = 0; i < reInd.size(); ++i) {
+        newLow.push_back(low[idx]);
+        newHigh.push_back(high[idx]);
+      }
+    }
+
+    Location loc = expandOp.getLoc();
+    auto finalType = cast<RankedTensorType>(expandOp.getType());
+    ArrayRef<int64_t> finalShape = finalType.getShape();
+
+    SmallVector<OpFoldResult> expandedShape;
+    for (int64_t dimSize : finalShape) {
+      if (dimSize == ShapedType::kDynamic) {
+        expandedShape.push_back(OpFoldResult{});
+      } else {
+        expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+      }
+    }
----------------
ita9naiwa wrote:

```
/Users/ita/src/llvm-project/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp:1257:59: error: no member named 'getMixedOutputShape' in 'mlir::tensor::CollapseShapeOp'
 1257 |     SmallVector<OpFoldResult> collapsedShape = collapseOp.getMixedOutputShape();
 ```

https://github.com/llvm/llvm-project/pull/136681


More information about the Mlir-commits mailing list