[Mlir-commits] [mlir] [MLIR][Linalg] Modify `rewriteAsPaddedOp` to not remove pre-padded op (PR #163467)

James Newling llvmlistbot at llvm.org
Tue Oct 14 15:41:05 PDT 2025


https://github.com/newling created https://github.com/llvm/llvm-project/pull/163467

Refactor/redesign `FailureOr<TilingInterface> rewriteAsPaddedOp(...)` to not remove unpadded operation. 

I previously found it difficult to work with this API (in IREE), as the original (pre-padded) operation was still useful for a while after it's replacement was created. I believe @Groverkss also has a use case where he wants the pre-padded value to stick around. 

TODO(newling) rename the function, as it no longer does a rewrite

>From f3a357bf3dc2cbe51ecba3e8b622275f9121505a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 14 Oct 2025 15:38:03 -0700
Subject: [PATCH] refactor

Signed-off-by: James Newling <james.newling at gmail.com>
---
 .../Dialect/Linalg/Transforms/Transforms.h    |  28 ++---
 .../TransformOps/LinalgTransformOps.cpp       |  17 +--
 .../Linalg/Transforms/PadTilingInterface.cpp  | 100 +++++++++---------
 3 files changed, 74 insertions(+), 71 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085a1f7a8..98c1af43fe67a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
 namespace bufferization {
@@ -621,35 +620,40 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
 /// In the future, more general interfaces can be devised to encode similar
 /// shape evolutions and map between an op and its operands.
 SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
                    AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
                    const PadTilingInterfaceOptions &options);
 
 using PadSizeComputationFunction =
     std::function<FailureOr<SmallVector<OpFoldResult>>(
-        RewriterBase &, OpOperand &, ArrayRef<Range>,
+        OpBuilder &, OpOperand &, ArrayRef<Range>,
         const PadTilingInterfaceOptions &)>;
 
 /// Specific helper for Linalg ops.
 FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
-    RewriterBase &rewriter, OpOperand &operandToPad,
+    OpBuilder &rewriter, OpOperand &operandToPad,
     ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
 
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
 /// * "options.paddingSizes" indicates that each padding dimension should be
 ///   padded to the specified padding size.
 /// * "options.padToMultipleOf" indicates that the paddingSizes should be
 //    interpreted as the bounding box (dynamic) value to pad to.
 /// * Use "options.paddingValues" to set the padding value of the created
 //    tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
 
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
-                  const PadTilingInterfaceOptions &constOptions,
-                  SmallVector<tensor::PadOp> &padOps,
-                  const PadSizeComputationFunction &computePaddingSizeFun =
+struct PadTilingInterfaceResult {
+  /// Padded operands of `toPad`.
+  SmallVector<tensor::PadOp> padOps;
+  /// Slices of the padded op that have the same shapes as `toPad` results.
+  SmallVector<Value> replacements;
+  /// The cloned and padded version of `toPad`.
+  TilingInterface paddedOp;
+};
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+                  PadTilingInterfaceOptions options,
+                  const PadSizeComputationFunction & =
                       &computeIndexingMapOpInterfacePaddedShape);
 
 namespace detail {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d8f983f98ae77..ce444f3db16e9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2457,26 +2457,27 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
     }
 
     // Set options.
-    TilingInterface paddedOp;
     PadTilingInterfaceOptions options;
     options.setPaddingValues(paddingValues)
         .setPaddingSizes(getMixedPaddingSizes())
         .setPadToMultipleOf(getPadToMultipleOf());
 
