[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 20 07:28:27 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: None (Max191)

<details>
<summary>Changes</summary>

The existing `FoldPadWithProducerReshapeOpByExpansion` and `FoldPadWithProducerReshapeOpByCollapsing` patterns did not cover all reshape propagation cases, because they only consider cases where the pad op is the consumer operation. This PR adds 2 new patterns to cover the cases where the pad op is the producer operation, which completes the propagation pattern set for pad op with expand_shape and collapse_shape.

---

Patch is 21.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168888.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+234-49) 
- (modified) mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir (+39) 
- (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+41) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 05fc7cbbb90af..8c5a0c1474408 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+  // The resulting shape after padding each dimension.
+  SmallVector<int64_t> paddedShape;
+
+  // Low and high padding amounts for each dimension.
+  SmallVector<OpFoldResult> lowPad;
+  SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+                       ArrayRef<ReassociationIndices> reassociations,
+                       PatternRewriter &rewriter) {
+  ArrayRef<int64_t> low = padOp.getStaticLow();
+  ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+  // Expanded dimensions cannot have padding because the resulting padding may
+  // not be representable by a tensor.pad op. There are some special cases where
+  // it is possible (like expanding unit dims), but supporting these cases is
+  // NYI, so disallow it for now.
+  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+    if (reInd.size() != 1 && (l != 0 || h != 0))
+      return failure();
+  }
+
+  SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+  SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+  ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+  PadDimInfo padDimInfo;
+  padDimInfo.paddedShape.assign(expandedShape);
+  padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+  padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    if (reInd.size() == 1) {
+      padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+      padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+      padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+    }
+  }
+
+  return padDimInfo;
+}
+
 class FoldPadWithProducerReshapeOpByExpansion
     : public OpRewritePattern<tensor::PadOp> {
 public:
@@ -1061,38 +1109,92 @@ class FoldPadWithProducerReshapeOpByExpansion
                                          "fusion blocked by control function");
     }
 
-    ArrayRef<int64_t> low = padOp.getStaticLow();
-    ArrayRef<int64_t> high = padOp.getStaticHigh();
+    RankedTensorType expandedType = reshapeOp.getSrcType();
     SmallVector<ReassociationIndices> reassociations =
         reshapeOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+        padOp, expandedType.getShape(), reassociations, rewriter);
+    if (failed(maybeExpandedPadding))
+      return failure();
+    PadDimInfo expandedPadding = maybeExpandedPadding.value();
 
-    for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
-      if (reInd.size() != 1 && (l != 0 || h != 0))
-        return failure();
+    Location loc = padOp->getLoc();
+    RankedTensorType expandedPaddedType =
+        padOp.getResultType().clone(expandedPadding.paddedShape);
+
+    auto newPadOp = tensor::PadOp::create(
+        rewriter, loc, expandedPaddedType, reshapeOp.getSrc(),
+        expandedPadding.lowPad, expandedPadding.highPad,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
+        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
+class FoldExpandShapeWithProducerPadOp
+    : public OpRewritePattern<tensor::ExpandShapeOp> {
+public:
+  FoldExpandShapeWithProducerPadOp(MLIRContext *context,
+                                   ControlFusionFn foldReshapes,
+                                   PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(expandOp,
+                                         "fusion blocked by control function");
     }
 
-    SmallVector<OpFoldResult> newLow, newHigh;
-    RankedTensorType expandedType = reshapeOp.getSrcType();
-    RankedTensorType paddedType = padOp.getResultType();
-    SmallVector<int64_t> expandedPaddedShape(expandedType.getShape());
+    RankedTensorType expandedType = expandOp.getResultType();
+    SmallVector<ReassociationIndices> reassociations =
+        expandOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeExpandedPadding = computeExpandedPadding(
+        padOp, expandedType.getShape(), reassociations, rewriter);
+    if (failed(maybeExpandedPadding))
+      return failure();
+    PadDimInfo expandedPadding = maybeExpandedPadding.value();
+
+    Location loc = expandOp->getLoc();
+    SmallVector<OpFoldResult> newExpandedSizes = expandOp.getMixedOutputShape();
+    SmallVector<int64_t> newExpandedShape(expandedType.getShape());
+    rewriter.setInsertionPointAfterValue(padOp.getSource());
+    SmallVector<OpFoldResult> padSrcSizes =
+        tensor::getMixedSizes(rewriter, loc, padOp.getSource());
     for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+      // We know that any reassociation with multiple dims is not padded because
+      // of the requirements of computeExpandedPadding.
       if (reInd.size() == 1) {
-        expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx];
-      }
-      for (size_t i = 0; i < reInd.size(); ++i) {
-        newLow.push_back(padOp.getMixedLowPad()[idx]);
-        newHigh.push_back(padOp.getMixedHighPad()[idx]);
+        newExpandedShape[reInd[0]] = padOp.getSourceType().getDimSize(idx);
+        newExpandedSizes[reInd[0]] = padSrcSizes[idx];
       }
     }
