[Mlir-commits] [mlir] [MLIR] Add expand_shape propagation through tensor.pad (PR #136681)
Hyunsung Lee
llvmlistbot at llvm.org
Sun Jul 13 17:59:50 PDT 2025
https://github.com/ita9naiwa updated https://github.com/llvm/llvm-project/pull/136681
>From 0d8c636d4fd4d5c9636cfd3599c804e4a89e81e6 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Tue, 22 Apr 2025 19:04:02 +0900
Subject: [PATCH 1/2] Add FoldReshapeWithProducerPadOpByExpansion
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 142 ++++++++++++++++++
mlir/test/Dialect/Linalg/reshape_fusion.mlir | 51 ++++++-
2 files changed, 191 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index bf70597d5ddfe..dd4ac89e98090 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1101,6 +1101,146 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+ if (auto attr = dyn_cast<Attribute>(padValue)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 0;
+ }
+
+ if (auto val = dyn_cast<Value>(padValue)) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return attr.getInt() == 0;
+ }
+ }
+
+ // when padding is dynamic and not constant, we don't know if it's zero or
+ // not. so we return false here.
+ return false;
+ };
+
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[idx];
+ OpFoldResult h = high[idx];
+ if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ Location loc = expandOp.getLoc();
+ auto finalType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = finalType.getShape();
+
+ SmallVector<OpFoldResult> expandedShape;
+ for (int64_t dimSize : finalShape) {
+ if (dimSize == ShapedType::kDynamic) {
+ expandedShape.push_back(OpFoldResult{});
+ } else {
+ expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+ }
+ }
+
+ for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[inDimIdx];
+ OpFoldResult h = high[inDimIdx];
+
+ if (!isZeroPadding(l) || !isZeroPadding(h)) {
+ auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+ int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+ OpFoldResult originalSizeOFR;
+ if (originalSize == ShapedType::kDynamic) {
+ Value orgSizeVal =
+ rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+ originalSizeOFR = orgSizeVal;
+ } else {
+ originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+ }
+
+ for (auto outDimIdx : outGroup) {
+ expandedShape[outDimIdx] = originalSizeOFR;
+ }
+ }
+ }
+
+ for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+ if (dimSize == ShapedType::kDynamic &&
+ !isa<Value>(expandedShape[outDimIdx]) &&
+ !isa<Attribute>(expandedShape[outDimIdx])) {
+ Value actualSize =
+ rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+ expandedShape[outDimIdx] = actualSize;
+ }
+ }
+
+ SmallVector<int64_t> staticExpandedShape;
+ for (OpFoldResult dim : expandedShape) {
+ if (auto attr = dyn_cast<Attribute>(dim)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ staticExpandedShape.push_back(intAttr.getInt());
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ }
+
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(staticExpandedShape,
+ padOp.getSource().getType().getElementType()),
+ padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -2249,6 +2389,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
- %arg1 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----
func.func @reshape_as_consumer_permutation_with_multiple_results
- (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
+ (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: return %[[PADDED]]
>From 57ec65705339764b1a472f32b382c015909b25e8 Mon Sep 17 00:00:00 2001
From: Hyunsung Lee <ita9naiwa at gmail.com>
Date: Mon, 14 Jul 2025 09:59:33 +0900
Subject: [PATCH 2/2] add collapse_shape
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 176 +++++++++++++++---
.../fuse-with-reshape-by-collapsing.mlir | 53 +++++-
2 files changed, 204 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 39eed6dd4cba4..e65228ae0e3eb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -26,6 +26,8 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/LogicalResult.h"
#include <optional>
#include <utility>
@@ -1100,6 +1102,20 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+bool isZero(OpFoldResult value) {
+ if (auto attr = dyn_cast<Attribute>(value)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 0;
+ }
+ if (auto val = dyn_cast<Value>(value)) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return attr.getInt() == 0;
+ }
+ }
+ return false;
+}
+
/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
/// by bubbling the expand_shape before the pad.
struct FoldReshapeWithProducerPadOpByExpansion
@@ -1125,41 +1141,29 @@ struct FoldReshapeWithProducerPadOpByExpansion
"fusion blocked by control function");
}
+ Value constantPaddingValue = padOp.getConstantPaddingValue();
+ if (!constantPaddingValue) {
+ return rewriter.notifyMatchFailure(
+ expandOp, "cannot fold with non-constant padding value");
+ }
+
SmallVector<ReassociationIndices> reassociations =
expandOp.getReassociationIndices();
SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
- auto isZeroPadding = [](OpFoldResult padValue) -> bool {
- if (auto attr = dyn_cast<Attribute>(padValue)) {
- if (auto intAttr = dyn_cast<IntegerAttr>(attr))
- return intAttr.getInt() == 0;
- }
-
- if (auto val = dyn_cast<Value>(padValue)) {
- if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
- if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
- return attr.getInt() == 0;
- }
- }
-
- // when padding is dynamic and not constant, we don't know if it's zero or
- // not. so we return false here.
- return false;
- };
-
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[idx];
OpFoldResult h = high[idx];
- if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+ if (reInd.size() > 1 && (!isZero(l) || !isZero(h)))
return failure();
}
SmallVector<OpFoldResult> newLow, newHigh;
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
for (size_t i = 0; i < reInd.size(); ++i) {
- newLow.push_back(padOp.getMixedLowPad()[idx]);
- newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ newLow.push_back(low[idx]);
+ newHigh.push_back(high[idx]);
}
}
@@ -1176,11 +1180,11 @@ struct FoldReshapeWithProducerPadOpByExpansion
}
}
- for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
OpFoldResult l = low[inDimIdx];
OpFoldResult h = high[inDimIdx];
- if (!isZeroPadding(l) || !isZeroPadding(h)) {
+ if (!isZero(l) || !isZero(h)) {
auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
int64_t originalSize = srcType.getDimSize(inDimIdx);
@@ -1193,7 +1197,7 @@ struct FoldReshapeWithProducerPadOpByExpansion
originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
}
- for (auto outDimIdx : outGroup) {
+ for (auto outDimIdx : reInd) {
expandedShape[outDimIdx] = originalSizeOFR;
}
}
@@ -1240,6 +1244,125 @@ struct FoldReshapeWithProducerPadOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.collapse_shape op with its producer tensor.pad op
+/// by bubbling the collapse_shape before the pad.
+struct FoldReshapeWithProducerPadOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+
+ FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = collapseOp.getSrc().getDefiningOp<tensor::PadOp>();
+
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&collapseOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(collapseOp,
+ "fusion blocked by control function");
+ }
+
+ Value constantPaddingValue = padOp.getConstantPaddingValue();
+ if (!constantPaddingValue) {
+ return rewriter.notifyMatchFailure(
+ collapseOp, "cannot fold with non-constant padding value");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ collapseOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() > 1) {
+ for (auto dimIdx : reInd) {
+ if (!isZero(low[dimIdx]) || !isZero(high[dimIdx])) {
+ return failure();
+ }
+ }
+ }
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ newLow.push_back(low[reInd[0]]);
+ newHigh.push_back(high[reInd[0]]);
+ }
+
+ Location loc = collapseOp.getLoc();
+ auto resultType = collapseOp.getResultType();
+
+ auto finalType = cast<RankedTensorType>(collapseOp.getType());
+ ArrayRef<int64_t> finalShape = finalType.getShape();
+
+ SmallVector<OpFoldResult> collapsedShape;
+ for (int64_t dimSize : finalShape) {
+ if (dimSize == ShapedType::kDynamic) {
+ collapsedShape.push_back(OpFoldResult{});
+ } else {
+ collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+ }
+ }
+
+ for (auto [inDimIdx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[reInd[0]];
+ OpFoldResult h = high[reInd[0]];
+
+ if (!isZero(l) || !isZero(h)) {
+ auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+ int64_t originalSize = srcType.getDimSize(reInd[0]);
+
+ OpFoldResult originalSizeOFR;
+ if (originalSize == ShapedType::kDynamic) {
+ Value orgSizeVal =
+ rewriter.create<tensor::DimOp>(loc, padOp.getSource(), reInd[0]);
+ originalSizeOFR = orgSizeVal;
+ } else {
+ originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+ }
+ collapsedShape[inDimIdx] = originalSizeOFR;
+ }
+ }
+
+ SmallVector<int64_t> staticCollapsedShape;
+ for (OpFoldResult dim : collapsedShape) {
+ if (auto attr = dyn_cast<Attribute>(dim)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ staticCollapsedShape.push_back(intAttr.getInt());
+ } else {
+ staticCollapsedShape.push_back(ShapedType::kDynamic);
+ }
+ } else {
+ staticCollapsedShape.push_back(ShapedType::kDynamic);
+ }
+ }
+
+ auto newCollapseType = RankedTensorType::get(
+ staticCollapsedShape, padOp.getSource().getType().getElementType());
+ auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
+ loc, newCollapseType, padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, resultType, newCollapseOp.getResult(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(collapseOp, newPadOp.getResult());
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -2388,6 +2511,11 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
+
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..0ac1686361bf7 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -232,7 +232,7 @@ func.func @fuse_by_collapsing_dynamic_2(%arg0 : tensor<?xf32>, %sz0: index, %sz1
%1 = linalg.generic {
indexing_maps = [#map0, #map0],
iterator_types = ["parallel", "parallel"]}
- ins(%0 : tensor<?x?xf32>)
+ ins(%0 : tensor<?x?xf32>)
outs(%init : tensor<?x?xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%out = arith.negf %b0 : f32
@@ -858,3 +858,54 @@ func.func @partial_fuse_by_collapsing(%arg0: tensor<4x?x32x128x192xf16>, %arg1:
// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]]
// CHECK-SAME: tensor<4x128x192x?x32xf32> into tensor<512x192x?xf32>
// CHECK: return %[[COLLAPSED]] : tensor<512x192x?xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_collapse(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
+ %padded = tensor.pad %0 low[0, 0, 1, 1] high[0, 0, 1, 1] {
+ ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
+ tensor.yield %cst : f32
+ } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+ %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
+ : tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
+ return %collapsed : tensor<512x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_collapse(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
+// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+// CHECK: return %[[PADDED]] : tensor<512x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_collapse_dynamic_pad_zero(%arg0: tensor<32x16x256x256xf32>) -> tensor<512x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x16x256x256xf32>) -> tensor<32x16x256x256xf32>
+ %padded = tensor.pad %0 low[%c0, %c0, %c1, %c1] high[%c0, %c0, %c1, %c1] {
+ ^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
+ tensor.yield %cst : f32
+ } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+ %collapsed = tensor.collapse_shape %padded [[0, 1], [2], [3]]
+ : tensor<32x16x258x258xf32> into tensor<512x258x258xf32>
+ return %collapsed : tensor<512x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_collapse_dynamic_pad_zero(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<32x16x256x256xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[FILLED]] {{\[}}[0, 1], [2], [3]{{\]}}
+// CHECK-SAME: : tensor<32x16x256x256xf32> into tensor<512x256x256xf32>
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[COLLAPSED]] low[0, 1, 1] high[0, 1, 1]
+// CHECK: ^bb0(
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: return %[[PADDED]]
More information about the Mlir-commits
mailing list