-    // Apply padding.
-    SmallVector<tensor::PadOp> newPadOps;
-    FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
-        rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
-        newPadOps);
-    if (failed(maybePaddedOp)) {
+    auto maybePadOps = rewriteAsPaddedOp(
+        rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+    if (failed(maybePadOps)) {
       auto diag = emitSilenceableError() << "failed to pad op";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
 
+    const auto &[newPadOps, replacementValues, newPaddedOp] = *maybePadOps;
+
     // Set transform results.
-    paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+    paddedOps.push_back(newPaddedOp);
     padOps.append(newPadOps.begin(), newPadOps.end());
+
+    // erase targetOp:
+    rewriter.replaceOp(targetOp.getOperation(), replacementValues);
   }
 
   results.set(cast<OpResult>(getPadded()), paddedOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d771394..8d865f39669c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
 /// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
 /// In the future, more general interfaces can be devised to encode similar
 /// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
-    RewriterBase &rewriter, TypedValue<RankedTensorType> v,
-    AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
-    const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
+                           AffineMap indexingMap,
+                           ArrayRef<OpFoldResult> indexingSizes,
+                           const PadTilingInterfaceOptions &options) {
   Location loc = v.getLoc();
   SmallVector<OpFoldResult> paddedShape;
   auto tensorType = cast<RankedTensorType>(v.getType());
@@ -198,7 +199,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
 
 FailureOr<SmallVector<OpFoldResult>>
 linalg::computeIndexingMapOpInterfacePaddedShape(
-    RewriterBase &rewriter, OpOperand &operandToPad,
+    OpBuilder &rewriter, OpOperand &operandToPad,
     ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
   auto transferOp =
       llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -224,7 +225,7 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
 
 /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
 /// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
                         TypedValue<RankedTensorType> v,
                         ArrayRef<OpFoldResult> paddedShape,
                         Attribute paddingValueAttr) {
@@ -263,45 +264,44 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
                                paddingValue, /*nofold=*/false, dynDims);
 }
 
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
-    RewriterBase &rewriter, TilingInterface opToPad,
-    const PadTilingInterfaceOptions &constOptions,
-    SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+    OpBuilder &builder, TilingInterface toPad,
+    PadTilingInterfaceOptions options,
     const PadSizeComputationFunction &computePaddingSizeFun) {
-  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+  SmallVector<tensor::PadOp> padOps;
+  Location loc = toPad.getLoc();
 
-  Location loc = opToPad.getLoc();
-  PadTilingInterfaceOptions options(constOptions);
   // Allow inference of pad values if they are not explicitly specified.
   // TODO: be mindful about the value depending on the actual operation.
   if (options.paddingValues.empty()) {
-    SmallVector<Type> types(opToPad->getOperandTypes());
-    llvm::append_range(types, opToPad->getResultTypes());
+    SmallVector<Type> types(toPad->getOperandTypes());
+    llvm::append_range(types, toPad->getResultTypes());
     for (Type t : types) {
       options.paddingValues.push_back(
-          rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+          builder.getZeroAttr(getElementTypeOrSelf(t)));
     }
   }
 
-  if (llvm::any_of(opToPad->getOperands(),
+  if (llvm::any_of(toPad->getOperands(),
                    [](Value v) { return isa<MemRefType>(v.getType()); })) {
-    return rewriter.notifyMatchFailure(opToPad,
-                                       "expected operation on tensors");
+    LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+    return failure();
   }
 
-  OpBuilder::InsertionGuard g(rewriter);
-  // Set IP after opToPad because we also take the dims of opToPad's output.
-  rewriter.setInsertionPointAfter(opToPad);
+  OpBuilder::InsertionGuard g(builder);
+  // Set IP after toPad because we also take the dims of toPad's output.
+  builder.setInsertionPointAfter(toPad);
 
   // 1. Get the loopUpperBounds from the TilingInterface.
-  SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+  SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
 
   // 2. For each operand.
   SmallVector<Value> newOperands;
-  newOperands.reserve(opToPad->getNumOperands());
-  for (OpOperand &opOperand : opToPad->getOpOperands()) {
+  newOperands.reserve(toPad->getNumOperands());
+  for (OpOperand &opOperand : toPad->getOpOperands()) {
     Value operand = opOperand.get();
-    LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+    LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
 
     // 2.a. Skip scalar-like operands.
     Type operandType = operand.getType();
@@ -311,27 +311,29 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
       newOperands.push_back(operand);
       continue;
     }
+
     // 2.a. Compute padded shape.
     FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
-        computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+        computePaddingSizeFun(builder, opOperand, iterationDomain, options);
     if (failed(maybePaddedShape)) {
-      return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+      LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+      return failure();
     }
 
     // 2.b. Expect proper `paddingValues`.
     // TODO: we may want to allow garbage padding in the future, in which case
     // we would just not assert.
     if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
-      return rewriter.notifyMatchFailure(opToPad,
-                                         "--no padding value specified");
+      LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+      return failure();
     }
     Attribute paddingValueAttr =
         options.paddingValues[opOperand.getOperandNumber()];
 
     // 2.c. Perform actual padding.
-    Value paddedOperand = padOperand(
-        rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
-        *maybePaddedShape, paddingValueAttr);
+    Value paddedOperand =
+        padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+                   *maybePaddedShape, paddingValueAttr);
     LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
 
     // 2.d. Perform actual padding.
@@ -342,38 +344,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
 
   // 3. Form the resulting tensor::ExtractSliceOp.
   ReifiedRankedShapedTypeDims reifiedResultShapes;
-  if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
-    LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
-    return rewriter.notifyMatchFailure(opToPad,
-                                       "failed to reify result shapes");
+  if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+    LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+    return failure();
   }
-  assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+  assert(reifiedResultShapes.size() == toPad->getNumResults() &&
          "expected same number of results");
 
-  // Clone `opToPad` to operate on the statically padded shapes.
+  // Clone `toPad` to operate on the statically padded shapes.
   auto resultTensorTypes =
-      ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
-  // clone **should** properly notify the rewriter.
+      ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+  // clone **should** properly notify the builder.
   TilingInterface paddedOp =
-      clone(rewriter, opToPad, resultTensorTypes, newOperands);
+      clone(builder, toPad, resultTensorTypes, newOperands);
   LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
 
-  // Recover the slice out of the new static results. This keeps the original
-  // opToPad around because it uses the dims of the original results.
+  // Recover the slice out of the new static results.
   SmallVector<Value> paddedSubtensorResults;
-  paddedSubtensorResults.reserve(opToPad->getNumResults());
+  paddedSubtensorResults.reserve(toPad->getNumResults());
   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
     Value paddedResult = en.value();
     int64_t resultNumber = en.index();
     int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
-    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
     paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
-        rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+        builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
         strides));
   }
 
-  rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
-  return paddedOp;
+  return PadTilingInterfaceResult{padOps, paddedSubtensorResults, paddedOp};
 }



More information about the Mlir-commits mailing list