[Mlir-commits] [mlir] [MLIR][Linalg] Modify `rewriteAsPaddedOp` to not remove pre-padded op (PR #163467)
James Newling
llvmlistbot at llvm.org
Fri Oct 17 16:32:53 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/163467
>From d7c6c3eb448b80d850219e00abf3d7329dd21215 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Tue, 14 Oct 2025 15:38:03 -0700
Subject: [PATCH 1/2] Refactor to return more and do less
Signed-off-by: James Newling <james.newling at gmail.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 29 ++---
.../TransformOps/LinalgTransformOps.cpp | 19 ++--
.../Linalg/Transforms/PadTilingInterface.cpp | 101 +++++++++---------
3 files changed, 75 insertions(+), 74 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085a1f7a8..db75379cc21c0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
@@ -621,35 +620,39 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);
using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
- RewriterBase &, OpOperand &, ArrayRef<Range>,
+ OpBuilder &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
-
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- const PadSizeComputationFunction &computePaddingSizeFun =
+struct PadTilingInterfaceResult {
+ /// The operands of the padded op.
+ SmallVector<tensor::PadOp> padOps;
+ /// The padded op, a clone of `toPad` with padded operands.
+ TilingInterface paddedOp;
+ /// Slices of the padded op's results, same types as `toPad`.
+ SmallVector<Value> replacements;
+};
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
+ const PadSizeComputationFunction & =
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 6192d791f87aa..a4d1d11e85633 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2457,26 +2457,27 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
}
// Set options.
- TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());
- // Apply padding.
- SmallVector<tensor::PadOp> newPadOps;
- FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
- rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
- newPadOps);
- if (failed(maybePaddedOp)) {
+ auto maybePadOps = rewriteAsPaddedOp(
+ rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+ if (failed(maybePadOps)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
+ const auto &[paddedOperands, paddedOp, slicedResults] = *maybePadOps;
+
// Set transform results.
- paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
- padOps.append(newPadOps.begin(), newPadOps.end());
+ paddedOps.push_back(paddedOp);
+ padOps.append(paddedOperands.begin(), paddedOperands.end());
+
+ // erase targetOp:
+ rewriter.replaceOp(targetOp.getOperation(), slicedResults);
}
results.set(cast<OpResult>(getPadded()), paddedOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d771394..513ce2c52ec87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
- RewriterBase &rewriter, TypedValue<RankedTensorType> v,
- AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
- const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options) {
Location loc = v.getLoc();
SmallVector<OpFoldResult> paddedShape;
auto tensorType = cast<RankedTensorType>(v.getType());
@@ -198,7 +199,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -224,7 +225,7 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
@@ -263,45 +264,44 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
- RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+ OpBuilder &builder, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
const PadSizeComputationFunction &computePaddingSizeFun) {
- LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+ LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+ SmallVector<tensor::PadOp> padOps;
+ Location loc = toPad.getLoc();
- Location loc = opToPad.getLoc();
- PadTilingInterfaceOptions options(constOptions);
// Allow inference of pad values if they are not explicitly specified.
// TODO: be mindful about the value depending on the actual operation.
if (options.paddingValues.empty()) {
- SmallVector<Type> types(opToPad->getOperandTypes());
- llvm::append_range(types, opToPad->getResultTypes());
+ SmallVector<Type> types(toPad->getOperandTypes());
+ llvm::append_range(types, toPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
- rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+ builder.getZeroAttr(getElementTypeOrSelf(t)));
}
}
- if (llvm::any_of(opToPad->getOperands(),
+ if (llvm::any_of(toPad->getOperands(),
[](Value v) { return isa<MemRefType>(v.getType()); })) {
- return rewriter.notifyMatchFailure(opToPad,
- "expected operation on tensors");
+ LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+ return failure();
}
- OpBuilder::InsertionGuard g(rewriter);
- // Set IP after opToPad because we also take the dims of opToPad's output.
- rewriter.setInsertionPointAfter(opToPad);
+ OpBuilder::InsertionGuard g(builder);
+ // Set IP after toPad because we also take the dims of toPad's output.
+ builder.setInsertionPointAfter(toPad);
// 1. Get the loopUpperBounds from the TilingInterface.
- SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+ SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
// 2. For each operand.
SmallVector<Value> newOperands;
- newOperands.reserve(opToPad->getNumOperands());
- for (OpOperand &opOperand : opToPad->getOpOperands()) {
+ newOperands.reserve(toPad->getNumOperands());
+ for (OpOperand &opOperand : toPad->getOpOperands()) {
Value operand = opOperand.get();
- LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+ LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
@@ -311,30 +311,31 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
newOperands.push_back(operand);
continue;
}
+
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
- computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+ computePaddingSizeFun(builder, opOperand, iterationDomain, options);
if (failed(maybePaddedShape)) {
- return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+ LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+ return failure();
}
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
- return rewriter.notifyMatchFailure(opToPad,
- "--no padding value specified");
+ LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+ return failure();
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
- Value paddedOperand = padOperand(
- rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
- *maybePaddedShape, paddingValueAttr);
+ Value paddedOperand =
+ padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+ *maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
- // 2.d. Perform actual padding.
newOperands.push_back(paddedOperand);
if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
padOps.push_back(padOp);
@@ -342,38 +343,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
// 3. Form the resulting tensor::ExtractSliceOp.
ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
- LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
- return rewriter.notifyMatchFailure(opToPad,
- "failed to reify result shapes");
+ if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+ LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+ return failure();
}
- assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+ assert(reifiedResultShapes.size() == toPad->getNumResults() &&
"expected same number of results");
- // Clone `opToPad` to operate on the statically padded shapes.
+ // Clone `toPad` to operate on the statically padded shapes.
auto resultTensorTypes =
- ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
- // clone **should** properly notify the rewriter.
+ ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+ // clone **should** properly notify the builder.
TilingInterface paddedOp =
- clone(rewriter, opToPad, resultTensorTypes, newOperands);
+ clone(builder, toPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
- // Recover the slice out of the new static results. This keeps the original
- // opToPad around because it uses the dims of the original results.
+ // Recover the slice out of the new static results.
SmallVector<Value> paddedSubtensorResults;
- paddedSubtensorResults.reserve(opToPad->getNumResults());
+ paddedSubtensorResults.reserve(toPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
- rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
- rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
- return paddedOp;
+ return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults};
}
>From a50f53f36d35aa6b346dd21033b3a8a8e2c35c34 Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Fri, 17 Oct 2025 16:36:08 -0700
Subject: [PATCH 2/2] address Max's comments
Signed-off-by: James Newling <james.newling at gmail.com>
---
.../Dialect/Linalg/Transforms/Transforms.h | 26 ++++++-----
.../TransformOps/LinalgTransformOps.cpp | 5 +--
.../Linalg/Transforms/PadTilingInterface.cpp | 45 +++++++++----------
3 files changed, 38 insertions(+), 38 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index db75379cc21c0..c89fc59c91830 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -620,7 +620,7 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);
@@ -630,17 +630,13 @@ using PadSizeComputationFunction =
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
-FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
- OpBuilder &rewriter, OpOperand &operandToPad,
- ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
+FailureOr<SmallVector<OpFoldResult>>
+computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
+ ArrayRef<Range> iterationDomain,
+ const PadTilingInterfaceOptions &);
-/// Pad the iterator dimensions of `toPad`.
-/// * "options.paddingSizes" indicates that each padding dimension should be
-/// padded to the specified padding size.
-/// * "options.padToMultipleOf" indicates that the paddingSizes should be
-// interpreted as the bounding box (dynamic) value to pad to.
-/// * Use "options.paddingValues" to set the padding value of the created
-// tensor::PadOp.
+/// Operations and values created in the process of padding a TilingInterface
+/// operation.
struct PadTilingInterfaceResult {
/// The operands of the padded op.
SmallVector<tensor::PadOp> padOps;
@@ -649,6 +645,14 @@ struct PadTilingInterfaceResult {
/// Slices of the padded op's results, same types as `toPad`.
SmallVector<Value> replacements;
};
+
+/// Pad the iterator dimensions of `toPad`.
+/// * "options.paddingSizes" indicates that each padding dimension should be
+/// padded to the specified padding size.
+/// * "options.padToMultipleOf" indicates that the paddingSizes should be
+// interpreted as the bounding box (dynamic) value to pad to.
+/// * Use "options.paddingValues" to set the padding value of the created
+// tensor::PadOp.
FailureOr<PadTilingInterfaceResult>
rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
PadTilingInterfaceOptions options,
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a4d1d11e85633..9a8a63e54d02d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2469,14 +2469,11 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
-
- const auto &[paddedOperands, paddedOp, slicedResults] = *maybePadOps;
+ const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();
// Set transform results.
paddedOps.push_back(paddedOp);
padOps.append(paddedOperands.begin(), paddedOperands.end());
-
- // erase targetOp:
rewriter.replaceOp(targetOp.getOperation(), slicedResults);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 513ce2c52ec87..3e787a2ad0ef5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -96,7 +96,7 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
+linalg::computePaddedShape(OpBuilder &builder, TypedValue<RankedTensorType> v,
AffineMap indexingMap,
ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options) {
@@ -110,7 +110,7 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
// "Full-rank" padding specification.
SmallVector<OpFoldResult> paddingSizes =
- getFullRankPaddingSizes(rewriter, indexingSizes, options);
+ getFullRankPaddingSizes(builder, indexingSizes, options);
// For each dimension in the operand's shape, iterate over indexingSizes and
// add the various term contributions.
@@ -148,28 +148,27 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
- bindDims(rewriter.getContext(), d0);
- bindSymbols(rewriter.getContext(), s0);
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
paddingDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, composedMap,
- {indexingSizes[paddingDim], paddingSize},
+ builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
} else {
// Otherwise just set to paddingSize.
paddingDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, projectedMap, paddingSize);
+ builder, loc, projectedMap, paddingSize);
}
// Adjust for the maximum accessed index, which is (paddingSize - 1) *
// multiplier.
AffineExpr d0;
- bindDims(rewriter.getContext(), d0);
+ bindDims(builder.getContext(), d0);
int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
- rewriter, loc, subtractMap, {paddingDimOfr});
+ builder, loc, subtractMap, {paddingDimOfr});
terms.push_back(maxAccessIdx);
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
@@ -178,19 +177,19 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
// If there are no terms, just return the dim.
if (terms.empty()) {
paddedShape[resultIndex] =
- createFoldedDimOp(rewriter, loc, v, resultIndex);
+ createFoldedDimOp(builder, loc, v, resultIndex);
continue;
}
// Sum individual terms' contributions.
SmallVector<AffineExpr> dims(terms.size());
- bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
+ bindDimsList(builder.getContext(), MutableArrayRef{dims});
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
// Add 1 to the maximum accessed index and get the final padded size.
- OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
- rewriter, loc, sumExpr + 1, terms);
+ OpFoldResult paddedDimOfr =
+ affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
@@ -199,7 +198,7 @@ linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
- OpBuilder &rewriter, OpOperand &operandToPad,
+ OpBuilder &builder, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -207,9 +206,9 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
return failure();
// clang-format off
- assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
- return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
- r.stride == OpFoldResult(rewriter.getIndexAttr(1));
+ assert(llvm::all_of(iterationDomain, [&builder](Range r) {
+ return r.offset == OpFoldResult(builder.getIndexAttr(0)) &&
+ r.stride == OpFoldResult(builder.getIndexAttr(1));
}) && "expected 0-offset 1-stride loop ranges");
// clang-format on
SmallVector<OpFoldResult> loopUpperBounds;
@@ -219,13 +218,13 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
return computePaddedShape(
- rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
+ builder, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
indexingMap, loopUpperBounds, options);
}
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
-static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &builder, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
@@ -233,15 +232,15 @@ static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(),
complexTy, complexAttr);
}
} else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
- paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(),
getElementTypeOrSelf(v.getType()));
} else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
paddingValue =
- arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
+ arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr);
}
assert(paddingValue && "failed to create value from padding attribute");
@@ -260,7 +259,7 @@ static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
<< paddedTensorType);
- return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
+ return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v,
paddingValue, /*nofold=*/false, dynDims);
}
More information about the Mlir-commits
mailing list