[Mlir-commits] [mlir] [MLIR] Add shape propagation through tensor.pad (PR #136681)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 16 06:41:15 PDT 2025
================
@@ -1100,6 +1102,267 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+bool isZero(OpFoldResult value) {
+ if (auto attr = dyn_cast<Attribute>(value)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 0;
+ }
+ if (auto val = dyn_cast<Value>(value)) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return attr.getInt() == 0;
+ }
+ }
+ return false;
+}
+
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ Value constantPaddingValue = padOp.getConstantPaddingValue();
+ if (!constantPaddingValue) {
+ return rewriter.notifyMatchFailure(
+ expandOp, "cannot fold with non-constant padding value");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[idx];
+ OpFoldResult h = high[idx];
+ if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(low[idx]);
+ newHigh.push_back(high[idx]);
+ }
+ }
+
+ Location loc = expandOp.getLoc();
+ auto finalType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = finalType.getShape();
+
+ SmallVector<OpFoldResult> expandedShape;
+ for (int64_t dimSize : finalShape) {
+ if (dimSize == ShapedType::kDynamic) {
+ expandedShape.push_back(OpFoldResult{});
+ } else {
+ expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+ }
+ }
+
+ for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[inDimIdx];
+ OpFoldResult h = high[inDimIdx];
+
+ if (!isZero(l) || !isZero(h)) {
+ auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+ int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+ OpFoldResult originalSizeOFR;
+ if (originalSize == ShapedType::kDynamic) {
+ Value orgSizeVal =
+ rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+ originalSizeOFR = orgSizeVal;
+ } else {
+ originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+ }
+
+ for (auto outDimIdx : reInd) {
+ expandedShape[outDimIdx] = originalSizeOFR;
+ }
+ }
+ }
+
+ for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+ if (dimSize == ShapedType::kDynamic &&
+ !isa<Value>(expandedShape[outDimIdx]) &&
+ !isa<Attribute>(expandedShape[outDimIdx])) {
+ Value actualSize =
+ rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+ expandedShape[outDimIdx] = actualSize;
+ }
+ }
+
+ SmallVector<int64_t> staticExpandedShape;
+ for (OpFoldResult dim : expandedShape) {
+ if (auto attr = dyn_cast<Attribute>(dim)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ staticExpandedShape.push_back(intAttr.getInt());
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ }
+
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(staticExpandedShape,
+ padOp.getSource().getType().getElementType()),
+ padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
----------------
Max191 wrote:
nit: use `rewriter.replaceOpWithNewOp<tensor::PadOp>`?
https://github.com/llvm/llvm-project/pull/136681
More information about the Mlir-commits
mailing list