-
-    Location loc = padOp->getLoc();
-    RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape);
+    RankedTensorType newExpandedType = expandedType.clone(newExpandedShape);
+    auto newExpandOp = tensor::ExpandShapeOp::create(
+        rewriter, loc, newExpandedType, padOp.getSource(), reassociations,
+        newExpandedSizes);
+    RankedTensorType expandedPaddedType =
+        padOp.getResultType().clone(expandedPadding.paddedShape);
+    rewriter.setInsertionPoint(expandOp);
     auto newPadOp = tensor::PadOp::create(
-        rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        rewriter, loc, expandedPaddedType, newExpandOp.getResult(),
+        expandedPadding.lowPad, expandedPadding.highPad,
         padOp.getConstantPaddingValue(), padOp.getNofold());
 
-    rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
-        padOp, padOp.getResultType(), newPadOp.getResult(), reassociations);
+    rewriter.replaceOp(expandOp, newPadOp.getResult());
 
     return success();
   }
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+                        ArrayRef<ReassociationIndices> reassociations,
+                        PatternRewriter &rewriter) {
+  ArrayRef<int64_t> low = padOp.getStaticLow();
+  ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+  // Collapsed dimensions cannot have padding because this can produce strided
+  // padding that isn't representable by a tensor.pad op. There are some special
+  // cases where it it possible (like collapsing unit dims), but supporting
+  // these cases is NYI, so disallow it for now.
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    for (int64_t dim : reInd) {
+      if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+        return failure();
+    }
+  }
+
+  // Initialize padding values for collapsed tensors with zeros
+  ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+  PadDimInfo padDimInfo;
+  padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+  padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+  // Update padding for dimensions that are not being collapsed, and compute
+  // the collapsed padded shape.
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    if (reInd.size() == 1) {
+      padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
+      padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
+    }
+    SaturatedInteger collapsedSize = SaturatedInteger::wrap(1);
+    for (int64_t dim : reInd) {
+      collapsedSize =
+          collapsedSize * SaturatedInteger::wrap(expandedPaddedShape[dim]);
+    }
+    padDimInfo.paddedShape.push_back(collapsedSize.asInteger());
+  }
+
+  return padDimInfo;
+}
+
 class FoldPadWithProducerReshapeOpByCollapsing
     : public OpRewritePattern<tensor::PadOp> {
 public:
@@ -1944,49 +2092,34 @@ class FoldPadWithProducerReshapeOpByCollapsing
                                          "fusion blocked by control function");
     }
 
-    ArrayRef<int64_t> low = padOp.getStaticLow();
-    ArrayRef<int64_t> high = padOp.getStaticHigh();
     SmallVector<ReassociationIndices> reassociations =
         reshapeOp.getReassociationIndices();
+    FailureOr<PadDimInfo> maybeCollapsedPadding =
+        computeCollapsedPadding(padOp, reassociations, rewriter);
+    if (failed(maybeCollapsedPadding))
+      return failure();
+    PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
 
