[Mlir-commits] [mlir] 1508a8e - [MLIR][Linalg] Modify `rewriteAsPaddedOp` to not remove pre-padded op (#163467)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 17 16:46:04 PDT 2025


Author: James Newling
Date: 2025-10-17T16:46:00-07:00
New Revision: 1508a8ec8d62ab1e9bdc8b7e0dbaaec9075b631f

URL: https://github.com/llvm/llvm-project/commit/1508a8ec8d62ab1e9bdc8b7e0dbaaec9075b631f
DIFF: https://github.com/llvm/llvm-project/commit/1508a8ec8d62ab1e9bdc8b7e0dbaaec9075b631f.diff

LOG: [MLIR][Linalg] Modify `rewriteAsPaddedOp`  to not remove pre-padded op (#163467)

Refactor/redesign `FailureOr<TilingInterface> rewriteAsPaddedOp(...)` to
not remove unpadded operation. This is more in line with how other
transformations like tiling work, where the user of the transformation
decides when to replace the actual operation. Instead of this, return
all info as a struct.

---------

Signed-off-by: James Newling <james.newling at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085a1f7a8..c89fc59c91830 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
 
 namespace mlir {
 namespace bufferization {
@@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
 /// In the future, more general interfaces can be devised to encode similar
 /// shape evolutions and map between an op and its operands.
 SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
                    AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
                    const PadTilingInterfaceOptions &options);
 
 using PadSizeComputationFunction =
     std::function<FailureOr<SmallVector<OpFoldResult>>(
-        RewriterBase &, OpOperand &, ArrayRef<Range>,
+        OpBuilder &, OpOperand &, ArrayRef<Range>,
         const PadTilingInterfaceOptions &)>;
 
 /// Specific helper for Linalg ops.
-FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
-    RewriterBase &rewriter, OpOperand &operandToPad,
-    ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
+FailureOr<SmallVector<OpFoldResult>>
+computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
+                                         ArrayRef<Range> iterationDomain,
+                                         const PadTilingInterfaceOptions &);
+
+/// Operations and values created in the process of padding a TilingInterface
+/// operation.
+struct PadTilingInterfaceResult {
+  /// The operands of the padded op.
+  SmallVector<tensor::PadOp> padOps;
+  /// The padded op, a clone of `toPad` with padded operands.
+  TilingInterface paddedOp;
+  /// Slices of the padded op's results, same types as `toPad`.
+  SmallVector<Value> replacements;
+};
 
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
 /// * "options.paddingSizes" indicates that each padding dimension should be
 ///   padded to the specified padding size.
 /// * "options.padToMultipleOf" indicates that the paddingSizes should be
 //    interpreted as the bounding box (dynamic) value to pad to.
 /// * Use "options.paddingValues" to set the padding value of the created
 //    tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
-
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
-                  const PadTilingInterfaceOptions &constOptions,
-                  SmallVector<tensor::PadOp> &padOps,
-                  const PadSizeComputationFunction &computePaddingSizeFun =
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+                  PadTilingInterfaceOptions options,
+                  const PadSizeComputationFunction & =
                       &computeIndexingMapOpInterfacePaddedShape);
 
 namespace detail {

diff  --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6192d791f87aa..9a8a63e54d02d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2457,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
     }
 
     // Set options.
-    TilingInterface paddedOp;
     PadTilingInterfaceOptions options;
     options.setPaddingValues(paddingValues)
         .setPaddingSizes(getMixedPaddingSizes())
         .setPadToMultipleOf(getPadToMultipleOf());
 
-    // Apply padding.
-    SmallVector<tensor::PadOp> newPadOps;
-    FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
-        rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
-        newPadOps);
-    if (failed(maybePaddedOp)) {
+    auto maybePadOps = rewriteAsPaddedOp(
+        rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+    if (failed(maybePadOps)) {
       auto diag = emitSilenceableError() << "failed to pad op";
       diag.attachNote(target->getLoc()) << "target op";
       return diag;
     }
+    const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
 
     // Set transform results.
-    paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
-    padOps.append(newPadOps.begin(), newPadOps.end());
+    paddedOps.push_back(paddedOp);
+    padOps.append(paddedOperands.begin(), paddedOperands.end());
+    rewriter.replaceOp(targetOp.getOperation(), slicedResults);
   }
 
   results.set(cast<OpResult>(getPadded()), paddedOps);

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d771394..3e787a2ad0ef5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
 /// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
 /// In the future, more general interfaces can be devised to encode similar
 /// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
-    RewriterBase &rewriter, TypedValue<RankedTensorType> v,
-    AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
-    const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &builder, TypedValue<RankedTensorType> v,
+                           AffineMap indexingMap,
+                           ArrayRef<OpFoldResult> indexingSizes,
+                           const PadTilingInterfaceOptions &options) {
   Location loc = v.getLoc();
   SmallVector<OpFoldResult> paddedShape;
   auto tensorType = cast<RankedTensorType>(v.getType());
@@ -109,7 +110,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
 
   // "Full-rank" padding specification.
   SmallVector<OpFoldResult> paddingSizes =
-      getFullRankPaddingSizes(rewriter, indexingSizes, options);
+      getFullRankPaddingSizes(builder, indexingSizes, options);
 
   // For each dimension in the operand's shape, iterate over indexingSizes and
   // add the various term contributions.
@@ -147,28 +148,27 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
       OpFoldResult paddingDimOfr;
       if (options.padToMultipleOf) {
         AffineExpr d0, s0;
-        bindDims(rewriter.getContext(), d0);
-        bindSymbols(rewriter.getContext(), s0);
+        bindDims(builder.getContext(), d0);
+        bindSymbols(builder.getContext(), s0);
         AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
         AffineMap composedMap = projectedMap.compose(ceilMap);
         paddingDimOfr = affine::makeComposedFoldedAffineApply(
-            rewriter, loc, composedMap,
-            {indexingSizes[paddingDim], paddingSize},
+            builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize},
             /*composeAffineMin=*/true);
       } else {
         // Otherwise just set to paddingSize.
         paddingDimOfr = affine::makeComposedFoldedAffineApply(
-            rewriter, loc, projectedMap, paddingSize);
+            builder, loc, projectedMap, paddingSize);
       }
 
       // Adjust for the maximum accessed index, which is (paddingSize - 1) *
       // multiplier.
       AffineExpr d0;
