[Mlir-commits] [mlir] [mlir][Linalg] enable scalar lowering for linalg.pack (PR #178222)

Ryutaro Okada llvmlistbot at llvm.org
Tue Jan 27 07:04:43 PST 2026


https://github.com/sakupan102 created https://github.com/llvm/llvm-project/pull/178222

This change is part of an effort to upstream IREE’s code to MLIR. It implements generateScalarImplementation for linalg.pack so that pack can lower to scalar loop code.

>From 0be82a90dd6f97a83cef5200065ad28bee7b05cb Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 27 Jan 2026 22:44:24 +0900
Subject: [PATCH] [mlir][Linalg] enable scalar lowering for linalg.pack
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This change is part of an effort to upstream IREE’s code to MLIR.
It implements generateScalarImplementation for linalg.pack so that pack can
lower to scalar loop code.

Co-authored-by: Han-Chung Wang <hanhan0912 at gmail.com>
Co-authored-by: lorenzo chelini <l.chelini at icloud.com>
Co-authored-by: Hyunsung Lee <ita9naiwa at gmail.com>
Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
 .../include/mlir/Dialect/Linalg/Utils/Utils.h |  17 +
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |  12 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 138 +++++-
 mlir/lib/Dialect/Linalg/Utils/Utils.cpp       |  26 ++
 .../lower-to-loops-using-interface.mlir       | 400 ++++++++++++++++++
 5 files changed, 587 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 06e7a472a8182..c1f9db1e3530f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -44,6 +44,23 @@ SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
 SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp,
                                              PackingMetadata &metadata);
 
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+SmallVector<T> interchange(ArrayRef<T> elements,
+                           ArrayRef<int64_t> interchangeVector,
+                           int offset = 0) {
+  SmallVector<T> vec = llvm::to_vector(elements);
+  for (auto [idx, val] : llvm::enumerate(interchangeVector)) {
+    vec[idx + offset] = elements[val + offset];
+  }
+  return vec;
+}
+
+/// Returns the `interchangeVector` based on `dimsPos`.
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+                                                  int64_t rank);
+
 //===----------------------------------------------------------------------===//
 // General utilities
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 40cabe20d1a4b..1162b96a67fc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5035,8 +5035,14 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
                 "applies to only pack or unpack operations");
   int64_t destRank = op.getDestRank();
   reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
-  reifiedReturnShapes[0] =
-      tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+  if (op.hasPureTensorSemantics()) {
+    reifiedReturnShapes[0] =
+        tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+  }
+  if (op.hasPureBufferSemantics()) {
+    reifiedReturnShapes[0] =
+        memref::getMixedSizes(builder, op.getLoc(), op.getDest());
+  }
   return success();
 }
 
