[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)

Ian Wood llvmlistbot at llvm.org
Thu Nov 20 11:24:40 PST 2025


================
@@ -1038,6 +1038,54 @@ class FoldWithProducerReshapeOpByExpansion
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Carries information about a padded dimension.
+struct PadDimInfo {
+  // The resulting shape after padding each dimension.
+  SmallVector<int64_t> paddedShape;
+
+  // Low and high padding amounts for each dimension.
+  SmallVector<OpFoldResult> lowPad;
+  SmallVector<OpFoldResult> highPad;
+};
+
+/// Computes the expanded padding information for the given pad operation based
+/// on the provided expanded shape and reassociation indices. Returns a list of
+/// PaddedDimInfo containing the low and high padding amounts and the padded
+/// size for each dimension, or failure if the expansion is not possible.
+static FailureOr<PadDimInfo>
+computeExpandedPadding(tensor::PadOp padOp, ArrayRef<int64_t> expandedShape,
+                       ArrayRef<ReassociationIndices> reassociations,
+                       PatternRewriter &rewriter) {
+  ArrayRef<int64_t> low = padOp.getStaticLow();
+  ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+  // Expanded dimensions cannot have padding because the resulting padding may
+  // not be representable by a tensor.pad op. There are some special cases where
+  // it is possible (like expanding unit dims), but supporting these cases is
+  // NYI, so disallow it for now.
+  for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) {
+    if (reInd.size() != 1 && (l != 0 || h != 0))
+      return failure();
+  }
+
+  SmallVector<OpFoldResult> mixedLowPad(padOp.getMixedLowPad());
+  SmallVector<OpFoldResult> mixedHighPad(padOp.getMixedHighPad());
+  ArrayRef<int64_t> paddedShape = padOp.getResultType().getShape();
+  PadDimInfo padDimInfo;
+  padDimInfo.paddedShape.assign(expandedShape);
+  padDimInfo.lowPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+  padDimInfo.highPad.assign(expandedShape.size(), rewriter.getIndexAttr(0));
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    if (reInd.size() == 1) {
+      padDimInfo.paddedShape[reInd[0]] = paddedShape[idx];
+      padDimInfo.lowPad[reInd[0]] = mixedLowPad[idx];
+      padDimInfo.highPad[reInd[0]] = mixedHighPad[idx];
+    }
+  }
+
+  return padDimInfo;
+}
+
 class FoldPadWithProducerReshapeOpByExpansion
----------------
IanWood1 wrote:

All four of these patterns should bail when `padOp.getConstantPaddingValue()` is null. It should be possible to handle non-const padding values similar to what linalg ops do for `linalg.index`.

https://github.com/llvm/llvm-project/pull/168888


More information about the Mlir-commits mailing list