-      bindDims(rewriter.getContext(), d0);
+      bindDims(builder.getContext(), d0);
       int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
       AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
       OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, subtractMap, {paddingDimOfr});
+          builder, loc, subtractMap, {paddingDimOfr});
       terms.push_back(maxAccessIdx);
 
       LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
@@ -177,19 +177,19 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
     // If there are no terms, just return the dim.
     if (terms.empty()) {
       paddedShape[resultIndex] =
-          createFoldedDimOp(rewriter, loc, v, resultIndex);
+          createFoldedDimOp(builder, loc, v, resultIndex);
       continue;
     }
 
     // Sum individual terms' contributions.
     SmallVector<AffineExpr> dims(terms.size());
-    bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
+    bindDimsList(builder.getContext(), MutableArrayRef{dims});
     AffineExpr sumExpr = dims.front();
     for (unsigned i = 1; i < dims.size(); ++i)
       sumExpr = sumExpr + dims[i];
     // Add 1 to the maximum accessed index and get the final padded size.
-    OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc, sumExpr + 1, terms);
+    OpFoldResult paddedDimOfr =
+        affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms);
     paddedShape[resultIndex] = paddedDimOfr;
   }
 
@@ -198,7 +198,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
 
 FailureOr<SmallVector<OpFoldResult>>
 linalg::computeIndexingMapOpInterfacePaddedShape(
-    RewriterBase &rewriter, OpOperand &operandToPad,
+    OpBuilder &builder, OpOperand &operandToPad,
     ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
   auto transferOp =
       llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -206,9 +206,9 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
     return failure();
 
   // clang-format off
-  assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
-    return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
-    r.stride == OpFoldResult(rewriter.getIndexAttr(1));
+  assert(llvm::all_of(iterationDomain, [&builder](Range r) {
+    return r.offset == OpFoldResult(builder.getIndexAttr(0)) &&
+    r.stride == OpFoldResult(builder.getIndexAttr(1));
   }) && "expected 0-offset 1-stride loop ranges");
   // clang-format on
   SmallVector<OpFoldResult> loopUpperBounds;
@@ -218,13 +218,13 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
 
   AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
   return computePaddedShape(
-      rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
+      builder, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
       indexingMap, loopUpperBounds, options);
 }
 
 /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
 /// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &builder, TilingInterface opToPad,
                         TypedValue<RankedTensorType> v,
                         ArrayRef<OpFoldResult> paddedShape,
                         Attribute paddingValueAttr) {
@@ -232,15 +232,15 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
   if (auto complexTy =
           dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
     if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
-      paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+      paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(),
                                                  complexTy, complexAttr);
     }
   } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
-    paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+    paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(),
                                         getElementTypeOrSelf(v.getType()));
   } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
     paddingValue =
-        arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
+        arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr);
   }
   assert(paddingValue && "failed to create value from padding attribute");
 
@@ -259,49 +259,48 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
       RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
   LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
                     << paddedTensorType);
