[Mlir-commits] [mlir] [mlir][linalg] Upstream PackOp/UnPackOp's generateScalarImplementation. (PR #182838)

Han-Chung Wang llvmlistbot at llvm.org
Wed Feb 25 22:49:08 PST 2026


https://github.com/hanhanW updated https://github.com/llvm/llvm-project/pull/182838

>From 457edcb6e06a1e9b0a0f12369759c30784ba9718 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 03:06:29 -0800
Subject: [PATCH 1/3] [mlir][linalg] Upstream PackOp/UnPackOp's
 generateScalarImplementation.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  10 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 228 +++++++++++++-
 .../lower-to-loops-using-interface.mlir       | 277 ++++++++++++++++++
 3 files changed, 508 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 685096de5fbea..2c89dcbadf9d3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5041,8 +5041,10 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
                 "applies to only pack or unpack operations");
   int64_t destRank = op.getDestRank();
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
-  reifiedReturnShapes[0] =
-      tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+  for (auto dim : llvm::seq<int64_t>(0, destRank)) {
+    reifiedReturnShapes[0][dim] =
+        createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
+  }
   return success();
 }
 
@@ -5440,8 +5442,6 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
 LogicalResult
 PackOp::reifyResultShapes(OpBuilder &builder,
                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  if (!hasPureTensorSemantics())
-    return failure();
   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
 }
 
