[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Nov 20 07:28:27 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: None (Max191)
<details>
<summary>Changes</summary>
The existing `FoldPadWithProducerReshapeOpByExpansion` and `FoldPadWithProducerReshapeOpByCollapsing` patterns did not cover all reshape propagation cases, because they only consider cases where the pad op is the consumer operation. This PR adds 2 new patterns to cover the cases where the pad op is the producer operation, which completes the propagation pattern set for pad op with expand_shape and collapse_shape.
---
Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168888.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+234-49)
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+39)
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+41)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..8c5a0c1474408 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+ // The resulting shape after padding each dimension.
+ SmallVector<int64_t> paddedShape;
+
+ // Low and high padding amounts for each dimension.
+ SmallVector<OpFoldResult> lowPad;
+ SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ // Expanded dimensions cannot have padding because the resulting padding may
+ // not be representable by a tensor.pad op. There are some special cases where
+ // it is possible (like expanding unit dims), but supporting these cases is
+ // NYI, so disallow it for now.
+ for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+ if (reInd.size() != 1 && (l != 0 || h != 0))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+ SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+ ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.paddedShape.assign(expandedShape);
+ padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+ padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+ padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+ }
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByExpansion
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
+ RankedTensorType expandedType = reshapeOp.getSrcType();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
- for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
- if (reInd.size() != 1 && (l != 0 || h != 0))
- return failure();
+ Location loc = padOp->getLoc();
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+ expandedPadding.lowPad, expandedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+ padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+ FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
}
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType expandedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+ RankedTensorType expandedType = expandOp.getResultType();
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+ padOp, expandedType.getShape(), reassociations, rewriter);
+ if (failed(maybeExpandedPadding))
+ return failure();
+ PadDimInfo expandedPadding = maybeExpandedPadding.value();
+
+ Location loc = expandOp->getLoc();
+ SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+ SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+ rewriter.setInsertionPointAfterValue(padOp.getSource());
+ SmallVector<OpFoldResult> padSrcSizes =
+ tensor::getMixedSizes(rewriter, loc, padOp.getSource());
for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ // We know that any reassociation with multiple dims is not padded because
+ // of the requirements of computeExpandedPadding.
if (reInd.size() == 1) {
- expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
- }
- for (size_t i = 0; i < reInd.size(); ++i) {
- newLow.push_back(padOp.getMixedLowPad()[idx]);
- newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+ newExpandedSizes[reInd[0]] = padSrcSizes[idx];
}
}
-
- Location loc = padOp->getLoc();
- RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+ RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+ auto newExpandOp = tensor::ExpandShapeOp::create(
+ rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+ newExpandedSizes);
+ RankedTensorType expandedPaddedType =
+ padOp.getResultType().clone(expandedPadding.paddedShape);
+ rewriter.setInsertionPoint(expandOp);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+ expandedPadding.lowPad, expandedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
- rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
- padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
return success();
}
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+ ArrayRef<ReassociationIndices> reassociations,
+ PatternRewriter &rewriter) {
+ ArrayRef<int64_t> low = padOp.getStaticLow();
+ ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+ // Collapsed dimensions cannot have padding because this can produce strided
+ // padding that isn't representable by a tensor.pad op. There are some special
+ // cases where it it possible (like collapsing unit dims), but supporting
+ // these cases is NYI, so disallow it for now.
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (int64_t dim : reInd) {
+ if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+ return failure();
+ }
+ }
+
+ // Initialize padding values for collapsed tensors with zeros
+ ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+ PadDimInfo padDimInfo;
+ padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+ padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+ // Update padding for dimensions that are not being collapsed, and compute
+ // the collapsed padded shape.
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ if (reInd.size() == 1) {
+ padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
+ padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
+ }
+ SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+ for (int64_t dim : reInd) {
+ collapsedSize =
+ collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+ }
+ padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+ }
+
+ return padDimInfo;
+}
+
class FoldPadWithProducerReshapeOpByCollapsing
: public OpRewritePattern<tensor::PadOp> {
public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
"fusion blocked by control function");
}
- ArrayRef<int64_t> low = padOp.getStaticLow();
- ArrayRef<int64_t> high = padOp.getStaticHigh();
SmallVector<ReassociationIndices> reassociations =
reshapeOp.getReassociationIndices();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
- for (auto reInd : reassociations) {
- if (reInd.size() == 1)
- continue;
- if (llvm::any_of(reInd, [&](int64_t ind) {
- return low[ind] != 0 || high[ind] != 0;
- })) {
- return failure();
- }
- }
-
- SmallVector<OpFoldResult> newLow, newHigh;
- RankedTensorType collapsedType = reshapeOp.getSrcType();
- RankedTensorType paddedType = padOp.getResultType();
- SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
- SmallVector<OpFoldResult> expandedPaddedSizes(
- getMixedValues(reshapeOp.getStaticOutputShape(),
- reshapeOp.getOutputShape(), rewriter));
+ SmallVector<OpFoldResult> expandedPaddedSizes =
+ reshapeOp.getMixedOutputShape();
AffineExpr d0, d1, d2;
bindDims(rewriter.getContext(), d0, d1, d2);
auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
Location loc = reshapeOp->getLoc();
- for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
- OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
- OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+ for (auto [reInd, l, h] :
+ llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+ collapsedPadding.highPad)) {
if (reInd.size() == 1) {
- collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
- OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+ expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
- expandedPaddedSizes[reInd[0]] = paddedSize;
}
- newLow.push_back(l);
- newHigh.push_back(h);
}
RankedTensorType collapsedPaddedType =
- paddedType.clone(collapsedPaddedShape);
+ padOp.getType().clone(collapsedPadding.paddedShape);
auto newPadOp = tensor::PadOp::create(
- rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+ rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
padOp.getConstantPaddingValue(), padOp.getNofold());
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
ControlFusionFn controlFoldingReshapes;
};
+class FoldReshapeWithProducerPadOpByCollapsing
+ : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+ FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(padOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ reshapeOp.getReassociationIndices();
+ RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+ FailureOr<PadDimInfo> maybeCollapsedPadding =
+ computeCollapsedPadding(padOp, reassociations, rewriter);
+ if (failed(maybeCollapsedPadding))
+ return failure();
+ PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+
+ Location loc = reshapeOp->getLoc();
+ auto newCollapseOp = tensor::CollapseShapeOp::create(
+ rewriter, loc, padOp.getSource(), reassociations);
+
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+ collapsedPadding.lowPad, collapsedPadding.highPad,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to collapse dimensions.
template <typename LinalgType>
class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
patterns.getContext(), controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+ patterns.getContext(), controlFoldingReshapes);
patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..923bb2ca9c28a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -639,6 +639,45 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
// CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
// CHECK: return %[[EXPAND]]
+// -----
+
+func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+ %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+ tensor.yield %cst : i32
+ } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]]
+ : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+ return %collapsed : tensor<8x12x17x336x14xi32>
+}
+// CHECK: func @collapse_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+// CHECK: return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>,
+ %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+ %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+ tensor.yield %cst : f32
+ } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+ %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]]
+ : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+ return %collapsed : tensor<?x?x?x?xf32>
+}
+// CHECK: func @collapse_shape_with_producer_pad_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME: %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]]
+// CHECK: %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+// CHECK: return %[[PAD]]
+
// -----
// Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..f6572674d10e2 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -863,6 +863,47 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
// -----
+func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+ %cst = arith.constant 0 : i32
+ %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+ ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+ tensor.yield %cst : i32
+ } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+ %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14]
+ : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+ return %expanded : tensor<8x3x4x17x6x7x8x14xi32>
+}
+// CHECK: func @expand_shape_with_producer_pad
+// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/168888
More information about the Mlir-commits
mailing list