-    for (auto reInd : reassociations) {
-      if (reInd.size() == 1)
-        continue;
-      if (llvm::any_of(reInd, [&](int64_t ind) {
-            return low[ind] != 0 || high[ind] != 0;
-          })) {
-        return failure();
-      }
-    }
-
-    SmallVector<OpFoldResult> newLow, newHigh;
-    RankedTensorType collapsedType = reshapeOp.getSrcType();
-    RankedTensorType paddedType = padOp.getResultType();
-    SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape());
-    SmallVector<OpFoldResult> expandedPaddedSizes(
-        getMixedValues(reshapeOp.getStaticOutputShape(),
-                       reshapeOp.getOutputShape(), rewriter));
+    SmallVector<OpFoldResult> expandedPaddedSizes =
+        reshapeOp.getMixedOutputShape();
     AffineExpr d0, d1, d2;
     bindDims(rewriter.getContext(), d0, d1, d2);
     auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2});
     Location loc = reshapeOp->getLoc();
-    for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
-      OpFoldResult l = padOp.getMixedLowPad()[reInd[0]];
-      OpFoldResult h = padOp.getMixedHighPad()[reInd[0]];
+    for (auto [reInd, l, h] :
+         llvm::zip_equal(reassociations, collapsedPadding.lowPad,
+                         collapsedPadding.highPad)) {
       if (reInd.size() == 1) {
-        collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]];
-        OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply(
+        expandedPaddedSizes[reInd[0]] = affine::makeComposedFoldedAffineApply(
             rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]});
-        expandedPaddedSizes[reInd[0]] = paddedSize;
       }
-      newLow.push_back(l);
-      newHigh.push_back(h);
     }
 
     RankedTensorType collapsedPaddedType =
-        paddedType.clone(collapsedPaddedShape);
+        padOp.getType().clone(collapsedPadding.paddedShape);
     auto newPadOp = tensor::PadOp::create(
-        rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh,
+        rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(),
+        collapsedPadding.lowPad, collapsedPadding.highPad,
         padOp.getConstantPaddingValue(), padOp.getNofold());
 
     rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
@@ -2000,6 +2133,54 @@ class FoldPadWithProducerReshapeOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+class FoldReshapeWithProducerPadOpByCollapsing
+    : public OpRewritePattern<tensor::CollapseShapeOp> {
+public:
+  FoldReshapeWithProducerPadOpByCollapsing(MLIRContext *context,
+                                           ControlFusionFn foldReshapes,
+                                           PatternBenefit benefit = 1)
+      : OpRewritePattern<tensor::CollapseShapeOp>(context, benefit),
+        controlFoldingReshapes(std::move(foldReshapes)) {}
+
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp reshapeOp,
+                                PatternRewriter &rewriter) const override {
+    tensor::PadOp padOp = reshapeOp.getSrc().getDefiningOp<tensor::PadOp>();
+    if (!padOp)
+      return failure();
+    if (!padOp->hasOneUse())
+      return failure();
+
+    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
+      return rewriter.notifyMatchFailure(padOp,
+                                         "fusion blocked by control function");
+    }
+
+    SmallVector<ReassociationIndices> reassociations =
+        reshapeOp.getReassociationIndices();
+    RankedTensorType collapsedPaddedType = reshapeOp.getResultType();
+    FailureOr<PadDimInfo> maybeCollapsedPadding =
+        computeCollapsedPadding(padOp, reassociations, rewriter);
+    if (failed(maybeCollapsedPadding))
+      return failure();
+    PadDimInfo collapsedPadding = maybeCollapsedPadding.value();
+
+    Location loc = reshapeOp->getLoc();
+    auto newCollapseOp = tensor::CollapseShapeOp::create(
+        rewriter, loc, padOp.getSource(), reassociations);
+
+    auto newPadOp = tensor::PadOp::create(
+        rewriter, loc, collapsedPaddedType, newCollapseOp.getResult(),
+        collapsedPadding.lowPad, collapsedPadding.highPad,
+        padOp.getConstantPaddingValue(), padOp.getNofold());
+
+    rewriter.replaceOp(reshapeOp, newPadOp.getResult());
+    return success();
+  }
+
+private:
+  ControlFusionFn controlFoldingReshapes;
+};
+
 /// Pattern to collapse dimensions.
 template <typename LinalgType>
 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> {
@@ -2239,6 +2420,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
                                                     controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                         controlFoldingReshapes);