@@ -6164,8 +6164,6 @@ void UnPackOp::print(OpAsmPrinter &p) {
 LogicalResult
 UnPackOp::reifyResultShapes(OpBuilder &builder,
                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  if (!hasPureTensorSemantics())
-    return failure();
   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 292bfa70441fa..92d72c72b6840 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -725,7 +726,7 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
   OpFoldResult zero = builder.getIndexAttr(0);
   OpFoldResult one = builder.getIndexAttr(1);
   ReifiedRankedShapedTypeDims resultShape;
-  (void)reifyResultShapes(builder, op, resultShape);
+  (void)op.reifyResultShapes(builder, resultShape);
   SmallVector<Range> loopBounds(rank);
   for (auto dim : llvm::seq<int64_t>(0, rank)) {
     loopBounds[dim].offset = zero;
@@ -744,6 +745,129 @@ static void applyPermToRange(SmallVector<OpFoldResult> &offsets,
   applyPermutationToVector<OpFoldResult>(sizes, permutation);
 }
 
+/// Compute the permutation vector to interchange `elements` such that the
+/// elements at positions in `dimsPos` are moved to the positions `[0, ...,
+/// dimsPos.size())` in order.
+static SmallVector<int64_t>
+computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
+  SmallVector<int64_t> interchangeVector;
+  interchangeVector.reserve(dimsPos.size());
+  // First map dims and their position. For example, dims_pos = [2, 0] will map
+  // to:
+  // [
+  //  [ key: 2, value: 0]
+  //  [ key: 0, value: 1]
+  // ]
+  // where key is the idx in dims_pos while value its position in dims_pos.
+  DenseMap<int64_t, int64_t> dimsAndPosMapping;
+  for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++)
+    dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+
+  // Scan the position in order and insert the value in the map
+  // to compute the interchange vector.
+  for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
+    if (dimsAndPosMapping.count(dimsIdx))
+      interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
+  }
+  return interchangeVector;
+}
+
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+static SmallVector<T> interchange(ArrayRef<T> elements,
+                                  ArrayRef<int64_t> interchangeVector,
+                                  int offset = 0) {
+  SmallVector<T> vec = llvm::to_vector(elements);
+  for (auto [idx, val] : llvm::enumerate(interchangeVector))
+    vec[idx + offset] = elements[val + offset];
+  return vec;
+}
+
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+                                                   OpBuilder &builder,
+                                                   Location loc,
+                                                   ValueRange ivs) {
+  // Note: `ivs` are already in the correct order, possibly interchanged based
+  // on `dims_pos`. However, connecting the loops with the access patterns is
+  // difficult - What is the relation between the position of the tile loop and
+  // the point loop? However, if we interchange `ivs` once more to go to the
+  // canonical blocking format: ABCabc, this connection becomes trivial: Each
+  // point loop is pointLoopsOffset + inputRank away from the tiled loop.
+  ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
+
+  SmallVector<Value> interchangedIvs = ivs;
+  SmallVector<int64_t> interchangeVector =
+      computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
+  interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+                                       /*offset=*/packOp.getSourceRank());
+  if (!dimsToOuterBlock.empty()) {
+    interchangeVector =
+        computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
+    interchangedIvs =
+        interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
+  }
+
+  SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      packOp.getDimAndTileMapping();
+  SmallVector<OpFoldResult> sourceIndices;
+  size_t pointLoopsOffset = 0;
+  int64_t sourceRank = packOp.getSourceRank();
+  for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+    if (dimAndTileMapping.count(dim)) {
+      AffineExpr i, j, tile;
+      bindDims(builder.getContext(), i, j);
+      bindSymbols(builder.getContext(), tile);
+      OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
+          builder, loc, i * tile + j,
+          ArrayRef<OpFoldResult>{
+              interchangedIvs[dim],
+              interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
+              dimAndTileMapping[dim]});
+      sourceIndices.push_back(sourceIndex);
+      ++pointLoopsOffset;
+    } else {
+      sourceIndices.push_back(interchangedIvs[dim]);
+    }
+  }
+
+  auto createLoad = [&]() -> Value {
+    return memref::LoadOp::create(
+        builder, loc, packOp.getSource(),
+        getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
+  };
+  Value scalar;
+  if (auto paddingValue = packOp.getPaddingValue()) {
+    ArithBuilder arithBuilder(builder, loc);
+    Value isInBounds;
+    for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
+      Value idx =
+          getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
+      Value cond = arithBuilder.slt(
+          idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
+      isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
+    }
+    scalar = scf::IfOp::create(
+                 builder, loc, isInBounds, /*thenBuilder=*/
+                 [&](OpBuilder &b, Location l) {
+                   scf::YieldOp::create(b, l, createLoad());
+                 },
+                 /*elseBuilder=*/
+                 [&](OpBuilder &b, Location l) {
+                   scf::YieldOp::create(b, l, paddingValue);
+                 })
+                 .getResult(0);
+  } else {
+    scalar = createLoad();
+  }
+
+  memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
+}
+
 struct PackOpTiling
     : public TilingInterface::ExternalModel<PackOpTiling, linalg::PackOp> {
 
@@ -899,6 +1023,52 @@ struct PackOpTiling
     return tilingResult.value();
   }
 
+  LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+                                             Location loc,
+                                             ValueRange ivs) const {
+    auto packOp = cast<PackOp>(op);
+    assert(packOp.hasPureBufferSemantics() &&
+           "expected operation to have buffer semantics");
+    OpBuilder::InsertionGuard g(builder);
+    // The `ivs` already represent the position into the output tensor for the
+    // non data-tile dimensions.
+    SmallVector<Value> ivVec = llvm::to_vector(ivs);
+
+    // Get output shape - for memrefs, get dimensions from dest directly.
+    SmallVector<OpFoldResult> outputShape;
+    Value dest = packOp.getDest();
+    for (auto dim : llvm::seq<int64_t>(0, packOp.getDestRank()))
+      outputShape.push_back(createOrFoldDimOp(builder, loc, dest, dim));
+
+    // Generate the loops that iterate over the data tile.
+    Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+    Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
+    // All loops except the innermost are simple loops that just iterate
+    // over the tile dimensions.
+    for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
+                                                packOp.getDestRank() - 1)) {
+      Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+                                                 outputShape[dataTileDim]);
+      scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
+      builder.setInsertionPointToStart(loop.getBody());
+      ivVec.push_back(loop.getInductionVar());
+    }
+    // The body of the innermost loops does the actual data movement.
+    scf::ForOp::create(
+        builder, loc, zero,
+        getValueOrCreateConstantIndexOp(builder, loc, outputShape.back()), one,
+        ValueRange{},
+        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+            ValueRange regionIterArgs) {
+          ivVec.push_back(iv);
+          generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
+                                                 ivVec);
+          scf::YieldOp::create(bodyBuilder, bodyLoc);
+        });
+    return success();
+  }
+
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
   /// `resultSizes` only cover outer dimensions.