@@ -5434,8 +5440,6 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
 LogicalResult
 PackOp::reifyResultShapes(OpBuilder &builder,
                           ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
-  if (!hasPureTensorSemantics())
-    return failure();
   return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5d39c4731dd1b..3ab6e741c7366 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -80,6 +80,90 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
   return success();
 }
 
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+                                                   OpBuilder &builder,
+                                                   Location loc,
+                                                   ValueRange ivs) {
+  // Note: `ivs` are already in the correct order, possibly interchanged based
+  // on `dims_pos`. However, connecting the loops with the access patterns is
+  // difficult - What is the relation between the position of the tile loop
+  // and the point loop? However, if we interchange `ivs` once more to go to
+  // the canonical blocking format: ABCabc, this connection becomes trivial:
+  // Each point loop is pointLoopsOffset + inputRank away from the tiled loop.
+  ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
+  ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
+
+  SmallVector<Value> interchangedIvs = ivs;
+  SmallVector<int64_t> interchangeVector =
+      computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
+  interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+                                       /*offset=*/packOp.getSourceRank());
+  if (!dimsToOuterBlock.empty()) {
+    interchangeVector =
+        computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
+    interchangedIvs =
+        interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
+  }
+
+  SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+  DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+      packOp.getDimAndTileMapping();
+  SmallVector<OpFoldResult> sourceIndices;
+  size_t pointLoopsOffset = 0;
+  int64_t inputRank = packOp.getSourceRank();
+  for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+    if (dimAndTileMapping.count(dim)) {
+      AffineExpr i, j, tile;
+      bindDims(builder.getContext(), i, j);
+      bindSymbols(builder.getContext(), tile);
+      OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
+          builder, loc, i * tile + j,
+          ArrayRef<OpFoldResult>{
+              interchangedIvs[dim],
+              interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
+              dimAndTileMapping[dim]});
+      sourceIndices.push_back(sourceIndex);
+      ++pointLoopsOffset;
+    } else {
+      sourceIndices.push_back(interchangedIvs[dim]);
+    }
+  }
+
+  auto createLoad = [&]() -> Value {
+    return memref::LoadOp::create(
+        builder, loc, packOp.getSource(),
+        getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
+  };
+  Value scalar;
+  if (auto paddingValue = packOp.getPaddingValue()) {
+    ArithBuilder arithBuilder(builder, loc);
+    Value isInBounds;
+    for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+      Value idx =
+          getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
+      Value cond = arithBuilder.slt(
+          idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
+      isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
+    }
+    scalar = scf::IfOp::create(
+                 builder, loc, isInBounds, /*thenBuilder=*/
+                 [&](OpBuilder &b, Location l) {
+                   scf::YieldOp::create(b, l, createLoad());
+                 },
+                 /*elseBuilder=*/
+                 [&](OpBuilder &b, Location l) {
+                   scf::YieldOp::create(b, l, paddingValue);
+                 })
+                 .getResult(0);
+  } else {
+    scalar = createLoad();
+  }
+
+  memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
+}
+
 //===----------------------------------------------------------------------===//
 // External Model for implementing `TilingInterface` for `LinalgOp`s.
 //===----------------------------------------------------------------------===//
@@ -725,7 +809,7 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
   OpFoldResult zero = builder.getIndexAttr(0);
   OpFoldResult one = builder.getIndexAttr(1);
   ReifiedRankedShapedTypeDims resultShape;
-  (void)reifyResultShapes(builder, op, resultShape);
+  (void)op.reifyResultShapes(builder, resultShape);
   SmallVector<Range> loopBounds(rank);
   for (auto dim : llvm::seq<int64_t>(0, rank)) {
     loopBounds[dim].offset = zero;
@@ -865,7 +949,9 @@ struct PackOpTiling
     resultOffsets.append(outputRank - inputRank, zeroAttr);
 
     ReifiedRankedShapedTypeDims outputShape;
-    (void)reifyResultShapes(b, packOp, outputShape);
+    if (failed(packOp.reifyResultShapes(b, outputShape))) {
+      return packOp.getOperation()->emitOpError("failed to reify result shape");
+    }
     resultSizes.assign(sizes.begin(), sizes.end());
     for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
       resultSizes.push_back(outputShape[0][dataTileDim]);
@@ -1058,6 +1144,54 @@ struct PackOpTiling
         SmallVector<Value>(tiledPackOp->getResults()),
         llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
   }