+  patterns.add<FoldExpandShapeWithProducerPadOp>(patterns.getContext(),
+                                                 controlFoldingReshapes);
   patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
@@ -2250,6 +2433,8 @@ void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns(
                                                       controlFoldingReshapes);
   patterns.add<FoldPadWithProducerReshapeOpByCollapsing>(
       patterns.getContext(), controlFoldingReshapes);
+  patterns.add<FoldReshapeWithProducerPadOpByCollapsing>(
+      patterns.getContext(), controlFoldingReshapes);
   patterns.add<FoldReshapeWithGenericOpByCollapsing>(patterns.getContext(),
                                                      controlFoldingReshapes);
 }
diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
index 2bf3d21c35526..923bb2ca9c28a 100644
--- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
+++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir
@@ -639,6 +639,45 @@ func.func @fuse_by_collapsing_dynamic_pad(%arg0 : tensor<?x?x?x?xf32>,
 // CHECK-SAME:       output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
 //      CHECK:   return %[[EXPAND]]
 
+// -----
+
+func.func @collapse_shape_with_producer_pad(%arg0: tensor<2x3x4x5x6x7x8x9xi32>) -> tensor<8x12x17x336x14xi32> {
+  %cst = arith.constant 0 : i32
+  %padded = tensor.pad %arg0 low[1, 0, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index,
+       %arg5: index, %arg6: index, %arg7: index, %arg8: index):
+    tensor.yield %cst : i32
+  } : tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x3x4x17x6x7x8x14xi32>
+  %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]]
+    : tensor<8x3x4x17x6x7x8x14xi32> into tensor<8x12x17x336x14xi32>
+  return %collapsed : tensor<8x12x17x336x14xi32>
+}
+//      CHECK: func @collapse_shape_with_producer_pad
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x3x4x5x6x7x8x9xi32>
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
+//      CHECK:   return %[[PAD]]
+
+// -----
+
+func.func @collapse_shape_with_producer_pad_dynamic(%arg0: tensor<?x?x?x?x?x?xf32>,
+    %l0 : index, %l1 : index, %h0 : index, %h1 : index) -> tensor<?x?x?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+  %padded = tensor.pad %arg0 low[%l0, 0, 0, %l1, 0, 0] high[%h0, 0, 0, %h1, 0, 0] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index):
+    tensor.yield %cst : f32
+  } : tensor<?x?x?x?x?x?xf32> to tensor<?x?x?x?x?x?xf32>
+  %collapsed = tensor.collapse_shape %padded [[0], [1, 2], [3], [4, 5]]
+    : tensor<?x?x?x?x?x?xf32> into tensor<?x?x?x?xf32>
+  return %collapsed : tensor<?x?x?x?xf32>
+}
+//      CHECK: func @collapse_shape_with_producer_pad_dynamic
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?x?x?x?x?xf32>
+// CHECK-SAME:   %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
+//      CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5]]
+//      CHECK:   %[[PAD:.+]] = tensor.pad %[[COLLAPSE]] low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
+//      CHECK:   return %[[PAD]]
+
 // -----
 // Static problem sizes. Checks all aspects of fusion by collapsing with bubbling up collapse shapes.
 #map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..f6572674d10e2 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -863,6 +863,47 @@ func.func @fuse_by_expanding_dynamic_pad(%arg0 : tensor<?x?x?x?x?x?xi32>, %l0: i
 
 // -----
 
+func.func @expand_shape_with_producer_pad(%arg0: tensor<2x12x5x336x9xi32>) -> tensor<8x3x4x17x6x7x8x14xi32> {
+  %cst = arith.constant 0 : i32
+  %padded = tensor.pad %arg0 low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2] {
+  ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index):
+    tensor.yield %cst : i32
+  } : tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
+  %expanded = tensor.expand_shape %padded [[0], [1, 2], [3], [4, 5, 6], [7]] output_shape [8, 3, 4, 17, 6, 7, 8, 14]
+    : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
+  return %expanded : tensor<8x3x4x17x6x7x8x14xi32>
+}
+//      CHECK: func @expand_shape_with_producer_pad
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>
+//      CHECK:   %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/168888


More information about the Mlir-commits mailing list