[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 20 13:51:27 PST 2025
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/168888
>From bbda86d46996e713ee424d6acf5d71d07d4be611 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 20 Nov 2025 14:53:00 +0000
Subject: [PATCH 1/2] [mlir] Add missing pad reshape propagation patterns
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 283 +++++++++++++++---
.../fuse-with-reshape-by-collapsing.mlir | 39 +++
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 41 +++
3 files changed, 314 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..8c5a0c1474408 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+ // The resulting shape after padding each dimension.
+ SmallVector<int64_t> paddedShape;
+
+ // Low and high padding amounts for each dimension.
+ SmallVector<OpFoldResult> lowPad;
+ SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ // Expanded dimensions cannot have padding because the resulting padding may
+ // not be representable by a tensor.pad op. There are some special cases where
+ // it is possible (like expanding unit dims), but supporting these cases is
+ // NYI, so disallow it for now.
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.paddedShape.assign(expandedShape);
+ padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+ padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+ padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+ }
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
- for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && (l != 0 || h != 0))
- return failure();
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+ expandedPadding.lowPad, expandedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+ FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
}
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType expandedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ RankedTensorType expandedType = expandOp.getResultType();
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
+
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+ SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+ rewriter.setInsertionPointAfterValue(padOp.getSource());
+ SmallVector<OpFoldResult> padSrcSizes =
+ tensor::getMixedSizes(rewriter, loc, padOp.getSource());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ // We know that any reassociation with multiple dims is not padded because
+ // of the requirements of computeExpandedPadding.
if (reInd.size() == 1) {
- expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
- }
- for (size_t i = 0; i < reInd.size(); ++i) {
- newLow.push_back(padOp.getMixedLowPad()[idx]);
- newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+ newExpandedSizes[reInd[0]] = padSrcSizes[idx];
}
}
-
- Location loc = padOp->getLoc();
- RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+ auto newExpandOp = tensor::ExpandShapeOp::create(
+ rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+ newExpandedSizes);
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+ rewriter.setInsertionPoint(expandOp);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+ expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
return success();
}
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ // Collapsed dimensions cannot have padding because this can produce strided
+ // padding that isn't representable by a tensor.pad op. There are some special
+ // cases where it it possible (like collapsing unit dims), but supporting
+ // these cases is NYI, so disallow it for now.
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (int64_t dim : reInd) {
+ if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+ return failure();
+ }
+ }
+
+ // Initialize padding values for collapsed tensors with zeros
+ ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+ // Update padding for dimensions that are not being collapsed, and compute
+ // the collapsed padded shape.
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
+ padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
+ }
+ SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+ for (int64_t dim : reInd) {
+ collapsedSize =
+ collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+ }
+ padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
- for (auto reInd : reassociations) {
- if (reInd.size() == 1)
- continue;
- if (llvm::any_of(reInd, [&](int64_t ind) {
- return low[ind] != 0 || high[ind] != 0;
- })) {
- return failure();
- }
- }
-
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType collapsedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
- SmallVector<OpFoldResult> expandedPaddedSizes(
- getMixedValues(reshapeOp.getStaticOutputShape(),
- reshapeOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> expandedPaddedSizes =
+ reshapeOp.getMixedOutputShape();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
- for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
- OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
- OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+ for (auto [reInd, l, h] :
+ llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+ collapsedPadding.highPad)) {
if (reInd.size() == 1) {
- collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
- expandedPaddedSizes[reInd[0]] = paddedSize;
}
- newLow.push_back(l);
- newHigh.push_back(h);
}
RankedTensorType collapsedPaddedType =
- paddedType.clone(collapsedPaddedShape);
+ padOp.getType().clone(collapsedPadding.paddedShape);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+class FoldReshapeWithProducerPadOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+
+ Location loc = reshapeOp->getLoc();
+ auto newCollapseOp = tensor::CollapseShapeOp::create(
+ rewriter, loc, padOp.getSource(), reassociations);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..923bb2ca9c28a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -639,6 +639,45 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
// CHECK: return %[[EXPAND]]
+// -----
+
+func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]]
+ : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+ return %collapsed : tensor<8x12x17x336x14xi32>
+}
+// CHECK: func @collapse_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>,
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]]
+ : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %collapsed : tensor<?x?x?x?xf32>
+}
+// CHECK: func @collapse_shape_with_producer_pad_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+// CHECK: return %[[PAD]]
+
// -----
// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..f6572674d10e2 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -863,6 +863,47 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// -----
+func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+ %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14]
+ : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+ return %expanded : tensor<8x3x4x17x6x7x8x14xi32>
+}
+// CHECK: func @expand_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [2, 3, 4, 5, 6, 7, 8, 9]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @expand_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?xf32>,
+ %s0: index, %s1: index, %s2: index, %s3: index, %s4: index, %s5: index,
+ %l0: index, %l1: index, %h0: index, %h1: index) -> tensor<?x?x?x?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %padded = tensor.pad %arg0 low[%l0, 0, %l1, 0] high[%h0, 0, %h1, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
+ %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5]] output_shape [%s0, %s1, %s2, %s3, %s4, %s5]
+ : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
+ return %expanded : tensor<?x?x?x?x?x?xf32>
+}
+// CHECK: func @expand_shape_with_producer_pad_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
+// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0:.+]] : tensor<?x?x?x?xf32>
+// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2:.+]] : tensor<?x?x?x?xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]] output_shape [%[[DIM0]], %[[S1]], %[[S2]], %[[DIM2]], %[[S4]], %[[S5]]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND]] low[%[[L0]], 0, 0, %[[L1]], 0, 0] high[%[[H0]], 0, 0, %[[H1]], 0, 0]
+// CHECK: return %[[PAD]]
+
+// -----
+
func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
%arg1 : tensor<4x?x32x128xf16>, %empty : tensor<4x?x32x128xf16>) -> tensor<4x?x32x8x16xf16> {
%c0 = arith.constant 0 : index
>From f76287cc2b5a666dacb39a086a603108bde25a37 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 20 Nov 2025 21:51:08 +0000
Subject: [PATCH 2/2] address comments
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 56 ++++++++++---------
1 file changed, 30 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 8c5a0c1474408..1e110d1c6b113 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1050,19 +1050,24 @@ struct PadDimInfo {
/// Computes the expanded padding information for the given pad operation based
/// on the provided expanded shape and reassociation indices. Returns a list of
-/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// PadDimInfo containing the low and high padding amounts and the padded
/// size for each dimension, or failure if the expansion is not possible.
static FailureOr<PadDimInfo>
computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
ArrayRef<ReassociationIndices> reassociations,
PatternRewriter &rewriter) {
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ // If the padding value depends on the index values of the pad operation,
+ // then it may not be valid to expand the dimensions, since it will change
+ // the index values on which the padding value depends.
+ if (!padOp.getConstantPaddingValue())
+ return failure();
// Expanded dimensions cannot have padding because the resulting padding may
// not be representable by a tensor.pad op. There are some special cases where
// it is possible (like expanding unit dims), but supporting these cases is
// NYI, so disallow it for now.
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
if (reInd.size() != 1 && (l != 0 || h != 0))
return failure();
@@ -1101,8 +1106,6 @@ class FoldPadWithProducerReshapeOpByExpansion
padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
if (!reshapeOp)
return failure();
- if (!reshapeOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
@@ -1116,7 +1119,7 @@ class FoldPadWithProducerReshapeOpByExpansion
padOp, expandedType.getShape(), reassociations, rewriter);
if (failed(maybeExpandedPadding))
return failure();
- PadDimInfo expandedPadding = maybeExpandedPadding.value();
+ PadDimInfo &expandedPadding = maybeExpandedPadding.value();
Location loc = padOp->getLoc();
RankedTensorType expandedPaddedType =
@@ -1137,12 +1140,12 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
-class FoldExpandShapeWithProducerPadOp
+class FoldReshapeWithProducerPadOpByExpansion
: public OpRewritePattern<tensor::ExpandShapeOp> {
public:
- FoldExpandShapeWithProducerPadOp(MLIRContext *context,
- ControlFusionFn foldReshapes,
- PatternBenefit benefit = 1)
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
: OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
controlFoldingReshapes(std::move(foldReshapes)) {}
@@ -1151,8 +1154,6 @@ class FoldExpandShapeWithProducerPadOp
tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();
- if (!padOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(expandOp,
@@ -1166,7 +1167,7 @@ class FoldExpandShapeWithProducerPadOp
padOp, expandedType.getShape(), reassociations, rewriter);
if (failed(maybeExpandedPadding))
return failure();
- PadDimInfo expandedPadding = maybeExpandedPadding.value();
+ PadDimInfo &expandedPadding = maybeExpandedPadding.value();
Location loc = expandOp->getLoc();
SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
@@ -2031,13 +2032,18 @@ static FailureOr<PadDimInfo>
computeCollapsedPadding(tensor::PadOp padOp,
ArrayRef<ReassociationIndices> reassociations,
PatternRewriter &rewriter) {
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ // If the padding value depends on the index values of the pad operation,
+ // then it may not be valid to collapse the dimensions, since it will change
+ // the index values on which the padding value depends.
+ if (!padOp.getConstantPaddingValue())
+ return failure();
// Collapsed dimensions cannot have padding because this can produce strided
// padding that isn't representable by a tensor.pad op. There are some special
- // cases where it it possible (like collapsing unit dims), but supporting
+ // cases where it is possible (like collapsing unit dims), but supporting
// these cases is NYI, so disallow it for now.
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
for (int64_t dim : reInd) {
if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
@@ -2053,10 +2059,12 @@ computeCollapsedPadding(tensor::PadOp padOp,
// Update padding for dimensions that are not being collapsed, and compute
// the collapsed padded shape.
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
if (reInd.size() == 1) {
- padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
- padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
+ padDimInfo.lowPad[idx] = mixedLowPad[reInd[0]];
+ padDimInfo.highPad[idx] = mixedHighPad[reInd[0]];
}
SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
for (int64_t dim : reInd) {
@@ -2084,8 +2092,6 @@ class FoldPadWithProducerReshapeOpByCollapsing
padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!reshapeOp)
return failure();
- if (!reshapeOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&padOp.getSourceMutable())) {
return rewriter.notifyMatchFailure(padOp,
@@ -2098,7 +2104,7 @@ class FoldPadWithProducerReshapeOpByCollapsing
computeCollapsedPadding(padOp, reassociations, rewriter);
if (failed(maybeCollapsedPadding))
return failure();
- PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
SmallVector<OpFoldResult> expandedPaddedSizes =
reshapeOp.getMixedOutputShape();
@@ -2147,8 +2153,6 @@ class FoldReshapeWithProducerPadOpByCollapsing
tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
if (!padOp)
return failure();
- if (!padOp->hasOneUse())
- return failure();
if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(padOp,
@@ -2162,7 +2166,7 @@ class FoldReshapeWithProducerPadOpByCollapsing
computeCollapsedPadding(padOp, reassociations, rewriter);
if (failed(maybeCollapsedPadding))
return failure();
- PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+ PadDimInfo &collapsedPadding = maybeCollapsedPadding.value();
Location loc = reshapeOp->getLoc();
auto newCollapseOp = tensor::CollapseShapeOp::create(
@@ -2420,8 +2424,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
- patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
- controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
More information about the Mlir-commits
mailing list