[Mlir-commits] [mlir] [mlir] Add missing pad reshape propagation patterns (PR #168888)

Ian Wood llvmlistbot at llvm.org
Thu Nov 20 11:24:40 PST 2025


================
@@ -1921,6 +2023,52 @@ struct FoldReshapeWithGenericOpByCollapsing
   ControlFusionFn controlFoldingReshapes;
 };
 
+/// Computes the collapsed padding information for the given pad operation based
+/// on the provided collapsed shape and reassociation indices. Returns a
+/// PadDimInfo containing the low and high padding amounts and the collapsed
+/// shape for each dimension, or failure if the collapse is not possible.
+static FailureOr<PadDimInfo>
+computeCollapsedPadding(tensor::PadOp padOp,
+                        ArrayRef<ReassociationIndices> reassociations,
+                        PatternRewriter &rewriter) {
+  ArrayRef<int64_t> low = padOp.getStaticLow();
+  ArrayRef<int64_t> high = padOp.getStaticHigh();
+
+  // Collapsed dimensions cannot have padding because this can produce strided
+  // padding that isn't representable by a tensor.pad op. There are some special
+  // cases where it it possible (like collapsing unit dims), but supporting
+  // these cases is NYI, so disallow it for now.
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    for (int64_t dim : reInd) {
+      if ((low[dim] != 0 || high[dim] != 0) && reInd.size() != 1)
+        return failure();
+    }
+  }
+
+  // Initialize padding values for collapsed tensors with zeros
+  ArrayRef<int64_t> expandedPaddedShape = padOp.getType().getShape();
+  PadDimInfo padDimInfo;
+  padDimInfo.lowPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+  padDimInfo.highPad.assign(reassociations.size(), rewriter.getIndexAttr(0));
+
+  // Update padding for dimensions that are not being collapsed, and compute
+  // the collapsed padded shape.
+  for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+    if (reInd.size() == 1) {
+      padDimInfo.lowPad[idx] = padOp.getMixedLowPad()[reInd[0]];
+      padDimInfo.highPad[idx] = padOp.getMixedHighPad()[reInd[0]];
----------------
IanWood1 wrote:

`getMixedLowPad()` and `getMixedHighPad()` both construct a vector. Can this be stored in a variable instead of calling the method each time?

https://github.com/llvm/llvm-project/pull/168888


More information about the Mlir-commits mailing list