[Mlir-commits] [mlir] [mlir][Linalg] enable scalar lowering for linalg.pack (PR #178222)
Ryutaro Okada
llvmlistbot at llvm.org
Tue Jan 27 07:04:43 PST 2026
https://github.com/sakupan102 created https://github.com/llvm/llvm-project/pull/178222
This change is part of an effort to upstream IREE’s code to MLIR. It implements generateScalarImplementation for linalg.pack so that pack can lower to scalar loop code.
>From 0be82a90dd6f97a83cef5200065ad28bee7b05cb Mon Sep 17 00:00:00 2001
From: Ryutaro Okada <1015ryu88 at gmail.com>
Date: Tue, 27 Jan 2026 22:44:24 +0900
Subject: [PATCH] [mlir][Linalg] enable scalar lowering for linalg.pack
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This change is part of an effort to upstream IREE’s code to MLIR.
It implements generateScalarImplementation for linalg.pack so that pack can
lower to scalar loop code.
Co-authored-by: Han-Chung Wang <hanhan0912 at gmail.com>
Co-authored-by: lorenzo chelini <l.chelini at icloud.com>
Co-authored-by: Hyunsung Lee <ita9naiwa at gmail.com>
Signed-off-by: Ryutaro Okada <1015ryu88 at gmail.com>
---
.../include/mlir/Dialect/Linalg/Utils/Utils.h | 17 +
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 12 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 138 +++++-
mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 26 ++
.../lower-to-loops-using-interface.mlir | 400 ++++++++++++++++++
5 files changed, 587 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
index 06e7a472a8182..c1f9db1e3530f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
@@ -44,6 +44,23 @@ SmallVector<int64_t> getPackInverseDestPerm(linalg::PackOp packOp,
SmallVector<int64_t> getUnPackInverseSrcPerm(linalg::UnPackOp,
PackingMetadata &metadata);
+/// Returns a vector that interchanges `elements` starting at offset `offset`
+/// based on the indexes in `interchangeVector`.
+template <typename T>
+SmallVector<T> interchange(ArrayRef<T> elements,
+ ArrayRef<int64_t> interchangeVector,
+ int offset = 0) {
+ SmallVector<T> vec = llvm::to_vector(elements);
+ for (auto [idx, val] : llvm::enumerate(interchangeVector)) {
+ vec[idx + offset] = elements[val + offset];
+ }
+ return vec;
+}
+
+/// Returns the `interchangeVector` based on `dimsPos`.
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+ int64_t rank);
+
//===----------------------------------------------------------------------===//
// General utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 40cabe20d1a4b..1162b96a67fc7 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -5035,8 +5035,14 @@ reifyResultShapesImpl(OpTy op, OpBuilder &builder,
"applies to only pack or unpack operations");
int64_t destRank = op.getDestRank();
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
- reifiedReturnShapes[0] =
- tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+ if (op.hasPureTensorSemantics()) {
+ reifiedReturnShapes[0] =
+ tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
+ }
+ if (op.hasPureBufferSemantics()) {
+ reifiedReturnShapes[0] =
+ memref::getMixedSizes(builder, op.getLoc(), op.getDest());
+ }
return success();
}
@@ -5434,8 +5440,6 @@ void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
LogicalResult
PackOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
- if (!hasPureTensorSemantics())
- return failure();
return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 5d39c4731dd1b..3ab6e741c7366 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -80,6 +80,90 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp,
return success();
}
+/// Generate the body of the innermost loop of the scalar implementation
+/// of `pack` operation.
+static void generatePackOpScalarImplementationBody(PackOp packOp,
+ OpBuilder &builder,
+ Location loc,
+ ValueRange ivs) {
+ // Note: `ivs` are already in the correct order, possibly interchanged based
+ // on `dims_pos`. However, connecting the loops with the access patterns is
+ // difficult - What is the relation between the position of the tile loop
+ // and the point loop? However, if we interchange `ivs` once more to go to
+ // the canonical blocking format: ABCabc, this connection becomes trivial:
+ // Each point loop is pointLoopsOffset + inputRank away from the tiled loop.
+ ArrayRef<int64_t> dimsToInnerBlock = packOp.getInnerDimsPos();
+ ArrayRef<int64_t> dimsToOuterBlock = packOp.getOuterDimsPerm();
+
+ SmallVector<Value> interchangedIvs = ivs;
+ SmallVector<int64_t> interchangeVector =
+ computeInterchangeFromDimPos(dimsToInnerBlock, packOp.getSourceRank());
+ interchangedIvs = interchange<Value>(interchangedIvs, interchangeVector,
+ /*offset=*/packOp.getSourceRank());
+ if (!dimsToOuterBlock.empty()) {
+ interchangeVector =
+ computeInterchangeFromDimPos(dimsToOuterBlock, packOp.getSourceRank());
+ interchangedIvs =
+ interchange<Value>(interchangedIvs, interchangeVector, /*offset=*/0);
+ }
+
+ SmallVector<OpFoldResult> tiles = packOp.getMixedTiles();
+ DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
+ packOp.getDimAndTileMapping();
+ SmallVector<OpFoldResult> sourceIndices;
+ size_t pointLoopsOffset = 0;
+ int64_t inputRank = packOp.getSourceRank();
+ for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+ if (dimAndTileMapping.count(dim)) {
+ AffineExpr i, j, tile;
+ bindDims(builder.getContext(), i, j);
+ bindSymbols(builder.getContext(), tile);
+ OpFoldResult sourceIndex = affine::makeComposedFoldedAffineApply(
+ builder, loc, i * tile + j,
+ ArrayRef<OpFoldResult>{
+ interchangedIvs[dim],
+ interchangedIvs[pointLoopsOffset + packOp.getSourceRank()],
+ dimAndTileMapping[dim]});
+ sourceIndices.push_back(sourceIndex);
+ ++pointLoopsOffset;
+ } else {
+ sourceIndices.push_back(interchangedIvs[dim]);
+ }
+ }
+
+ auto createLoad = [&]() -> Value {
+ return memref::LoadOp::create(
+ builder, loc, packOp.getSource(),
+ getValueOrCreateConstantIndexOp(builder, loc, sourceIndices));
+ };
+ Value scalar;
+ if (auto paddingValue = packOp.getPaddingValue()) {
+ ArithBuilder arithBuilder(builder, loc);
+ Value isInBounds;
+ for (auto dim : llvm::seq<int64_t>(0, inputRank)) {
+ Value idx =
+ getValueOrCreateConstantIndexOp(builder, loc, sourceIndices[dim]);
+ Value cond = arithBuilder.slt(
+ idx, createOrFoldDimOp(builder, loc, packOp.getSource(), dim));
+ isInBounds = dim == 0 ? cond : arithBuilder._and(isInBounds, cond);
+ }
+ scalar = scf::IfOp::create(
+ builder, loc, isInBounds, /*thenBuilder=*/
+ [&](OpBuilder &b, Location l) {
+ scf::YieldOp::create(b, l, createLoad());
+ },
+ /*elseBuilder=*/
+ [&](OpBuilder &b, Location l) {
+ scf::YieldOp::create(b, l, paddingValue);
+ })
+ .getResult(0);
+ } else {
+ scalar = createLoad();
+ }
+
+ memref::StoreOp::create(builder, loc, scalar, packOp.getDest(), ivs);
+}
+
//===----------------------------------------------------------------------===//
// External Model for implementing `TilingInterface` for `LinalgOp`s.
//===----------------------------------------------------------------------===//
@@ -725,7 +809,7 @@ static SmallVector<Range> getPackUnPackIterationDomain(OpTy op,
OpFoldResult zero = builder.getIndexAttr(0);
OpFoldResult one = builder.getIndexAttr(1);
ReifiedRankedShapedTypeDims resultShape;
- (void)reifyResultShapes(builder, op, resultShape);
+ (void)op.reifyResultShapes(builder, resultShape);
SmallVector<Range> loopBounds(rank);
for (auto dim : llvm::seq<int64_t>(0, rank)) {
loopBounds[dim].offset = zero;
@@ -865,7 +949,9 @@ struct PackOpTiling
resultOffsets.append(outputRank - inputRank, zeroAttr);
ReifiedRankedShapedTypeDims outputShape;
- (void)reifyResultShapes(b, packOp, outputShape);
+ if (failed(packOp.reifyResultShapes(b, outputShape))) {
+ return packOp.getOperation()->emitOpError("failed to reify result shape");
+ }
resultSizes.assign(sizes.begin(), sizes.end());
for (auto dataTileDim : llvm::seq<unsigned>(inputRank, outputRank))
resultSizes.push_back(outputShape[0][dataTileDim]);
@@ -1058,6 +1144,54 @@ struct PackOpTiling
SmallVector<Value>(tiledPackOp->getResults()),
llvm::to_vector(ArrayRef<Operation *>{sourceSlice, outSlice})};
}
+
+ LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder,
+ Location loc,
+ ValueRange ivs) const {
+ OpBuilder::InsertionGuard g(builder);
+ auto packOp = cast<PackOp>(op);
+ // The `ivs` already represent the position into the output tensor for the
+ // non data-tile dimensions.
+ SmallVector<Value> ivVec = llvm::to_vector(ivs);
+ ReifiedRankedShapedTypeDims outputShape;
+ if (failed(packOp.reifyResultShapes(builder, outputShape))) {
+ return packOp.getOperation()->emitOpError("failed to reify result shape");
+ }
+ if (outputShape.size() != 1 ||
+ outputShape[0].size() != packOp.getDestRank()) {
+ return packOp.getOperation()->emitOpError(
+ "expected shape of one result value of rank")
+ << packOp.getDestRank();
+ }
+
+ // Generate the loops that iterate over the data tile.
+ Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
+ Value one = arith::ConstantIndexOp::create(builder, loc, 1);
+
+ // All loops except the innermost are simple loops that just iterate
+ // over the tile dimensions.
+ for (auto dataTileDim : llvm::seq<unsigned>(packOp.getSourceRank(),
+ packOp.getDestRank() - 1)) {
+ Value ub = getValueOrCreateConstantIndexOp(builder, loc,
+ outputShape[0][dataTileDim]);
+ scf::ForOp loop = scf::ForOp::create(builder, loc, zero, ub, one);
+ builder.setInsertionPointToStart(loop.getBody());
+ ivVec.push_back(loop.getInductionVar());
+ }
+ // The body of the innermost loops does the actual data movement.
+ scf::ForOp::create(
+ builder, loc, zero,
+ getValueOrCreateConstantIndexOp(builder, loc, outputShape[0].back()),
+ one, ValueRange{},
+ [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
+ ValueRange regionIterArgs) {
+ ivVec.push_back(iv);
+ generatePackOpScalarImplementationBody(packOp, bodyBuilder, bodyLoc,
+ ivVec);
+ scf::YieldOp::create(bodyBuilder, bodyLoc);
+ });
+ return success();
+ }
};
struct UnpackTileDimInfo {
diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
index a1ee6b307caf5..c922ee4b7b6f0 100644
--- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp
@@ -192,6 +192,32 @@ SmallVector<int64_t> getUnPackInverseSrcPerm(UnPackOp unpackOp,
return unpackInvSrcPerm;
}
+SmallVector<int64_t> computeInterchangeFromDimPos(ArrayRef<int64_t> dimsPos,
+ int64_t rank) {
+ SmallVector<int64_t> interchangeVector;
+ interchangeVector.reserve(dimsPos.size());
+ // First map dims and their position. For example, dims_pos = [2, 0] will
+ // map to:
+ // [
+ // [ key: 2, value: 0]
+ // [ key: 0, value: 1]
+ // ]
+ // where key is the idx in dims_pos while value its position in dims_pos.
+ DenseMap<int64_t, int64_t> dimsAndPosMapping;
+ for (int64_t dimsIdx = 0, end = dimsPos.size(); dimsIdx < end; dimsIdx++) {
+ dimsAndPosMapping[dimsPos[dimsIdx]] = dimsIdx;
+ }
+
+ // Scan the position in order and insert the value in the map
+ // to compute the interchange vector.
+ for (int64_t dimsIdx = 0; dimsIdx < rank; dimsIdx++) {
+ if (dimsAndPosMapping.count(dimsIdx)) {
+ interchangeVector.push_back(dimsAndPosMapping[dimsIdx]);
+ }
+ }
+ return interchangeVector;
+}
+
bool allIndexingsAreProjectedPermutation(LinalgOp op) {
return llvm::all_of(op.getIndexingMapsArray(), [](AffineMap m) {
return m.isProjectedPermutation(/*allowZeroInResults=*/true);
diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
index aa8882d21698c..50c21ae721ab3 100644
--- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir
@@ -400,3 +400,403 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
// CHECK: %[[ELEM:.*]] = memref.load %[[IN]][%[[I]], %[[K]]]
// CHECK: memref.store %[[ELEM]], %[[OUT]][%[[I]], %[[J]], %[[K]]]
+
+// -----
+
+func.func @NC_to_NCnc(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg1 : memref<128x256xf32> -> memref<4x8x32x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK: func.func @NC_to_NCnc(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[N:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[n:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[applyMapI:.*]] = affine.apply #[[MAP]](%[[N]], %[[n]])
+// CHECK-DAG: %[[applyMapJ:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<128x256xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<4x8x32x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @NC_to_NCnc_pad_static(%arg0: memref<13x15xf32>, %arg1: memref<2x8x8x2xf32>, %arg2: f32) {
+ linalg.pack %arg0 padding_value(%arg2 : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %arg1 : memref<13x15xf32> -> memref<2x8x8x2xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
+// CHECK: func.func @NC_to_NCnc_pad_static(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C13:.*]] = arith.constant 13 : index
+// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
+// CHECK: scf.for %[[N:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[n:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK-DAG: %[[applyMapI:.*]] = affine.apply #[[MAP0]](%[[N]], %[[n]])
+// CHECK-DAG: %[[applyMapJ:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]])
+// CHECK: %[[isIInBound:.*]] = arith.cmpi slt, %[[applyMapI]], %[[C13]] : index
+// CHECK: %[[isJInBound:.*]] = arith.cmpi slt, %[[applyMapJ]], %[[C15]] : index
+// CHECK: %[[isAllInBounds:.*]] = arith.andi %[[isIInBound]], %[[isJInBound]] : i1
+// CHECK: %[[scalar:.*]] = scf.if %[[isAllInBounds]] -> (f32) {
+// CHECK: %[[load:.*]] = memref.load %arg0[%[[applyMapI]], %[[applyMapJ]]] : memref<13x15xf32>
+// CHECK: scf.yield %[[load]]
+// CHECK: } else {
+// CHECK: scf.yield %arg2
+// CHECK: }
+// CHECK: memref.store %[[scalar]], %arg1[%[[N]], %[[C]], %[[n]], %[[c]]] : memref<2x8x8x2xf32>
+
+// -----
+
+func.func @KC_to_KCck(%arg0: memref<128x256xf32>, %arg1: memref<4x8x32x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 32] into %arg1 : memref<128x256xf32> -> memref<4x8x32x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK: func.func @KC_to_KCck(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK-DAG: %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[applyMapC]]] : memref<128x256xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]], %[[k]]] : memref<4x8x32x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+// This should be a simple expand shape.
+func.func @KC_to_KCc(%arg0: memref<128x256xf32>, %arg1: memref<128x8x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %arg1 : memref<128x256xf32> -> memref<128x8x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK: func.func @KC_to_KCc(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C128]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: %[[applyMapC:.*]] = affine.apply #[[MAP]](%[[C]], %[[c]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[applyMapC]]] : memref<128x256xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[c]]] : memref<128x8x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KC_to_KCk(%arg0: memref<128x256xf32>, %arg1: memref<4x256x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %arg1 : memref<128x256xf32> -> memref<4x256x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK: func.func @KC_to_KCk(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK: scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK: %[[applyMapK:.*]] = affine.apply #[[MAP]](%[[K]], %[[k]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[applyMapK]], %[[C]]] : memref<128x256xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[k]]] : memref<4x256x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCRS_to_KCRSck(%arg0: memref<128x64x1x1xf32>, %arg1: memref<4x8x1x1x8x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [1, 0] inner_tiles = [8, 32] into %arg1 : memref<128x64x1x1xf32> -> memref<4x8x1x1x8x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: func.func @KCRS_to_KCRSck(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[c:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[k:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[affineMapK:.*]] = affine.apply #[[MAP0]](%[[K]], %[[k]])
+// CHECK-DAG: %[[affineMapC:.*]] = affine.apply #[[MAP1]](%[[C]], %[[c]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[affineMapK]], %[[affineMapC]], %[[C0]], %[[C0]]] : memref<128x64x1x1xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[C0]], %[[C0]], %[[c]], %[[k]]] : memref<4x8x1x1x8x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCRS_to_KCRSsr(%arg0: memref<1x1x128x64xf32>, %arg1: memref<1x1x4x8x8x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : memref<1x1x128x64xf32> -> memref<1x1x4x8x8x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: func.func @KCRS_to_KCRSsr(
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[R:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[S:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[s:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[r:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])
+// CHECK-DAG: %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[C0]], %[[C0]], %[[affineMapR]], %[[affineMapS]]] : memref<1x1x128x64xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[C0]], %[[C0]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<1x1x4x8x8x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+// Test to check that we properly handle shuffled `inner_dims_pos` and `tiles.
+// In this example, the dimension at position `0` (aka `128`) is tiled with a factor of `32`.
+// While the dimension at position `2` (aka `2`) is tiled with a factor of `2`.
+func.func @shuffled_dim_pos_and_tiles(%arg0: memref<128x256x2x1000xf32>, %arg1: memref<4x256x1x1000x2x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [2, 0] inner_tiles = [2, 32] into %arg1 : memref<128x256x2x1000xf32> -> memref<4x256x1x1000x2x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK: func.func @shuffled_dim_pos_and_tiles(
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C1000:.*]] = arith.constant 1000 : index
+// CHECK-DAG: %[[C256:.*]] = arith.constant 256 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: scf.for %[[i:.*]] = %[[C0]] to %[[C4]] step %[[C1]] {
+// CHECK: scf.for %[[j:.*]] = %[[C0]] to %[[C256]] step %[[C1]] {
+// CHECK: scf.for %[[l:.*]] = %[[C0]] to %[[C1000]] step %[[C1]] {
+// CHECK: scf.for %[[m:.*]] = %[[C0]] to %[[C2]] step %[[C1]] {
+// CHECK: scf.for %[[n:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[affineApplyZero:.*]] = affine.apply #[[MAP0]](%[[i]], %[[n]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[affineApplyZero]], %[[j]], %[[m]], %[[l]]] : memref<128x256x2x1000xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[i]], %[[j]], %[[C0]], %[[l]], %[[m]], %[[n]]] : memref<4x256x1x1000x2x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCRS_to_KCRSsr(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x32xf32>) {
+ linalg.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : memref<?x?x?x?xf32> -> memref<?x?x?x?x8x32xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 32 + d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: func.func @KCRS_to_KCRSsr(
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
+// CHECK-DAG: %[[dimZero:.*]] = memref.dim %arg1, %[[C0]] : memref<?x?x?x?x8x32xf32>
+// CHECK-DAG: %[[dimOne:.*]] = memref.dim %arg1, %[[C1]] : memref<?x?x?x?x8x32xf32>
+// CHECK-DAG: %[[dimTwo:.*]] = memref.dim %arg1, %[[C2]] : memref<?x?x?x?x8x32xf32>
+// CHECK-DAG: %[[dimThree:.*]] = memref.dim %arg1, %[[C3]] : memref<?x?x?x?x8x32xf32>
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[dimZero]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[dimOne]] step %[[C1]] {
+// CHECK: scf.for %[[R:.*]] = %[[C0]] to %[[dimTwo]] step %[[C1]] {
+// CHECK: scf.for %[[S:.*]] = %[[C0]] to %[[dimThree]] step %[[C1]] {
+// CHECK: scf.for %[[s:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[r:.*]] = %[[C0]] to %[[C32]] step %[[C1]] {
+// CHECK-DAG: %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])
+// CHECK-DAG: %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]])
+// CHECK: %[[scalar:.*]] = memref.load %arg0[%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref<?x?x?x?xf32>
+// CHECK: memref.store %[[scalar]], %arg1[%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<?x?x?x?x8x32xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
+// -----
+
+func.func @KCRS_to_KCRSsr(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?x8x?xf32>, %block : index) {
+ linalg.pack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, %block] into %arg1 : memref<?x?x?x?xf32> -> memref<?x?x?x?x8x?xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+ %broadcast = transform.structured.match ops{["linalg.pack"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ %0 = transform.structured.convert_to_loops %broadcast
+ : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0 * 8 + d1)>
+// CHECK: func.func @KCRS_to_KCRSsr
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-DAG: %[[dimZero:.*]] = memref.dim %[[ARG1]], %[[C0]] : memref<?x?x?x?x8x?xf32>
+// CHECK-DAG: %[[dimOne:.*]] = memref.dim %[[ARG1]], %[[C1]] : memref<?x?x?x?x8x?xf32>
+// CHECK-DAG: %[[dimTwo:.*]] = memref.dim %[[ARG1]], %[[C2]] : memref<?x?x?x?x8x?xf32>
+// CHECK-DAG: %[[dimThree:.*]] = memref.dim %[[ARG1]], %[[C3]] : memref<?x?x?x?x8x?xf32>
+// CHECK: scf.for %[[K:.*]] = %[[C0]] to %[[dimZero]] step %[[C1]] {
+// CHECK: scf.for %[[C:.*]] = %[[C0]] to %[[dimOne]] step %[[C1]] {
+// CHECK: scf.for %[[R:.*]] = %[[C0]] to %[[dimTwo]] step %[[C1]] {
+// CHECK: scf.for %[[S:.*]] = %[[C0]] to %[[dimThree]] step %[[C1]] {
+// CHECK: %[[dimFive:.*]] = memref.dim %[[ARG1]], %[[C5]] : memref<?x?x?x?x8x?xf32>
+// CHECK: scf.for %[[s:.*]] = %[[C0]] to %[[C8]] step %[[C1]] {
+// CHECK: scf.for %[[r:.*]] = %[[C0]] to %[[dimFive]] step %[[C1]] {
+// CHECK-DAG: %[[affineMapR:.*]] = affine.apply #[[MAP0]](%[[R]], %[[r]])[%[[ARG2]]]
+// CHECK-DAG: %[[affineMapS:.*]] = affine.apply #[[MAP1]](%[[S]], %[[s]])
+// CHECK: %[[scalar:.*]] = memref.load %[[ARG0]][%[[K]], %[[C]], %[[affineMapR]], %[[affineMapS]]] : memref<?x?x?x?xf32>
+// CHECK: memref.store %[[scalar]], %[[ARG1]][%[[K]], %[[C]], %[[R]], %[[S]], %[[s]], %[[r]]] : memref<?x?x?x?x8x?xf32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+
More information about the Mlir-commits
mailing list