[Mlir-commits] [mlir] [Draft][MLIR] Add reshape propagation through tensor.pad (PR #136681)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Jul 12 02:31:43 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Hyunsung Lee (ita9naiwa)
<details>
<summary>Changes</summary>
https://github.com/iree-org/iree/issues/17492#issuecomment-2688799803
I’ve implemented fusion for tensor.expand_shape → tensor.pad, but two gaps remain:
1. Missing collapse‑side pattern.
I haven’t yet added the mirror case for tensor.collapse_shape → tensor.pad.
2. Static‑only support
The current pattern only handles fully static shapes and padding.
Before (expand then pad):
```mlir
func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%c0 = arith.constant 0.0 : f32
%producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %c0 : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %reshape : tensor<32x16x258x258xf32>
}
```
After (reshape then pad):
```mlir
func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%c0 = arith.constant 0.0 : f32
%producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%reshape = tensor.expand_shape %producer [[0, 1], [2], [3]]
output_shape [32, 16, 258, 258] : tensor<512x256x256xf32> into tensor<32x16x256x256xf32>
%pad = tensor.pad %reshape low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %c0 : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
return %pad : tensor<32x16x258x258xf32>
}
```
Next steps
• Add a CollapseShapeOp→PadOp pattern to cover the missing collapse‑side fusion.
• Lift the “static‑only” guard so both patterns handle dynamic shapes and pads.
CC @<!-- -->Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling!
---
Full diff: https://github.com/llvm/llvm-project/pull/136681.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+142)
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+49-2)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 9c0f6e5d6469e..39eed6dd4cba4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1100,6 +1100,146 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+ if (auto attr = dyn_cast<Attribute>(padValue)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 0;
+ }
+
+ if (auto val = dyn_cast<Value>(padValue)) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return attr.getInt() == 0;
+ }
+ }
+
+ // when padding is dynamic and not constant, we don't know if it's zero or
+ // not. so we return false here.
+ return false;
+ };
+
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[idx];
+ OpFoldResult h = high[idx];
+ if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ Location loc = expandOp.getLoc();
+ auto finalType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = finalType.getShape();
+
+ SmallVector<OpFoldResult> expandedShape;
+ for (int64_t dimSize : finalShape) {
+ if (dimSize == ShapedType::kDynamic) {
+ expandedShape.push_back(OpFoldResult{});
+ } else {
+ expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+ }
+ }
+
+ for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[inDimIdx];
+ OpFoldResult h = high[inDimIdx];
+
+ if (!isZeroPadding(l) || !isZeroPadding(h)) {
+ auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+ int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+ OpFoldResult originalSizeOFR;
+ if (originalSize == ShapedType::kDynamic) {
+ Value orgSizeVal =
+ rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+ originalSizeOFR = orgSizeVal;
+ } else {
+ originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+ }
+
+ for (auto outDimIdx : outGroup) {
+ expandedShape[outDimIdx] = originalSizeOFR;
+ }
+ }
+ }
+
+ for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+ if (dimSize == ShapedType::kDynamic &&
+ !isa<Value>(expandedShape[outDimIdx]) &&
+ !isa<Attribute>(expandedShape[outDimIdx])) {
+ Value actualSize =
+ rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+ expandedShape[outDimIdx] = actualSize;
+ }
+ }
+
+ SmallVector<int64_t> staticExpandedShape;
+ for (OpFoldResult dim : expandedShape) {
+ if (auto attr = dyn_cast<Attribute>(dim)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ staticExpandedShape.push_back(intAttr.getInt());
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ }
+
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(staticExpandedShape,
+ padOp.getSource().getType().getElementType()),
+ padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -2235,6 +2375,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
- %arg1 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----
func.func @reshape_as_consumer_permutation_with_multiple_results
- (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
+ (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: return %[[PADDED]]
``````````
</details>
https://github.com/llvm/llvm-project/pull/136681
More information about the Mlir-commits
mailing list