+
+  LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+                                             Location loc,
+                                             ValueRange ivs) const {
+    OpBuilder::InsertionGuard g(builder);
+    auto packOp = cast<PackOp>(op);
+    // The `ivs` already represent the position into the output tensor for the
+    // non data-tile dimensions.
+    SmallVector<Value> ivVec = llvm::to_vector(ivs);
+    ReifiedRankedShapedTypeDims outputShape;
+    if (failed(packOp.reifyResultShapes(builder, outputShape))) {
+      return packOp.getOperation()->emitOpError("failed to reify result shape");
+    }
+    if (outputShape.size() != 1 ||
+        outputShape[0].size() != packOp.getDestRank()) {
+      return packOp.getOperation()->emitOpError(
+                 "expected shape of one result value of rank")
+             << packOp.getDestRank();
+    }
+
+    // Generate the loops that iterate over the data tile.
+    Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+    Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
+    // All loops except the innermost are simple loops that just iterate
+    // over the tile dimensions.
+    for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
+                                                packOp.getDestRank() - 1)) {
+      Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+                                                 outputShape[0][dataTileDim]);
+      scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
+      builder.setInsertionPointToStart(loop.getBody());
+      ivVec.push_back(loop.getInductionVar());
+    }
+    // The body of the innermost loops does the actual data movement.
+    scf::ForOp::create(
+        builder, loc, zero,
+        getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()),
+        one, ValueRange{},
+        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+            ValueRange regionIterArgs) {
+          ivVec.push_back(iv);
+          generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
+                                                 ivVec);
+          scf::YieldOp::create(bodyBuilder, bodyLoc);
+        });
+    return success();
+  }
 };
 
 struct UnpackTileDimInfo {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index a1ee6b307caf5..c922ee4b7b6f0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -192,6 +192,32 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
   return unpackInvSrcPerm;
 }
 
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+                                                  int64_t rank) {
+  SmallVector<int64_t> interchangeVector;
+  interchangeVector.reserve(dimsPos.size());
+  // First map dims and their position. For example, dims_pos = [2, 0] will
+  // map to:
+  // [
+  //  [ key: 2, value: 0]
+  //  [ key: 0, value: 1]
+  // ]
+  // where key is the idx in dims_pos while value its position in dims_pos.
+  DenseMap<int64_t, int64_t> dimsAndPosMapping;
+  for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) {
+    dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+  }
+
+  // Scan the position in order and insert the value in the map
+  // to compute the interchange vector.
+  for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
+    if (dimsAndPosMapping.count(dimsIdx)) {
+      interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
+    }
+  }
+  return interchangeVector;
+}
+
 bool allIndexingsAreProjectedPermutation(LinalgOp op) {
   return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
     return m.isProjectedPermutation(/*allowZeroInResults=*/true);
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index aa8882d21698c..50c21ae721ab3 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -400,3 +400,403 @@ module attributes {transform.with_named_sequence} {
 // CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
 // CHECK:           %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
 // CHECK:           memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
+
+// -----
+
+func.func @NC_to_NCnc(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : memref<128x256xf32> -> memref<4x8x32x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @NC_to_NCnc(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[n:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[applyMapI:.*]] = affine.apply #[[MAP]](%[[N]], %[[n]])
+// CHECK-DAG:             %[[applyMapJ:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<4x8x32x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @NC_to_NCnc_pad_static(%arg0: memref<13x15xf32>, %arg1: memref<2x8x8x2xf32>, %arg2: f32) {
+  linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : memref<13x15xf32> -> memref<2x8x8x2xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+// CHECK:       func.func @NC_to_NCnc_pad_static(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C13:.*]] = arith.constant 13 : index
+// CHECK-DAG:     %[[C15:.*]] = arith.constant 15 : index
+// CHECK:           scf.for %[[N:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK:             scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:               scf.for %[[n:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:                 scf.for %[[c:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-DAG:               %[[applyMapI:.*]] = affine.apply #[[MAP0]](%[[N]], %[[n]])
+// CHECK-DAG:               %[[applyMapJ:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]])
+// CHECK:                   %[[isIInBound:.*]] = arith.cmpi slt, %[[applyMapI]], %[[C13]] : index
+// CHECK:                   %[[isJInBound:.*]] = arith.cmpi slt, %[[applyMapJ]], %[[C15]] : index
+// CHECK:                   %[[isAllInBounds:.*]] = arith.andi %[[isIInBound]], %[[isJInBound]] : i1
+// CHECK:                   %[[scalar:.*]] = scf.if %[[isAllInBounds]] -> (f32) {
+// CHECK:                     %[[load:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<13x15xf32>
+// CHECK:                     scf.yield %[[load]]
+// CHECK:                   } else {
+// CHECK:                     scf.yield %arg2
+// CHECK:                   }
+// CHECK:                   memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<2x8x8x2xf32>
+
+// -----
+
+func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1 : memref<128x256xf32> -> memref<4x8x32x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @KC_to_KCck(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK-DAG:             %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[applyMapC]]] : memref<128x256xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]], %[[k]]] : memref<4x8x32x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+// This should be a simple expand shape.
+func.func @KC_to_KCc(%arg0: memref<128x256xf32>, %arg1: memref<128x8x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : memref<128x256xf32> -> memref<128x8x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @KC_to_KCc(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK:               %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[applyMapC]]] : memref<128x256xf32>
+// CHECK:               memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]]] : memref<128x8x32xf32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KC_to_KCk(%arg0: memref<128x256xf32>, %arg1: memref<4x256x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %arg1 : memref<128x256xf32> -> memref<4x256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK:       #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @KC_to_KCk(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C256:.*]] = arith.constant 256 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK:             scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK:               %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]])
+// CHECK:               %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[C]]] : memref<128x256xf32>
+// CHECK:               memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[k]]] : memref<4x256x32xf32>
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCRS_to_KCRSck(%arg0: memref<128x64x1x1xf32>, %arg1: memref<4x8x1x1x8x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [8, 32] into %arg1 : memref<128x64x1x1xf32> -> memref<4x8x1x1x8x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK:       func.func @KCRS_to_KCRSck(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[c:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:               scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[affineMapK:.*]] = affine.apply #[[MAP0]](%[[K]], %[[k]])
+// CHECK-DAG:             %[[affineMapC:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[affineMapK]], %[[affineMapC]], %[[C0]], %[[C0]]] : memref<128x64x1x1xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[C0]], %[[C0]], %[[c]], %[[k]]] : memref<4x8x1x1x8x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCRS_to_KCRSsr(%arg0: memref<1x1x128x64xf32>, %arg1: memref<1x1x4x8x8x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : memref<1x1x128x64xf32> -> memref<1x1x4x8x8x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK:       func.func @KCRS_to_KCRSsr(
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         scf.for %[[R:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[S:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:             scf.for %[[s:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:               scf.for %[[r:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])
+// CHECK-DAG:             %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[C0]], %[[C0]], %[[affineMapR]], %[[affineMapS]]] : memref<1x1x128x64xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[C0]], %[[C0]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<1x1x4x8x8x32xf32>
+// CHECK:                   }
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+
+// -----
+
+// Test to check that we properly handle shuffled `inner_dims_pos` and `tiles.
+// In this example, the dimension at position `0` (aka `128`) is tiled with a factor of `32`.
+// While the dimension at position `2` (aka `2`) is tiled with a factor of `2`.
+func.func @shuffled_dim_pos_and_tiles(%arg0: memref<128x256x2x1000xf32>, %arg1: memref<4x256x1x1000x2x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [2, 0] inner_tiles = [2, 32] into %arg1 : memref<128x256x2x1000xf32> -> memref<4x256x1x1000x2x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK:       func.func @shuffled_dim_pos_and_tiles(
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C1000:.*]] = arith.constant 1000 : index
+// CHECK-DAG:     %[[C256:.*]] = arith.constant 256 : index
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK:         scf.for %[[i:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK:           scf.for %[[j:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK:             scf.for %[[l:.*]] = %[[C0]] to %[[C1000]] step %[[C1]] {
+// CHECK:               scf.for %[[m:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK:                 scf.for %[[n:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:             %[[affineApplyZero:.*]] = affine.apply #[[MAP0]](%[[i]], %[[n]])
+// CHECK:                 %[[scalar:.*]] = memref.load %arg0[%[[affineApplyZero]], %[[j]], %[[m]], %[[l]]] : memref<128x256x2x1000xf32>
+// CHECK:                 memref.store %[[scalar]], %arg1[%[[i]], %[[j]], %[[C0]], %[[l]], %[[m]], %[[n]]] : memref<4x256x1x1000x2x32xf32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCRS_to_KCRSsr(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x32xf32>) {
+  linalg.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : memref<?x?x?x?xf32> -> memref<?x?x?x?x8x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-DAG:   #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG:   #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK:       func.func @KCRS_to_KCRSsr(
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG:     %[[dimZero:.*]] = memref.dim %arg1, %[[C0]] : memref<?x?x?x?x8x32xf32>
+// CHECK-DAG:     %[[dimOne:.*]] = memref.dim %arg1, %[[C1]] : memref<?x?x?x?x8x32xf32>
+// CHECK-DAG:     %[[dimTwo:.*]] = memref.dim %arg1, %[[C2]] : memref<?x?x?x?x8x32xf32>
+// CHECK-DAG:     %[[dimThree:.*]] = memref.dim %arg1, %[[C3]] : memref<?x?x?x?x8x32xf32>
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[dimZero]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[dimOne]] step %[[C1]] {
+// CHECK:             scf.for %[[R:.*]] = %[[C0]] to %[[dimTwo]] step %[[C1]] {
+// CHECK:               scf.for %[[S:.*]] = %[[C0]] to %[[dimThree]] step %[[C1]] {
+// CHECK:                 scf.for %[[s:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:                   scf.for %[[r:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG:                 %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])
+// CHECK-DAG:                 %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]])
+// CHECK:                     %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref<?x?x?x?xf32>
+// CHECK:                     memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<?x?x?x?x8x32xf32>
+// CHECK:                   }
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+
+// -----
+
+func.func @KCRS_to_KCRSsr(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x?xf32>, %block : index) {
+  linalg.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, %block] into %arg1 : memref<?x?x?x?xf32> -> memref<?x?x?x?x8x?xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %0 = transform.structured.convert_to_loops %broadcast
+      : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK-DAG:  #[[MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
+// CHECK-DAG:  #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK:      func.func @KCRS_to_KCRSsr
+// CHECK-SAME:    %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME:    %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG:     %[[dimZero:.*]] = memref.dim %[[ARG1]], %[[C0]] : memref<?x?x?x?x8x?xf32>
+// CHECK-DAG:     %[[dimOne:.*]] = memref.dim %[[ARG1]], %[[C1]] : memref<?x?x?x?x8x?xf32>
+// CHECK-DAG:     %[[dimTwo:.*]] = memref.dim %[[ARG1]], %[[C2]] : memref<?x?x?x?x8x?xf32>
+// CHECK-DAG:     %[[dimThree:.*]] = memref.dim %[[ARG1]], %[[C3]] : memref<?x?x?x?x8x?xf32>
+// CHECK:         scf.for %[[K:.*]] = %[[C0]] to %[[dimZero]] step %[[C1]] {
+// CHECK:           scf.for %[[C:.*]] = %[[C0]] to %[[dimOne]] step %[[C1]] {
+// CHECK:             scf.for %[[R:.*]] = %[[C0]] to %[[dimTwo]] step %[[C1]] {
+// CHECK:               scf.for %[[S:.*]] = %[[C0]] to %[[dimThree]] step %[[C1]] {
+// CHECK:                 %[[dimFive:.*]] = memref.dim %[[ARG1]], %[[C5]] : memref<?x?x?x?x8x?xf32>
+// CHECK:                 scf.for %[[s:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK:                   scf.for %[[r:.*]] = %[[C0]] to %[[dimFive]] step %[[C1]] {
+// CHECK-DAG:                 %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])[%[[ARG2]]]
+// CHECK-DAG:                 %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]])
+// CHECK:                     %[[scalar:.*]] = memref.load %[[ARG0]][%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref<?x?x?x?xf32>
+// CHECK:                     memref.store %[[scalar]], %[[ARG1]][%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<?x?x?x?x8x?xf32>
+// CHECK:                   }
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+



More information about the Mlir-commits mailing list