[Mlir-commits] [mlir] [mlir][linalg] Upstream PackOp/UnPackOp's generateScalarImplementation. (PR #182838)
Han-Chung Wang
llvmlistbot at llvm.org
Wed Feb 25 22:49:08 PST 2026
https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/182838
>From 457edcb6e06a1e9b0a0f12369759c30784ba9718 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 03:06:29 -0800
Subject: [PATCH 1/3] [mlir][linalg] Upstream PackOp/UnPackOp's
generateScalarImplementation.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 10 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 228 +++++++++++++-
.../lower-to-loops-using-interface.mlir | 277 ++++++++++++++++++
3 files changed, 508 insertions(+), 7 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 685096de5fbea..2c89dcbadf9d3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5041,8 +5041,10 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
"applies to only pack or unpack operations");
int64_t destRank = op.getDestRank();
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
- reifiedReturnShapes[0] =
- tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+ for (auto dim : llvm::seq<int64_t>(0, destRank)) {
+ reifiedReturnShapes[0][dim] =
+ createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
+ }
return success();
}
@@ -5440,8 +5442,6 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
LogicalResult
PackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- if (!hasPureTensorSemantics())
- return failure();
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
}
@@ -6164,8 +6164,6 @@ void UnPackOp::print(OpAsmPrinter &p) {
LogicalResult
UnPackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- if (!hasPureTensorSemantics())
- return failure();
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 292bfa70441fa..92d72c72b6840 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -725,7 +726,7 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
OpFoldResult zero = builder.getIndexAttr(0);
OpFoldResult one = builder.getIndexAttr(1);
ReifiedRankedShapedTypeDims resultShape;
- (void)reifyResultShapes(builder, op, resultShape);
+ (void)op.reifyResultShapes(builder, resultShape);
SmallVector<Range> loopBounds(rank);
for (auto dim : llvm::seq<int64_t>(0, rank)) {
loopBounds[dim].offset = zero;
@@ -744,6 +745,129 @@ static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
applyPermutationToVector<OpFoldResult>(sizes, permutation);
}
+/// Compute the permutation vector to interchange `elements` such that the
+/// elements at positions in `dimsPos` are moved to the positions `[0, ...,
+/// dimsPos.size())` in order.
+static SmallVector<int64_t>
+computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
+ SmallVector<int64_t> interchangeVector;
+ interchangeVector.reserve(dimsPos.size());
+ // First map dims and their position. For example, dims_pos = [2, 0] will map
+ // to:
+ // [
+ // [ key: 2, value: 0]
+ // [ key: 0, value: 1]
+ // ]
+ // where key is the idx in dims_pos while value its position in dims_pos.
+ DenseMap<int64_t, int64_t> dimsAndPosMapping;
+ for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++)
+ dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+
+ // Scan the position in order and insert the value in the map
+ // to compute the interchange vector.
+ for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
+ if (dimsAndPosMapping.count(dimsIdx))
+ interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
+ }
+ return interchangeVector;
+}
+
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+static SmallVector<T> interchange(ArrayRef<T> elements,
+ ArrayRef<int64_t> interchangeVector,
+ int offset = 0) {
+ SmallVector<T> vec = llvm::to_vector(elements);
+ for (auto [idx, val] : llvm::enumerate(interchangeVector))
+ vec[idx + offset] = elements[val + offset];
+ return vec;
+}
+
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+ OpBuilder &builder,
+ Location loc,
+ ValueRange ivs) {
+ // Note: `ivs` are already in the correct order, possibly interchanged based
+ // on `dims_pos`. However, connecting the loops with the access patterns is
+ // difficult - What is the relation between the position of the tile loop and
+ // the point loop? However, if we interchange `ivs` once more to go to the
+ // canonical blocking format: ABCabc, this connection becomes trivial: Each
+ // point loop is pointLoopsOffset + inputRank away from the tiled loop.
+ ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
+
+ SmallVector<Value> interchangedIvs = ivs;
+ SmallVector<int64_t> interchangeVector =
+ computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
+ interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+ /*offset=*/packOp.getSourceRank());
+ if (!dimsToOuterBlock.empty()) {
+ interchangeVector =
+ computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
+ interchangedIvs =
+ interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
+ }
+
+ SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
+ SmallVector<OpFoldResult> sourceIndices;
+ size_t pointLoopsOffset = 0;
+ int64_t sourceRank = packOp.getSourceRank();
+ for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+ if (dimAndTileMapping.count(dim)) {
+ AffineExpr i, j, tile;
+ bindDims(builder.getContext(), i, j);
+ bindSymbols(builder.getContext(), tile);
+ OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
+ builder, loc, i * tile + j,
+ ArrayRef<OpFoldResult>{
+ interchangedIvs[dim],
+ interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
+ dimAndTileMapping[dim]});
+ sourceIndices.push_back(sourceIndex);
+ ++pointLoopsOffset;
+ } else {
+ sourceIndices.push_back(interchangedIvs[dim]);
+ }
+ }
+
+ auto createLoad = [&]() -> Value {
+ return memref::LoadOp::create(
+ builder, loc, packOp.getSource(),
+ getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
+ };
+ Value scalar;
+ if (auto paddingValue = packOp.getPaddingValue()) {
+ ArithBuilder arithBuilder(builder, loc);
+ Value isInBounds;
+ for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+ Value idx =
+ getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
+ Value cond = arithBuilder.slt(
+ idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
+ isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
+ }
+ scalar = scf::IfOp::create(
+ builder, loc, isInBounds, /*thenBuilder=*/
+ [&](OpBuilder &b, Location l) {
+ scf::YieldOp::create(b, l, createLoad());
+ },
+ /*elseBuilder=*/
+ [&](OpBuilder &b, Location l) {
+ scf::YieldOp::create(b, l, paddingValue);
+ })
+ .getResult(0);
+ } else {
+ scalar = createLoad();
+ }
+
+ memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
+}
+
struct PackOpTiling
: public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
@@ -899,6 +1023,52 @@ struct PackOpTiling
return tilingResult.value();
}
+ LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+ Location loc,
+ ValueRange ivs) const {
+ auto packOp = cast<PackOp>(op);
+ assert(packOp.hasPureBufferSemantics() &&
+ "expected operation to have buffer semantics");
+ OpBuilder::InsertionGuard g(builder);
+ // The `ivs` already represent the position into the output tensor for the
+ // non data-tile dimensions.
+ SmallVector<Value> ivVec = llvm::to_vector(ivs);
+
+ // Get output shape - for memrefs, get dimensions from dest directly.
+ SmallVector<OpFoldResult> outputShape;
+ Value dest = packOp.getDest();
+ for (auto dim : llvm::seq<int64_t>(0, packOp.getDestRank()))
+ outputShape.push_back(createOrFoldDimOp(builder, loc, dest, dim));
+
+ // Generate the loops that iterate over the data tile.
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
+ // All loops except the innermost are simple loops that just iterate
+ // over the tile dimensions.
+ for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
+ packOp.getDestRank() - 1)) {
+ Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+ outputShape[dataTileDim]);
+ scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
+ builder.setInsertionPointToStart(loop.getBody());
+ ivVec.push_back(loop.getInductionVar());
+ }
+ // The body of the innermost loops does the actual data movement.
+ scf::ForOp::create(
+ builder, loc, zero,
+ getValueOrCreateConstantIndexOp(builder, loc, outputShape.back()), one,
+ ValueRange{},
+ [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+ ValueRange regionIterArgs) {
+ ivVec.push_back(iv);
+ generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
+ ivVec);
+ scf::YieldOp::create(bodyBuilder, bodyLoc);
+ });
+ return success();
+ }
+
/// Method to return the position of iteration domain tile computed by the
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
/// `resultSizes` only cover outer dimensions.
@@ -1292,6 +1462,62 @@ struct UnPackOpTiling
return tilingResult.value();
}
+ LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+ Location loc,
+ ValueRange ivs) const {
+ auto unpackOp = cast<UnPackOp>(op);
+ assert(unpackOp.hasPureBufferSemantics() &&
+ "expected operation to have buffer semantics");
+ assert(ivs.size() == unpackOp.getDestRank() &&
+ "number of ivs must match the rank of the output tensor");
+ OpBuilder::InsertionGuard g(builder);
+
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ unpackOp.getDimAndTileMapping();
+ // untiled loops and tile loops induction variables.
+ SmallVector<Value> inputIvs;
+ // point loops induction variables.
+ SmallVector<Value> inputIvsPointLoops;
+ inputIvs.reserve(unpackOp.getDestRank());
+ inputIvsPointLoops.reserve(dimAndTileMapping.size());
+ for (auto dim : llvm::seq<int64_t>(0, unpackOp.getDestRank())) {
+ if (dimAndTileMapping.count(dim)) {
+ affine::DivModValue divMod =
+ affine::getDivMod(builder, loc, ivs[dim],
+ getValueOrCreateConstantIndexOp(
+ builder, loc, dimAndTileMapping[dim]));
+ inputIvsPointLoops.push_back(divMod.remainder);
+ inputIvs.push_back(divMod.quotient);
+ } else {
+ inputIvs.push_back(ivs[dim]);
+ }
+ }
+
+ // TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
+ // `inputIvsPointLoops` and `inputIvs`.
+ assert(inputIvsPointLoops.size() + inputIvs.size() ==
+ unpackOp.getSourceRank() &&
+ "expect same number of induction variables equals to input rank");
+ // interchange the point loops induction variables based on `inner_dim_pos`.
+ ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
+ SmallVector<int64_t> interchangeVector =
+ computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
+ SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
+ interchangedInputIvsPointLoops = interchange<Value>(
+ interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
+ // interchange the tiled loops induction variables based on
+ // `outer_dims_perm`.
+ ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
+ if (!outerDims.empty())
+ inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
+
+ llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
+ Value scalar =
+ memref::LoadOp::create(builder, loc, unpackOp.getSource(), inputIvs);
+ memref::StoreOp::create(builder, loc, scalar, unpackOp.getDest(), ivs);
+ return success();
+ }
+
/// Method to return the position of iteration domain tile computed by the
/// tiled operation.
LogicalResult getIterationDomainTileFromOperandTiles(
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index aa8882d21698c..7fe37a659f5a5 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -400,3 +400,280 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
// CHECK: %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
// CHECK: memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
+
+// -----
+
+func.func @NC_to_NCnc(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1
+ : memref<128x256xf32> -> memref<4x8x32x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %pack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-LABEL: func @NC_to_NCnc(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[SRC_I:.*]] = affine.apply #[[$MAP]](%[[I]], %[[K]])
+// CHECK-DAG: %[[SRC_J:.*]] = affine.apply #[[$MAP]](%[[J]], %[[L]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[SRC_I]], %[[SRC_J]]] : memref<128x256xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x8x32x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @NC_to_NCnc_pad(%arg0: memref<13x15xf32>, %arg1: memref<2x8x8x2xf32>, %arg2: f32) {
+ linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+ into %arg1 : memref<13x15xf32> -> memref<2x8x8x2xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %pack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+// CHECK-LABEL: func @NC_to_NCnc_pad(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<13x15xf32>
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<2x8x8x2xf32>
+// CHECK-SAME: %[[PAD:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C13:.*]] = arith.constant 13 : index
+// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-DAG: %[[SRC_I:.*]] = affine.apply #[[$MAP0]](%[[I]], %[[K]])
+// CHECK-DAG: %[[SRC_J:.*]] = affine.apply #[[$MAP1]](%[[J]], %[[L]])
+// CHECK: %[[BOUND_I:.*]] = arith.cmpi slt, %[[SRC_I]], %[[C13]] : index
+// CHECK: %[[BOUND_J:.*]] = arith.cmpi slt, %[[SRC_J]], %[[C15]] : index
+// CHECK: %[[IN_BOUNDS:.*]] = arith.andi %[[BOUND_I]], %[[BOUND_J]] : i1
+// CHECK: %[[VAL:.*]] = scf.if %[[IN_BOUNDS]] -> (f32) {
+// CHECK: %[[LOAD:.*]] = memref.load %[[SRC]][%[[SRC_I]], %[[SRC_J]]] : memref<13x15xf32>
+// CHECK: scf.yield %[[LOAD]]
+// CHECK: } else {
+// CHECK: scf.yield %[[PAD]]
+// CHECK: }
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<2x8x8x2xf32>
+
+// -----
+
+func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1
+ : memref<128x256xf32> -> memref<4x8x32x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %pack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-LABEL: func @KC_to_KCck(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[SRC_C:.*]] = affine.apply #[[$MAP]](%[[C]], %[[c]])
+// CHECK-DAG: %[[SRC_K:.*]] = affine.apply #[[$MAP]](%[[K]], %[[k]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[SRC_K]], %[[SRC_C]]] : memref<128x256xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[K]], %[[C]], %[[c]], %[[k]]] : memref<4x8x32x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @NCnc_to_NC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+ linalg.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0
+ : memref<4x8x32x32xf32> -> memref<128x256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[$MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-LABEL: func @NCnc_to_NC(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK-DAG: %[[FLOOR_I:.*]] = affine.apply #[[$MAP_FLOOR]](%[[I]])
+// CHECK-DAG: %[[FLOOR_J:.*]] = affine.apply #[[$MAP_FLOOR]](%[[J]])
+// CHECK-DAG: %[[MOD_I:.*]] = affine.apply #[[$MAP_MOD]](%[[I]])
+// CHECK-DAG: %[[MOD_J:.*]] = affine.apply #[[$MAP_MOD]](%[[J]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_I]], %[[FLOOR_J]], %[[MOD_I]], %[[MOD_J]]] : memref<4x8x32x32xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]]] : memref<128x256xf32>
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCck_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+ linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg0
+ : memref<4x8x32x32xf32> -> memref<128x256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[$MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-LABEL: func @KCck_to_KC(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK-DAG: %[[FLOOR_I:.*]] = affine.apply #[[$MAP_FLOOR]](%[[I]])
+// CHECK-DAG: %[[FLOOR_J:.*]] = affine.apply #[[$MAP_FLOOR]](%[[J]])
+// CHECK-DAG: %[[MOD_I:.*]] = affine.apply #[[$MAP_MOD]](%[[I]])
+// CHECK-DAG: %[[MOD_J:.*]] = affine.apply #[[$MAP_MOD]](%[[J]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_I]], %[[FLOOR_J]], %[[MOD_J]], %[[MOD_I]]] : memref<4x8x32x32xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]]] : memref<128x256xf32>
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KC_to_CKkc(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) {
+ linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+ into %arg1 : memref<128x256xf32> -> memref<32x4x32x8xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %pack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-LABEL: func @KC_to_CKkc(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK-DAG: %[[SRC_K:.*]] = affine.apply #[[$MAP0]](%[[K]], %[[k]])
+// CHECK-DAG: %[[SRC_C:.*]] = affine.apply #[[$MAP1]](%[[C]], %[[c]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[SRC_K]], %[[SRC_C]]] : memref<128x256xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[C]], %[[K]], %[[k]], %[[c]]] : memref<32x4x32x8xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @CKkc_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) {
+ linalg.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+ into %arg0 : memref<32x4x32x8xf32> -> memref<128x256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @CKkc_to_KC(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK-DAG: %[[FLOOR_K:.*]] = affine.apply #[[$MAP0]](%[[K]])
+// CHECK-DAG: %[[MOD_K:.*]] = affine.apply #[[$MAP1]](%[[K]])
+// CHECK-DAG: %[[FLOOR_C:.*]] = affine.apply #[[$MAP2]](%[[C]])
+// CHECK-DAG: %[[MOD_C:.*]] = affine.apply #[[$MAP3]](%[[C]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_C]], %[[FLOOR_K]], %[[MOD_K]], %[[MOD_C]]] : memref<32x4x32x8xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[K]], %[[C]]] : memref<128x256xf32>
+// CHECK: }
+// CHECK: }
>From 7d7edb45e9a1e9639942be662f449baff6209365 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 03:42:39 -0800
Subject: [PATCH 2/3] Improve code quality and upstream more tests.
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 24 ++--
.../lower-to-loops-using-interface.mlir | 134 ++++++++++++++++++
3 files changed, 151 insertions(+), 10 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2c89dcbadf9d3..bfc03cc7436df 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5041,10 +5041,9 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
"applies to only pack or unpack operations");
int64_t destRank = op.getDestRank();
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
- for (auto dim : llvm::seq<int64_t>(0, destRank)) {
+ for (auto dim : llvm::seq<int64_t>(0, destRank))
reifiedReturnShapes[0][dim] =
createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
- }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 92d72c72b6840..aae2e134108c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -772,8 +772,18 @@ computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
return interchangeVector;
}
-/// Returns a vector that interchanges `elements` starting at offset `offset`
-/// based on the indexes in `interchangeVector`.
+/// Permute the elements of `vec` starting at position `offset` according to
+/// `interchangeVector`. The permutation maps position `i` in the permuted range
+/// to position `interchangeVector[i]` in the original range. Elements before
+/// `offset` are unchanged.
+///
+/// Example: interchange([a, b, c, d, e], [2, 0, 1], offset=2)
+/// returns [a, b, e, c, d] (permutes the suffix [c, d, e])
+///
+/// Note: This is similar to `applyPermutationToVector` but supports an offset
+/// for permuting a suffix of the vector. It is only used for pack/unpack scalar
+/// implementation where we need to permute inner tile dimensions which are
+/// stored at the end of the index vector.
template <typename T>
static SmallVector<T> interchange(ArrayRef<T> elements,
ArrayRef<int64_t> interchangeVector,
@@ -810,15 +820,13 @@ static void generatePackOpScalarImplementationBody(PackOp packOp,
interchangedIvs =
interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
}
-
- SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
packOp.getDimAndTileMapping();
SmallVector<OpFoldResult> sourceIndices;
size_t pointLoopsOffset = 0;
int64_t sourceRank = packOp.getSourceRank();
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
- if (dimAndTileMapping.count(dim)) {
+ if (dimAndTileMapping.contains(dim)) {
AffineExpr i, j, tile;
bindDims(builder.getContext(), i, j);
bindSymbols(builder.getContext(), tile);
@@ -1032,7 +1040,7 @@ struct PackOpTiling
OpBuilder::InsertionGuard g(builder);
// The `ivs` already represent the position into the output tensor for the
// non data-tile dimensions.
- SmallVector<Value> ivVec = llvm::to_vector(ivs);
+ SmallVector<Value> ivVec(ivs);
// Get output shape - for memrefs, get dimensions from dest directly.
SmallVector<OpFoldResult> outputShape;
@@ -1498,14 +1506,14 @@ struct UnPackOpTiling
assert(inputIvsPointLoops.size() + inputIvs.size() ==
unpackOp.getSourceRank() &&
"expect same number of induction variables equals to input rank");
- // interchange the point loops induction variables based on `inner_dim_pos`.
+ // Interchange the point loops induction variables based on `inner_dim_pos`.
ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
SmallVector<int64_t> interchangeVector =
computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
interchangedInputIvsPointLoops = interchange<Value>(
interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
- // interchange the tiled loops induction variables based on
+ // Interchange the tiled loops induction variables based on
// `outer_dims_perm`.
ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
if (!outerDims.empty())
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 7fe37a659f5a5..ec9b606491910 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -677,3 +677,137 @@ module attributes {transform.with_named_sequence} {
// CHECK: memref.store %[[VAL]], %[[DEST]][%[[K]], %[[C]]] : memref<128x256xf32>
// CHECK: }
// CHECK: }
+
+// -----
+
+func.func @KCRSsr_to_KCRS_static(%arg0: memref<1x1x128x64xf32>, %arg1: memref<1x1x4x8x8x32xf32>) {
+ linalg.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg0
+ : memref<1x1x4x8x8x32xf32> -> memref<1x1x128x64xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[$MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-DAG: #[[$MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG: #[[$MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @KCRSsr_to_KCRS_static(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<1x1x128x64xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<1x1x4x8x8x32xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C64:.*]] = arith.constant 64 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
+// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[$MAP_FLOORK]](%[[K]])
+// CHECK-DAG: %[[FLOORL:.*]] = affine.apply #[[$MAP_FLOORL]](%[[L]])
+// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[$MAP_MODK]](%[[K]])
+// CHECK-DAG: %[[MODL:.*]] = affine.apply #[[$MAP_MODL]](%[[L]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<1x1x4x8x8x32xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[C0]], %[[C0]], %[[K]], %[[L]]] : memref<1x1x128x64xf32>
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCRSsr_to_KCRS_dynamic_with_static_inner_tiles(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x32xf32>) {
+ linalg.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg0
+ : memref<?x?x?x?x8x32xf32> -> memref<?x?x?x?xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG: #[[$MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-DAG: #[[$MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG: #[[$MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @KCRSsr_to_KCRS_dynamic_with_static_inner_tiles(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<?x?x?x?x8x32xf32>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[UBI:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x?x?x?xf32>
+// CHECK-DAG: %[[UBJ:.*]] = memref.dim %[[DEST]], %[[C1]] : memref<?x?x?x?xf32>
+// CHECK-DAG: %[[UBK:.*]] = memref.dim %[[DEST]], %[[C2]] : memref<?x?x?x?xf32>
+// CHECK-DAG: %[[UBL:.*]] = memref.dim %[[DEST]], %[[C3]] : memref<?x?x?x?xf32>
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[UBI]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[UBJ]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[UBK]] step %[[C1]] {
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[UBL]] step %[[C1]] {
+// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[$MAP_FLOORK]](%[[K]])
+// CHECK-DAG: %[[FLOORL:.*]] = affine.apply #[[$MAP_FLOORL]](%[[L]])
+// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[$MAP_MODK]](%[[K]])
+// CHECK-DAG: %[[MODL:.*]] = affine.apply #[[$MAP_MODL]](%[[L]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<?x?x?x?x8x32xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCRSsr_to_KCRS_dynamic_with_dynamic_inner_tiles(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x?xf32>, %block : index) {
+ linalg.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, %block] into %arg0
+ : memref<?x?x?x?x8x?xf32> -> memref<?x?x?x?xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %unpack
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[$MAP_FLOORK:.*]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
+// CHECK-DAG: #[[$MAP_MODK:.*]] = affine_map<(d0)[s0] -> (d0 mod s0)>
+// CHECK-DAG: #[[$MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG: #[[$MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @KCRSsr_to_KCRS_dynamic_with_dynamic_inner_tiles(
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: memref<?x?x?x?x8x?xf32>
+// CHECK-SAME: %[[TILE:[a-zA-Z0-9]+]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[UBI:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x?x?x?xf32>
+// CHECK-DAG: %[[UBJ:.*]] = memref.dim %[[DEST]], %[[C1]] : memref<?x?x?x?xf32>
+// CHECK-DAG: %[[UBK:.*]] = memref.dim %[[DEST]], %[[C2]] : memref<?x?x?x?xf32>
+// CHECK-DAG: %[[UBL:.*]] = memref.dim %[[DEST]], %[[C3]] : memref<?x?x?x?xf32>
+// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[UBI]] step %[[C1]] {
+// CHECK: scf.for %[[J:.*]] = %[[C0]] to %[[UBJ]] step %[[C1]] {
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[UBK]] step %[[C1]] {
+// CHECK: scf.for %[[L:.*]] = %[[C0]] to %[[UBL]] step %[[C1]] {
+// CHECK-DAG: %[[FLOORK:.*]] = affine.apply #[[$MAP_FLOORK]](%[[K]])[%[[TILE]]]
+// CHECK-DAG: %[[FLOORL:.*]] = affine.apply #[[$MAP_FLOORL]](%[[L]])
+// CHECK-DAG: %[[MODK:.*]] = affine.apply #[[$MAP_MODK]](%[[K]])[%[[TILE]]]
+// CHECK-DAG: %[[MODL:.*]] = affine.apply #[[$MAP_MODL]](%[[L]])
+// CHECK: %[[VAL:.*]] = memref.load %[[SRC]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<?x?x?x?x8x?xf32>
+// CHECK: memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<?x?x?x?xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
>From 919ab63302134bb646e0b2f50e483f2753fb9ec0 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 25 Feb 2026 22:35:48 -0800
Subject: [PATCH 3/3] address comments
Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
.../lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index aae2e134108c3..fd9c8a7a8eba7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -1038,8 +1038,8 @@ struct PackOpTiling
assert(packOp.hasPureBufferSemantics() &&
"expected operation to have buffer semantics");
OpBuilder::InsertionGuard g(builder);
- // The `ivs` already represent the position into the output tensor for the
- // non data-tile dimensions.
+ // The `ivs` already represent the position into the output for the non
+ // data-tile dimensions.
SmallVector<Value> ivVec(ivs);
// Get output shape - for memrefs, get dimensions from dest directly.
@@ -1482,9 +1482,9 @@ struct UnPackOpTiling
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
unpackOp.getDimAndTileMapping();
- // untiled loops and tile loops induction variables.
+ // Untiled loops and tile loops induction variables.
SmallVector<Value> inputIvs;
- // point loops induction variables.
+ // Point loops induction variables.
SmallVector<Value> inputIvsPointLoops;
inputIvs.reserve(unpackOp.getDestRank());
inputIvsPointLoops.reserve(dimAndTileMapping.size());
More information about the Mlir-commits
mailing list