[Mlir-commits] [mlir] [MLIR][Linalg] Modify `rewriteAsPaddedOp` to not remove pre-padded op (PR #163467)
James Newling
llvmlistbot at llvm.org
Tue Oct 14 15:41:05 PDT 2025
https://github.com/newling created https://github.com/llvm/llvm-project/pull/163467
Refactor/redesign `FailureOr<TilingInterface> rewriteAsPaddedOp(...)` to not remove unpadded operation.
I previously found it difficult to work with this API (in IREE), as the original (pre-padded) operation was still useful for a while after it's replacement was created. I believe @Groverkss also has a use case where he wants the pre-padded value to stick around.
TODO(newling) rename the function, as it no longer does a rewrite
>From f3a357bf3dc2cbe51ecba3e8b622275f9121505a Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 14 Oct 2025 15:38:03 -0700
Subject: [PATCH] refactor
Signed-off-by: James Newling <james.newling at gmail.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 28 ++---
.../TransformOps/LinalgTransformOps.cpp | 17 +--
.../Linalg/Transforms/PadTilingInterface.cpp | 100 +++++++++---------
3 files changed, 74 insertions(+), 71 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085a1f7a8..98c1af43fe67a 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
@@ -621,35 +620,40 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);
using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
- RewriterBase &, OpOperand &, ArrayRef<Range>,
+ OpBuilder &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- const PadSizeComputationFunction &computePaddingSizeFun =
+struct PadTilingInterfaceResult {
+ /// Padded operands of `toPad`.
+ SmallVector<tensor::PadOp> padOps;
+ /// Slices of the padded op that have the same shapes as `toPad` results.
+ SmallVector<Value> replacements;
+ /// The cloned and padded version of `toPad`.
+ TilingInterface paddedOp;
+};
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
+ const PadSizeComputationFunction & =
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d8f983f98ae77..ce444f3db16e9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2457,26 +2457,27 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
}
// Set options.
- TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());
- // Apply padding.
- SmallVector<tensor::PadOp> newPadOps;
- FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
- rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
- newPadOps);
- if (failed(maybePaddedOp)) {
+ auto maybePadOps = rewriteAsPaddedOp(
+ rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+ if (failed(maybePadOps)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
+ const auto &[newPadOps, replacementValues, newPaddedOp] = *maybePadOps;
+
// Set transform results.
- paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
+ paddedOps.push_back(newPaddedOp);
padOps.append(newPadOps.begin(), newPadOps.end());
+
+ // erase targetOp:
+ rewriter.replaceOp(targetOp.getOperation(), replacementValues);
}
results.set(cast<OpResult>(getPadded()), paddedOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d771394..8d865f39669c5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
- RewriterBase &rewriter, TypedValue<RankedTensorType> v,
- AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
- const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options) {
Location loc = v.getLoc();
SmallVector<OpFoldResult> paddedShape;
auto tensorType = cast<RankedTensorType>(v.getType());
@@ -198,7 +199,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -224,7 +225,7 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
@@ -263,45 +264,44 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
- RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+ OpBuilder &builder, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
const PadSizeComputationFunction &computePaddingSizeFun) {
- LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+ LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+ SmallVector<tensor::PadOp> padOps;
+ Location loc = toPad.getLoc();
- Location loc = opToPad.getLoc();
- PadTilingInterfaceOptions options(constOptions);
// Allow inference of pad values if they are not explicitly specified.
// TODO: be mindful about the value depending on the actual operation.
if (options.paddingValues.empty()) {
- SmallVector<Type> types(opToPad->getOperandTypes());
- llvm::append_range(types, opToPad->getResultTypes());
+ SmallVector<Type> types(toPad->getOperandTypes());
+ llvm::append_range(types, toPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
- rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+ builder.getZeroAttr(getElementTypeOrSelf(t)));
}
}
- if (llvm::any_of(opToPad->getOperands(),
+ if (llvm::any_of(toPad->getOperands(),
[](Value v) { return isa<MemRefType>(v.getType()); })) {
- return rewriter.notifyMatchFailure(opToPad,
- "expected operation on tensors");
+ LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+ return failure();
}
- OpBuilder::InsertionGuard g(rewriter);
- // Set IP after opToPad because we also take the dims of opToPad's output.
- rewriter.setInsertionPointAfter(opToPad);
+ OpBuilder::InsertionGuard g(builder);
+ // Set IP after toPad because we also take the dims of toPad's output.
+ builder.setInsertionPointAfter(toPad);
// 1. Get the loopUpperBounds from the TilingInterface.
- SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+ SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
// 2. For each operand.
SmallVector<Value> newOperands;
- newOperands.reserve(opToPad->getNumOperands());
- for (OpOperand &opOperand : opToPad->getOpOperands()) {
+ newOperands.reserve(toPad->getNumOperands());
+ for (OpOperand &opOperand : toPad->getOpOperands()) {
Value operand = opOperand.get();
- LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+ LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
@@ -311,27 +311,29 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
newOperands.push_back(operand);
continue;
}
+
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
- computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+ computePaddingSizeFun(builder, opOperand, iterationDomain, options);
if (failed(maybePaddedShape)) {
- return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+ LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+ return failure();
}
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
- return rewriter.notifyMatchFailure(opToPad,
- "--no padding value specified");
+ LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+ return failure();
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
- Value paddedOperand = padOperand(
- rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
- *maybePaddedShape, paddingValueAttr);
+ Value paddedOperand =
+ padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+ *maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
// 2.d. Perform actual padding.
@@ -342,38 +344,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
// 3. Form the resulting tensor::ExtractSliceOp.
ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
- LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
- return rewriter.notifyMatchFailure(opToPad,
- "failed to reify result shapes");
+ if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+ LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+ return failure();
}
- assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+ assert(reifiedResultShapes.size() == toPad->getNumResults() &&
"expected same number of results");
- // Clone `opToPad` to operate on the statically padded shapes.
+ // Clone `toPad` to operate on the statically padded shapes.
auto resultTensorTypes =
- ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
- // clone **should** properly notify the rewriter.
+ ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+ // clone **should** properly notify the builder.
TilingInterface paddedOp =
- clone(rewriter, opToPad, resultTensorTypes, newOperands);
+ clone(builder, toPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
- // Recover the slice out of the new static results. This keeps the original
- // opToPad around because it uses the dims of the original results.
+ // Recover the slice out of the new static results.
SmallVector<Value> paddedSubtensorResults;
- paddedSubtensorResults.reserve(opToPad->getNumResults());
+ paddedSubtensorResults.reserve(toPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
- rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
- rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
- return paddedOp;
+ return PadTilingInterfaceResult{padOps, paddedSubtensorResults, paddedOp};
}
More information about the Mlir-commits
mailing list