-  return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
+  return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v,
                                paddingValue, /*nofold=*/false, dynDims);
 }
 
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
-    RewriterBase &rewriter, TilingInterface opToPad,
-    const PadTilingInterfaceOptions &constOptions,
-    SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+    OpBuilder &builder, TilingInterface toPad,
+    PadTilingInterfaceOptions options,
     const PadSizeComputationFunction &computePaddingSizeFun) {
-  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+  LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+  SmallVector<tensor::PadOp> padOps;
+  Location loc = toPad.getLoc();
 
-  Location loc = opToPad.getLoc();
-  PadTilingInterfaceOptions options(constOptions);
   // Allow inference of pad values if they are not explicitly specified.
   // TODO: be mindful about the value depending on the actual operation.
   if (options.paddingValues.empty()) {
-    SmallVector<Type> types(opToPad->getOperandTypes());
-    llvm::append_range(types, opToPad->getResultTypes());
+    SmallVector<Type> types(toPad->getOperandTypes());
+    llvm::append_range(types, toPad->getResultTypes());
     for (Type t : types) {
       options.paddingValues.push_back(
-          rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+          builder.getZeroAttr(getElementTypeOrSelf(t)));
     }
   }
 
-  if (llvm::any_of(opToPad->getOperands(),
+  if (llvm::any_of(toPad->getOperands(),
                    [](Value v) { return isa<MemRefType>(v.getType()); })) {
-    return rewriter.notifyMatchFailure(opToPad,
-                                       "expected operation on tensors");
+    LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+    return failure();
   }
 
-  OpBuilder::InsertionGuard g(rewriter);
-  // Set IP after opToPad because we also take the dims of opToPad's output.
-  rewriter.setInsertionPointAfter(opToPad);
+  OpBuilder::InsertionGuard g(builder);
+  // Set IP after toPad because we also take the dims of toPad's output.
+  builder.setInsertionPointAfter(toPad);
 
   // 1. Get the loopUpperBounds from the TilingInterface.
-  SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+  SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
 
   // 2. For each operand.
   SmallVector<Value> newOperands;
-  newOperands.reserve(opToPad->getNumOperands());
-  for (OpOperand &opOperand : opToPad->getOpOperands()) {
+  newOperands.reserve(toPad->getNumOperands());
+  for (OpOperand &opOperand : toPad->getOpOperands()) {
     Value operand = opOperand.get();
-    LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+    LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
 
     // 2.a. Skip scalar-like operands.
     Type operandType = operand.getType();
@@ -311,30 +310,31 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
       newOperands.push_back(operand);
       continue;
     }
+
     // 2.a. Compute padded shape.
     FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
-        computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+        computePaddingSizeFun(builder, opOperand, iterationDomain, options);
     if (failed(maybePaddedShape)) {
-      return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+      LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+      return failure();
     }
 
     // 2.b. Expect proper `paddingValues`.
     // TODO: we may want to allow garbage padding in the future, in which case
     // we would just not assert.
     if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
-      return rewriter.notifyMatchFailure(opToPad,
-                                         "--no padding value specified");
+      LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+      return failure();
     }
     Attribute paddingValueAttr =
         options.paddingValues[opOperand.getOperandNumber()];
 
     // 2.c. Perform actual padding.
-    Value paddedOperand = padOperand(
-        rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
-        *maybePaddedShape, paddingValueAttr);
+    Value paddedOperand =
+        padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+                   *maybePaddedShape, paddingValueAttr);
     LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
 
-    // 2.d. Perform actual padding.
     newOperands.push_back(paddedOperand);
     if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
       padOps.push_back(padOp);
@@ -342,38 +342,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
 
   // 3. Form the resulting tensor::ExtractSliceOp.
   ReifiedRankedShapedTypeDims reifiedResultShapes;
-  if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
-    LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
-    return rewriter.notifyMatchFailure(opToPad,
-                                       "failed to reify result shapes");
+  if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+    LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+    return failure();
   }
-  assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+  assert(reifiedResultShapes.size() == toPad->getNumResults() &&
          "expected same number of results");
 
-  // Clone `opToPad` to operate on the statically padded shapes.
+  // Clone `toPad` to operate on the statically padded shapes.
   auto resultTensorTypes =
-      ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
-  // clone **should** properly notify the rewriter.
+      ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+  // clone **should** properly notify the builder.
   TilingInterface paddedOp =
-      clone(rewriter, opToPad, resultTensorTypes, newOperands);
+      clone(builder, toPad, resultTensorTypes, newOperands);
   LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
 
-  // Recover the slice out of the new static results. This keeps the original
-  // opToPad around because it uses the dims of the original results.
+  // Recover the slice out of the new static results.
   SmallVector<Value> paddedSubtensorResults;
-  paddedSubtensorResults.reserve(opToPad->getNumResults());
+  paddedSubtensorResults.reserve(toPad->getNumResults());
   for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
     Value paddedResult = en.value();
     int64_t resultNumber = en.index();
     int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
-    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+    SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+    SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
     paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
-        rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+        builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
         strides));
   }
 
-  rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
-  return paddedOp;
+  return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults};
 }


        


More information about the Mlir-commits mailing list