@@ -1292,6 +1462,62 @@ struct UnPackOpTiling
     return tilingResult.value();
   }
 
+  LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+                                             Location loc,
+                                             ValueRange ivs) const {
+    auto unpackOp = cast<UnPackOp>(op);
+    assert(unpackOp.hasPureBufferSemantics() &&
+           "expected operation to have buffer semantics");
+    assert(ivs.size() == unpackOp.getDestRank() &&
+           "number of ivs must match the rank of the output tensor");
+    OpBuilder::InsertionGuard g(builder);
+
+    DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+        unpackOp.getDimAndTileMapping();
+    // untiled loops and tile loops induction variables.
+    SmallVector<Value> inputIvs;
+    // point loops induction variables.
+    SmallVector<Value> inputIvsPointLoops;
+    inputIvs.reserve(unpackOp.getDestRank());
+    inputIvsPointLoops.reserve(dimAndTileMapping.size());
+    for (auto dim : llvm::seq<int64_t>(0, unpackOp.getDestRank())) {
+      if (dimAndTileMapping.count(dim)) {
+        affine::DivModValue divMod =
+            affine::getDivMod(builder, loc, ivs[dim],
+                              getValueOrCreateConstantIndexOp(
+                                  builder, loc, dimAndTileMapping[dim]));
+        inputIvsPointLoops.push_back(divMod.remainder);
+        inputIvs.push_back(divMod.quotient);
+      } else {
+        inputIvs.push_back(ivs[dim]);
+      }
+    }
+
+    // TODO: (lorenzo) simplify the logic a bit. There is `ivs`,
+    // `inputIvsPointLoops` and `inputIvs`.
+    assert(inputIvsPointLoops.size() + inputIvs.size() ==
+               unpackOp.getSourceRank() &&
+           "expect same number of induction variables equals to input rank");
+    // interchange the point loops induction variables based on `inner_dim_pos`.
+    ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
+    SmallVector<int64_t> interchangeVector =
+        computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
+    SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
+    interchangedInputIvsPointLoops = interchange<Value>(
+        interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
+    // interchange the tiled loops induction variables based on
+    // `outer_dims_perm`.
+    ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
+    if (!outerDims.empty())
+      inputIvs = interchange<Value>(inputIvs, outerDims, /*offset=*/0);
+
+    llvm::append_range(inputIvs, interchangedInputIvsPointLoops);
+    Value scalar =
+        memref::LoadOp::create(builder, loc, unpackOp.getSource(), inputIvs);
+    memref::StoreOp::create(builder, loc, scalar, unpackOp.getDest(), ivs);
+    return success();
+  }
+
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation.
   LogicalResult getIterationDomainTileFromOperandTiles(
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index aa8882d21698c..7fe37a659f5a5 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -400,3 +400,280 @@ module attributes {transform.with_named_sequence} {
 // CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
 // CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
 // CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
+
+// -----
+
+func.func @NC_to_NCnc(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1
+    : memref<128x256xf32> -> memref<4x8x32x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %pack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-LABEL: func @NC_to_NCnc(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[L:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[SRC_I:.*]] = affine.apply #[[$MAP]](%[[I]], %[[K]])
+// CHECK-DAG:             %[[SRC_J:.*]] = affine.apply #[[$MAP]](%[[J]], %[[L]])
+// CHECK:                 %[[VAL:.*]] = memref.load %[[SRC]][%[[SRC_I]], %[[SRC_J]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x8x32x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @NC_to_NCnc_pad(%arg0: memref<13x15xf32>, %arg1: memref<2x8x8x2xf32>, %arg2: f32) {
+  linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2]
+    into %arg1 : memref<13x15xf32> -> memref<2x8x8x2xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %pack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG:   #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+// CHECK-LABEL: func @NC_to_NCnc_pad(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<13x15xf32>
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<2x8x8x2xf32>
+// CHECK-SAME:    %[[PAD:[a-zA-Z0-9]+]]: f32
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C13:.*]] = arith.constant 13 : index
+// CHECK-DAG:     %[[C15:.*]] = arith.constant 15 : index
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[K:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:               scf.for %[[L:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-DAG:             %[[SRC_I:.*]] = affine.apply #[[$MAP0]](%[[I]], %[[K]])
+// CHECK-DAG:             %[[SRC_J:.*]] = affine.apply #[[$MAP1]](%[[J]], %[[L]])
+// CHECK:                 %[[BOUND_I:.*]] = arith.cmpi slt, %[[SRC_I]], %[[C13]] : index
+// CHECK:                 %[[BOUND_J:.*]] = arith.cmpi slt, %[[SRC_J]], %[[C15]] : index
+// CHECK:                 %[[IN_BOUNDS:.*]] = arith.andi %[[BOUND_I]], %[[BOUND_J]] : i1
+// CHECK:                 %[[VAL:.*]] = scf.if %[[IN_BOUNDS]] -> (f32) {
+// CHECK:                   %[[LOAD:.*]] = memref.load %[[SRC]][%[[SRC_I]], %[[SRC_J]]] : memref<13x15xf32>
+// CHECK:                   scf.yield %[[LOAD]]
+// CHECK:                 } else {
+// CHECK:                   scf.yield %[[PAD]]
+// CHECK:                 }
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<2x8x8x2xf32>
+
+// -----
+
+func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1
+    : memref<128x256xf32> -> memref<4x8x32x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %pack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[$MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-LABEL: func @KC_to_KCck(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[SRC_C:.*]] = affine.apply #[[$MAP]](%[[C]], %[[c]])
+// CHECK-DAG:             %[[SRC_K:.*]] = affine.apply #[[$MAP]](%[[K]], %[[k]])
+// CHECK:                 %[[VAL:.*]] = memref.load %[[SRC]][%[[SRC_K]], %[[SRC_C]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[K]], %[[C]], %[[c]], %[[k]]] : memref<4x8x32x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @NCnc_to_NC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.unpack %arg1 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg0
+    : memref<4x8x32x32xf32> -> memref<128x256xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[$MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-LABEL: func @NCnc_to_NC(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C256:.*]] = arith.constant 256 : index
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK-DAG:         %[[FLOOR_I:.*]] = affine.apply #[[$MAP_FLOOR]](%[[I]])
+// CHECK-DAG:         %[[FLOOR_J:.*]] = affine.apply #[[$MAP_FLOOR]](%[[J]])
+// CHECK-DAG:         %[[MOD_I:.*]] = affine.apply #[[$MAP_MOD]](%[[I]])
+// CHECK-DAG:         %[[MOD_J:.*]] = affine.apply #[[$MAP_MOD]](%[[J]])
+// CHECK:             %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_I]], %[[FLOOR_J]], %[[MOD_I]], %[[MOD_J]]] : memref<4x8x32x32xf32>
+// CHECK:             memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]]] : memref<128x256xf32>
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCck_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.unpack %arg1 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg0
+    : memref<4x8x32x32xf32> -> memref<128x256xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP_FLOOR:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[$MAP_MOD:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-LABEL: func @KCck_to_KC(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<4x8x32x32xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C256:.*]] = arith.constant 256 : index
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK-DAG:         %[[FLOOR_I:.*]] = affine.apply #[[$MAP_FLOOR]](%[[I]])
+// CHECK-DAG:         %[[FLOOR_J:.*]] = affine.apply #[[$MAP_FLOOR]](%[[J]])
+// CHECK-DAG:         %[[MOD_I:.*]] = affine.apply #[[$MAP_MOD]](%[[I]])
+// CHECK-DAG:         %[[MOD_J:.*]] = affine.apply #[[$MAP_MOD]](%[[J]])
+// CHECK:             %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_I]], %[[FLOOR_J]], %[[MOD_J]], %[[MOD_I]]] : memref<4x8x32x32xf32>
+// CHECK:             memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]]] : memref<128x256xf32>
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KC_to_CKkc(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) {
+  linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+    into %arg1 : memref<128x256xf32> -> memref<32x4x32x8xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %pack = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %pack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG:   #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-LABEL: func @KC_to_CKkc(
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[C:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:           scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:             scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[c:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK-DAG:             %[[SRC_K:.*]] = affine.apply #[[$MAP0]](%[[K]], %[[k]])
+// CHECK-DAG:             %[[SRC_C:.*]] = affine.apply #[[$MAP1]](%[[C]], %[[c]])
+// CHECK:                 %[[VAL:.*]] = memref.load %[[SRC]][%[[SRC_K]], %[[SRC_C]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[C]], %[[K]], %[[k]], %[[c]]] : memref<32x4x32x8xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @CKkc_to_KC(%arg0: memref<128x256xf32>, %arg1: memref<32x4x32x8xf32>) {
+  linalg.unpack %arg1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 8]
+    into %arg0 : memref<32x4x32x8xf32> -> memref<128x256xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP0:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[$MAP1:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-DAG:   #[[$MAP2:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG:   #[[$MAP3:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @CKkc_to_KC(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<128x256xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<32x4x32x8xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C256:.*]] = arith.constant 256 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK-DAG:         %[[FLOOR_K:.*]] = affine.apply #[[$MAP0]](%[[K]])
+// CHECK-DAG:         %[[MOD_K:.*]] = affine.apply #[[$MAP1]](%[[K]])
+// CHECK-DAG:         %[[FLOOR_C:.*]] = affine.apply #[[$MAP2]](%[[C]])
+// CHECK-DAG:         %[[MOD_C:.*]] = affine.apply #[[$MAP3]](%[[C]])
+// CHECK:             %[[VAL:.*]] = memref.load %[[SRC]][%[[FLOOR_C]], %[[FLOOR_K]], %[[MOD_K]], %[[MOD_C]]] : memref<32x4x32x8xf32>
+// CHECK:             memref.store %[[VAL]], %[[DEST]][%[[K]], %[[C]]] : memref<128x256xf32>
+// CHECK:           }
+// CHECK:         }

>From 7d7edb45e9a1e9639942be662f449baff6209365 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Mon, 23 Feb 2026 03:42:39 -0800
Subject: [PATCH 2/3] Improve code quality and upstream more tests.

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |   3 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  24 ++--
 .../lower-to-loops-using-interface.mlir       | 134 ++++++++++++++++++
 3 files changed, 151 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2c89dcbadf9d3..bfc03cc7436df 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5041,10 +5041,9 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
                 "applies to only pack or unpack operations");
   int64_t destRank = op.getDestRank();
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
-  for (auto dim : llvm::seq<int64_t>(0, destRank)) {
+  for (auto dim : llvm::seq<int64_t>(0, destRank))
     reifiedReturnShapes[0][dim] =
         createFoldedDimOp(builder, op.getLoc(), op.getDest(), dim);
-  }
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 92d72c72b6840..aae2e134108c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -772,8 +772,18 @@ computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos, int64_t rank) {
   return interchangeVector;
 }
 
-/// Returns a vector that interchanges `elements` starting at offset `offset`
-/// based on the indexes in `interchangeVector`.
+/// Permute the elements of `vec` starting at position `offset` according to
+/// `interchangeVector`. The permutation maps position `i` in the permuted range
+/// to position `interchangeVector[i]` in the original range. Elements before
+/// `offset` are unchanged.
+///
+/// Example: interchange([a, b, c, d, e], [2, 0, 1], offset=2)
+///          returns [a, b, e, c, d] (permutes the suffix [c, d, e])
+///
+/// Note: This is similar to `applyPermutationToVector` but supports an offset
+/// for permuting a suffix of the vector. It is only used for pack/unpack scalar
+/// implementation where we need to permute inner tile dimensions which are
+/// stored at the end of the index vector.
 template <typename T>
 static SmallVector<T> interchange(ArrayRef<T> elements,
                                   ArrayRef<int64_t> interchangeVector,
@@ -810,15 +820,13 @@ static void generatePackOpScalarImplementationBody(PackOp packOp,
     interchangedIvs =
         interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
   }
-
-  SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
   DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
       packOp.getDimAndTileMapping();
   SmallVector<OpFoldResult> sourceIndices;
   size_t pointLoopsOffset = 0;
   int64_t sourceRank = packOp.getSourceRank();
   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
-    if (dimAndTileMapping.count(dim)) {
+    if (dimAndTileMapping.contains(dim)) {
       AffineExpr i, j, tile;
       bindDims(builder.getContext(), i, j);
       bindSymbols(builder.getContext(), tile);
@@ -1032,7 +1040,7 @@ struct PackOpTiling
     OpBuilder::InsertionGuard g(builder);
     // The `ivs` already represent the position into the output tensor for the
     // non data-tile dimensions.
-    SmallVector<Value> ivVec = llvm::to_vector(ivs);
+    SmallVector<Value> ivVec(ivs);
 
     // Get output shape - for memrefs, get dimensions from dest directly.
     SmallVector<OpFoldResult> outputShape;
@@ -1498,14 +1506,14 @@ struct UnPackOpTiling
     assert(inputIvsPointLoops.size() + inputIvs.size() ==
                unpackOp.getSourceRank() &&
            "expect same number of induction variables equals to input rank");
-    // interchange the point loops induction variables based on `inner_dim_pos`.
+    // Interchange the point loops induction variables based on `inner_dim_pos`.
     ArrayRef<int64_t> innerDims = unpackOp.getInnerDimsPos();
     SmallVector<int64_t> interchangeVector =
         computeInterchangeFromDimPos(innerDims, unpackOp.getDestRank());
     SmallVector<Value> interchangedInputIvsPointLoops = inputIvsPointLoops;
     interchangedInputIvsPointLoops = interchange<Value>(
         interchangedInputIvsPointLoops, interchangeVector, /*offset=*/0);
-    // interchange the tiled loops induction variables based on
+    // Interchange the tiled loops induction variables based on
     // `outer_dims_perm`.
     ArrayRef<int64_t> outerDims = unpackOp.getOuterDimsPerm();
     if (!outerDims.empty())
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index 7fe37a659f5a5..ec9b606491910 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -677,3 +677,137 @@ module attributes {transform.with_named_sequence} {
 // CHECK:             memref.store %[[VAL]], %[[DEST]][%[[K]], %[[C]]] : memref<128x256xf32>
 // CHECK:           }
 // CHECK:         }
+
+// -----
+
+func.func @KCRSsr_to_KCRS_static(%arg0: memref<1x1x128x64xf32>, %arg1: memref<1x1x4x8x8x32xf32>) {
+  linalg.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg0
+    : memref<1x1x4x8x8x32xf32> -> memref<1x1x128x64xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[$MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-DAG:   #[[$MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG:   #[[$MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @KCRSsr_to_KCRS_static(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<1x1x128x64xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<1x1x4x8x8x32xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C64:.*]] = arith.constant 64 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[L:.*]] = %[[C0]] to %[[C64]] step %[[C1]] {
+// CHECK-DAG:         %[[FLOORK:.*]] = affine.apply #[[$MAP_FLOORK]](%[[K]])
+// CHECK-DAG:         %[[FLOORL:.*]] = affine.apply #[[$MAP_FLOORL]](%[[L]])
+// CHECK-DAG:         %[[MODK:.*]] = affine.apply #[[$MAP_MODK]](%[[K]])
+// CHECK-DAG:         %[[MODL:.*]] = affine.apply #[[$MAP_MODL]](%[[L]])
+// CHECK:             %[[VAL:.*]] = memref.load %[[SRC]][%[[C0]], %[[C0]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<1x1x4x8x8x32xf32>
+// CHECK:             memref.store %[[VAL]], %[[DEST]][%[[C0]], %[[C0]], %[[K]], %[[L]]] : memref<1x1x128x64xf32>
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCRSsr_to_KCRS_dynamic_with_static_inner_tiles(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x32xf32>) {
+  linalg.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg0
+    : memref<?x?x?x?x8x32xf32> -> memref<?x?x?x?xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP_FLOORK:.*]] = affine_map<(d0) -> (d0 floordiv 32)>
+// CHECK-DAG:   #[[$MAP_MODK:.*]] = affine_map<(d0) -> (d0 mod 32)>
+// CHECK-DAG:   #[[$MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG:   #[[$MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @KCRSsr_to_KCRS_dynamic_with_static_inner_tiles(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<?x?x?x?x8x32xf32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[UBI:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x?x?x?xf32>
+// CHECK-DAG:     %[[UBJ:.*]] = memref.dim %[[DEST]], %[[C1]] : memref<?x?x?x?xf32>
+// CHECK-DAG:     %[[UBK:.*]] = memref.dim %[[DEST]], %[[C2]] : memref<?x?x?x?xf32>
+// CHECK-DAG:     %[[UBL:.*]] = memref.dim %[[DEST]], %[[C3]] : memref<?x?x?x?xf32>
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[UBI]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[UBJ]] step %[[C1]] {
+// CHECK:             scf.for %[[K:.*]] = %[[C0]] to %[[UBK]] step %[[C1]] {
+// CHECK:               scf.for %[[L:.*]] = %[[C0]] to %[[UBL]] step %[[C1]] {
+// CHECK-DAG:             %[[FLOORK:.*]] = affine.apply #[[$MAP_FLOORK]](%[[K]])
+// CHECK-DAG:             %[[FLOORL:.*]] = affine.apply #[[$MAP_FLOORL]](%[[L]])
+// CHECK-DAG:             %[[MODK:.*]] = affine.apply #[[$MAP_MODK]](%[[K]])
+// CHECK-DAG:             %[[MODL:.*]] = affine.apply #[[$MAP_MODL]](%[[L]])
+// CHECK:                 %[[VAL:.*]] = memref.load %[[SRC]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<?x?x?x?x8x32xf32>
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<?x?x?x?xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCRSsr_to_KCRS_dynamic_with_dynamic_inner_tiles(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x?xf32>, %block : index) {
+  linalg.unpack %arg1 inner_dims_pos = [3, 2] inner_tiles = [8, %block] into %arg0
+    : memref<?x?x?x?x8x?xf32> -> memref<?x?x?x?xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %unpack = transform.structured.match ops{["linalg.unpack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %unpack
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[$MAP_FLOORK:.*]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
+// CHECK-DAG:   #[[$MAP_MODK:.*]] = affine_map<(d0)[s0] -> (d0 mod s0)>
+// CHECK-DAG:   #[[$MAP_FLOORL:.*]] = affine_map<(d0) -> (d0 floordiv 8)>
+// CHECK-DAG:   #[[$MAP_MODL:.*]] = affine_map<(d0) -> (d0 mod 8)>
+// CHECK-LABEL: func @KCRSsr_to_KCRS_dynamic_with_dynamic_inner_tiles(
+// CHECK-SAME:    %[[DEST:[a-zA-Z0-9]+]]: memref<?x?x?x?xf32>
+// CHECK-SAME:    %[[SRC:[a-zA-Z0-9]+]]: memref<?x?x?x?x8x?xf32>
+// CHECK-SAME:    %[[TILE:[a-zA-Z0-9]+]]: index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[UBI:.*]] = memref.dim %[[DEST]], %[[C0]] : memref<?x?x?x?xf32>
+// CHECK-DAG:     %[[UBJ:.*]] = memref.dim %[[DEST]], %[[C1]] : memref<?x?x?x?xf32>
+// CHECK-DAG:     %[[UBK:.*]] = memref.dim %[[DEST]], %[[C2]] : memref<?x?x?x?xf32>
+// CHECK-DAG:     %[[UBL:.*]] = memref.dim %[[DEST]], %[[C3]] : memref<?x?x?x?xf32>
+// CHECK:         scf.for %[[I:.*]] = %[[C0]] to %[[UBI]] step %[[C1]] {
+// CHECK:           scf.for %[[J:.*]] = %[[C0]] to %[[UBJ]] step %[[C1]] {
+// CHECK:             scf.for %[[K:.*]] = %[[C0]] to %[[UBK]] step %[[C1]] {
+// CHECK:               scf.for %[[L:.*]] = %[[C0]] to %[[UBL]] step %[[C1]] {
+// CHECK-DAG:             %[[FLOORK:.*]] = affine.apply #[[$MAP_FLOORK]](%[[K]])[%[[TILE]]]
+// CHECK-DAG:             %[[FLOORL:.*]] = affine.apply #[[$MAP_FLOORL]](%[[L]])
+// CHECK-DAG:             %[[MODK:.*]] = affine.apply #[[$MAP_MODK]](%[[K]])[%[[TILE]]]
+// CHECK-DAG:             %[[MODL:.*]] = affine.apply #[[$MAP_MODL]](%[[L]])
+// CHECK:                 %[[VAL:.*]] = memref.load %[[SRC]][%[[I]], %[[J]], %[[FLOORK]], %[[FLOORL]], %[[MODL]], %[[MODK]]] : memref<?x?x?x?x8x?xf32>
+// CHECK:                 memref.store %[[VAL]], %[[DEST]][%[[I]], %[[J]], %[[K]], %[[L]]] : memref<?x?x?x?xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }

>From 919ab63302134bb646e0b2f50e483f2753fb9ec0 Mon Sep 17 00:00:00 2001
From: hanhanW <hanhan0912 at gmail.com>
Date: Wed, 25 Feb 2026 22:35:48 -0800
Subject: [PATCH 3/3] address comments

Signed-off-by: hanhanW <hanhan0912 at gmail.com>
---
 .../lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index aae2e134108c3..fd9c8a7a8eba7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -1038,8 +1038,8 @@ struct PackOpTiling
     assert(packOp.hasPureBufferSemantics() &&
            "expected operation to have buffer semantics");
     OpBuilder::InsertionGuard g(builder);
-    // The `ivs` already represent the position into the output tensor for the
-    // non data-tile dimensions.
+    // The `ivs` already represent the position into the output for the non
+    // data-tile dimensions.
     SmallVector<Value> ivVec(ivs);
 
     // Get output shape - for memrefs, get dimensions from dest directly.
@@ -1482,9 +1482,9 @@ struct UnPackOpTiling
 
     DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
         unpackOp.getDimAndTileMapping();
-    // untiled loops and tile loops induction variables.
+    // Untiled loops and tile loops induction variables.
     SmallVector<Value> inputIvs;
-    // point loops induction variables.
+    // Point loops induction variables.
     SmallVector<Value> inputIvsPointLoops;
     inputIvs.reserve(unpackOp.getDestRank());
     inputIvsPointLoops.reserve(dimAndTileMapping.size());



More information about the Mlir-commits mailing list