[Mlir-commits] [mlir] [mlir][linalg] Add extra_pad_tiles to linalg.pack & unpack (PR #189049)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Mar 27 09:27:48 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: fabrizio-indirli
<details>
<summary>Changes</summary>
- In linalg.pack, add the optional `extra_pad_tiles` attribute to append a chosen number of additional full tiles of high-padding to each tiled dimension.
- In linalg.unpack, add the dual optional attribute `drop_last_tiles` to drop a chosen number of full outer tiles for each tiled dimension before reconstructing the unpacked result.
---
Patch is 67.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/189049.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td (+37-16)
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+216-39)
- (modified) mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp (+22-16)
- (modified) mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp (+45-18)
- (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+5-3)
- (modified) mlir/python/mlir/dialects/linalg/__init__.py (+4)
- (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+26)
- (modified) mlir/test/Dialect/Linalg/invalid.mlir (+51)
- (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+13)
- (modified) mlir/test/Dialect/Linalg/transform-lower-pack.mlir (+49)
- (modified) mlir/test/python/dialects/linalg/ops.py (+6-4)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
index 95383e6262f71..cdecb41db3123 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
@@ -164,13 +164,16 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
tiles divide perfectly the corresponding outer dimension in the result
tensor. It is UB if the tile does not perfectly divide the dimension.
- If present, it will pad along high dimensions (high-padding) to make the
- tile complete. Note that it is not allowed to have artificial padding that
- is not strictly required by linalg.pack (i.e., padding past what is needed
- to complete the last tile along each packed dimension). It is UB if extra
- padding is requested.
+ tile complete.
It is not possible to verify the requirements statically with dynamic
shapes, so they are treated as UB.
+ `extra_pad_tiles` (optional) specifies a number of additional full tiles of
+ high-padding to append for each tiled dimension. It is indexed in the same
+ order as `inner_dims_pos` / `inner_tiles`, must have the same length, and
+ defaults to all zeros when omitted. `extra_pad_tiles[i]` adds that many
+ extra full tiles at the end of tiled dimension `inner_dims_pos[i]`.
+
Example:
```mlir
%0 = linalg.pack %arg0 padding_value(%pad : f32) outer_dims_perm = [2, 1, 0]
@@ -200,7 +203,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
- DenseI64ArrayAttr:$static_inner_tiles);
+ DenseI64ArrayAttr:$static_inner_tiles,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$extra_pad_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
let builders = [
@@ -208,7 +212,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
"ArrayRef<int64_t>":$innerDimsPos,
"ArrayRef<OpFoldResult>":$innerTiles,
CArg<"std::optional<Value>", "std::nullopt">:$paddingValue,
- CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+ CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm,
+ CArg<"ArrayRef<int64_t>", "{}">:$extraPadTiles)>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
@@ -218,28 +223,32 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
static SmallVector<OpFoldResult> getResultShape(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> sourceDims,
ArrayRef<OpFoldResult> innerTileDims, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Method to get the `MemRefType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Returns the shape of the packed type. It is a shared helper that helps
// type inference methods in a way that ensures that they agree on which
// dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm = {});
+ ArrayRef<int64_t> outerDimsPerm = {},
+ ArrayRef<int64_t> extraPadTiles = {});
// Returns true if we have enough static information to catch undefined
// behavior when the tile size does not divide perfectly the dimension of
@@ -249,7 +258,8 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles);
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles = {});
// Same as above function but here dynamic dimensions are assumed
// to require padding.
@@ -257,11 +267,13 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles);
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles = {});
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles = {});
/// Build and return a new PackOp that is a clone of the current PackOp with
/// (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
@@ -325,6 +337,12 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
dimensions. If specified, it must have `n - k` elements. If specified, this
permutation is applied before combining any dimensions.
+ `drop_last_tiles` (optional) specifies how many full packed outer tiles to
+ drop for each tiled dimension before reconstructing the unpacked result. It
+ is indexed in the same order as `inner_dims_pos` / `inner_tiles`, must have
+ the same length, and defaults to all zeros when omitted. This can be used
+ to drop the extra pad tiles added by a previous `pack` operation with `extra_pad_tiles`.
+
Note, the unpack operation may drop any padding introduced by the pack
operation and hence the following holds
`NumElementsOf(source) >= NumElementsOf(result)`.
@@ -362,20 +380,23 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
TensorOrMemRef<[AnyType]>:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos, Variadic<Index>:$inner_tiles,
- DenseI64ArrayAttr:$static_inner_tiles);
+ DenseI64ArrayAttr:$static_inner_tiles,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$drop_last_tiles);
let results = (outs Optional<AnyRankedTensor>:$result);
let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
"ArrayRef<int64_t>":$innerDimsPos,
"ArrayRef<OpFoldResult>":$innerTiles,
- CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm)>
+ CArg<"ArrayRef<int64_t>", "{}">:$outerDimsPerm,
+ CArg<"ArrayRef<int64_t>", "{}">:$dropLastTiles)>
];
let extraClassDeclaration = commonExtraClassDeclaration # [{
static Value createDestinationTensor(OpBuilder &b, Location loc,
Value source, ArrayRef<OpFoldResult> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
+ ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> dropLastTiles = {});
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index e9698365765e7..8223ab78d1ef5 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5104,6 +5104,51 @@ static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
return staticTiles;
}
+static inline bool isEmptyOrZeroArray(ArrayRef<int64_t> values) {
+ return values.empty() ||
+ llvm::all_of(values, [](int64_t value) { return value == 0; });
+}
+
+static ArrayRef<int64_t> getExtraTrailingTiles(PackOp op) {
+ return op.getExtraPadTiles();
+}
+
+static ArrayRef<int64_t> getExtraTrailingTiles(UnPackOp op) {
+ return op.getDropLastTiles();
+}
+
+static void addExtraTrailingTiles(SmallVectorImpl<int64_t> &outerDims,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> adjustments) {
+ if (adjustments.empty())
+ return;
+ for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+ if (ShapedType::isDynamic(outerDims[dimPos]))
+ continue;
+ outerDims[dimPos] += adjustments[index];
+ }
+}
+
+static void addExtraTrailingTiles(OpBuilder &builder, Location loc,
+ SmallVectorImpl<OpFoldResult> &outerDims,
+ ArrayRef<int64_t> innerDimsPos,
+ ArrayRef<int64_t> adjustments) {
+ if (adjustments.empty())
+ return;
+ AffineExpr d0, c0;
+ bindDims(builder.getContext(), d0);
+ bindSymbols(builder.getContext(), c0);
+ auto addConstant = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1, d0 + c0,
+ builder.getContext());
+ for (auto [index, dimPos] : llvm::enumerate(innerDimsPos)) {
+ if (adjustments[index] == 0)
+ continue;
+ outerDims[dimPos] = affine::makeComposedFoldedAffineApply(
+ builder, loc, addConstant,
+ {outerDims[dimPos], builder.getIndexAttr(adjustments[index])});
+ }
+}
+
/// Returns true if `dimsPos` is invalid. It is invalid when:
/// a) It contains duplicate.
/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
@@ -5161,12 +5206,26 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
+ ArrayRef<int64_t> extraPadTiles = getExtraTrailingTiles(packOrUnPack);
if (isInvalidPackingPosSpecification(innerDimsPos, unpackedRank))
return op->emitError("invalid inner_dims_pos vector");
if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
return op->emitError("invalid outer_dims_perm vector");
if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
return op->emitError("outer_dims_perm must be a permutation or empty");
+ if (!extraPadTiles.empty() && extraPadTiles.size() != innerDimsPos.size()) {
+ return op->emitError() << (std::is_same<OpTy, PackOp>::value
+ ? "extra_pad_tiles"
+ : "drop_last_tiles")
+ << " must have the same number of entries as "
+ "inner_dims_pos";
+ }
+ if (llvm::any_of(extraPadTiles, [](int64_t value) { return value < 0; })) {
+ return op->emitError() << (std::is_same<OpTy, PackOp>::value
+ ? "extra_pad_tiles"
+ : "drop_last_tiles")
+ << " must contain only non-negative values";
+ }
// Tiling factors must be less than or equal to the input rank for pack (or
// output rank for unpack), and must match the number of `inner_dims_pos`.
@@ -5196,7 +5255,8 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
unpackedType.getShape(), packOrUnPack.getStaticTiles(),
- packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
+ packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm(),
+ extraPadTiles);
for (auto it : llvm::enumerate(llvm::zip(
packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
int64_t dimSize = std::get<0>(it.value());
@@ -5244,6 +5304,7 @@ struct PackOrUnPackTransposeResult {
SmallVector<int64_t> innerDimsPos;
SmallVector<OpFoldResult> innerTiles;
SmallVector<int64_t> outerDimsPerm;
+ SmallVector<int64_t> trailingTileAdjustments;
};
} // namespace
@@ -5261,6 +5322,8 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
metadata.innerTiles =
SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
+ metadata.trailingTileAdjustments =
+ llvm::to_vector(getExtraTrailingTiles(packOrUnPackOp));
int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
? packOrUnPackOp.getSourceRank()
: packOrUnPackOp.getDestRank();
@@ -5274,6 +5337,9 @@ commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
"invalid inner permutation");
applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
applyPermutationToVector(metadata.innerTiles, innerPermutation);
+ if (!metadata.trailingTileAdjustments.empty())
+ applyPermutationToVector(metadata.trailingTileAdjustments,
+ innerPermutation);
}
if (!outerPermutation.empty()) {
assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
@@ -5299,7 +5365,7 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> paddingValue;
SmallVector<Type> paddingValueType;
SmallVector<int64_t> staticTiles;
- DenseI64ArrayAttr innerDimsPos, outerDimsPerm;
+ DenseI64ArrayAttr innerDimsPos, outerDimsPerm, extraPadTiles;
Type sourceType, destType, resultType;
if (parser.parseOperand(source))
@@ -5352,6 +5418,24 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
for (auto val : staticTilesAttr.asArrayRef())
staticTiles.push_back(val);
+ if (succeeded(parser.parseOptionalKeyword("extra_pad_tiles"))) {
+ if (parser.parseEqual())
+ return failure();
+ SmallVector<int64_t> extraPadTilesVec;
+ if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, [&]() {
+ int64_t value;
+ if (parser.parseInteger(value))
+ return failure();
+ extraPadTilesVec.push_back(value);
+ return success();
+ }))
+ return failure();
+ if (!isEmptyOrZeroArray(extraPadTilesVec)) {
+ extraPadTiles =
+ parser.getBuilder().getDenseI64ArrayAttr(extraPadTilesVec);
+ }
+ }
+
if (parser.parseKeyword("into") || parser.parseOperand(dest))
return failure();
@@ -5395,6 +5479,8 @@ ParseResult PackOp::parse(OpAsmParser &parser, OperationState &result) {
result.addAttribute("inner_dims_pos", innerDimsPos);
if (outerDimsPerm)
result.addAttribute("outer_dims_perm", outerDimsPerm);
+ if (extraPadTiles)
+ result.addAttribute("extra_pad_tiles", extraPadTiles);
SmallVector<int32_t> segmentSizes = {
1, 1, static_cast<int32_t>(paddingValue.size()),
@@ -5429,11 +5515,18 @@ void PackOp::print(OpAsmPrinter &p) {
p << " inner_tiles = ";
printDynamicIndexList(p, *this, getInnerTiles(), getStaticInnerTilesAttr());
+ if (!isEmptyOrZeroArray(getExtraPadTiles())) {
+ p << " extra_pad_tiles = [";
+ llvm::interleaveComma(getExtraPadTiles(), p);
+ p << "]";
+ }
+
p << " into " << getDest();
p.printOptionalAttrDict((*this)->getAttrs(),
{"static_inner_tiles", "inner_dims_pos",
- "outer_dims_perm", "operandSegmentSizes"});
+ "outer_dims_perm", "extra_pad_tiles",
+ "operandSegmentSizes"});
p << " : " << getSource().getType();
p << " -> " << getDest().getType();
@@ -5443,7 +5536,8 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
Value dest, ArrayRef<int64_t> innerDimsPos,
ArrayRef<OpFoldResult> innerTiles,
std::optional<Value> paddingValue,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
assert(innerDimsPos.size() == innerTiles.size() &&
"number of tile sizes specified must match the specified number of "
"original dimensions to be tiled");
@@ -5455,7 +5549,10 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
outerDimsPerm.empty() ? nullptr
: builder.getDenseI64ArrayAttr(outerDimsPerm),
builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
- builder.getDenseI64ArrayAttr(staticTileSizes));
+ builder.getDenseI64ArrayAttr(staticTileSizes),
+ isEmptyOrZeroArray(extraPadTiles)
+ ? nullptr
+ : builder.getDenseI64ArrayAttr(extraPadTiles));
}
LogicalResult
@@ -5504,7 +5601,10 @@ bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles) {
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles) {
+ if (!isEmptyOrZeroArray(extraPadTiles))
+ return true;
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
@@ -5535,7 +5635,10 @@ bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
- ArrayRef<OpFoldResult> innerTiles) {
+ ArrayRef<OpFoldResult> innerTiles,
+ ArrayRef<int64_t> extraPadTiles) {
+ if (!isEmptyOrZeroArray(extraPadTiles))
+ return true;
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
@@ -5576,7 +5679,7 @@ LogicalResult PackOp::verify() {
if (!paddingValue &&
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
getDestType().getShape(), getOuterDimsPerm(),
- getMixedTiles())) {
+ getMixedTiles(), getExtraPadTiles())) {
return emitOpError(
"invalid tile factor or output size provided. Only full tiles are "
"supported when padding_value is not set");
@@ -5602,7 +5705,8 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
- ArrayRef<int64_t> outerDimsPerm) {
+ ArrayRef<int64_t> outerDimsPerm,
+ ArrayRef<int64_t> extraPadTiles) {
SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
@@ -5614,6 +5718,7 @@ SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
resultShape[tiledDim.value()...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/189049
More information about the Mlir-commits
mailing list