[Mlir-commits] [mlir] [mlir][linalg] Pack matmul pass (PR #89782)

Adam Siemieniuk llvmlistbot at llvm.org
Thu Apr 25 07:56:02 PDT 2024


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/89782

>From ad5de2b01b364733483475a3238f22da7a2a2707 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 22 Apr 2024 18:21:34 +0200
Subject: [PATCH 1/3] [mlir][linalg] Pack matmul pass

Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
The result is unpacked back to the original layout.

Matmul packing splits the operands into smaller blocks (inner dimensions)
and then block-transposes the block sub-groups (outer dimensions).

This data arrangement minimizes distance between consecutive blocks
which improves spacial locality and cache behavior.
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  20 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Dialect/Linalg/Transforms/PackMatmul.cpp  | 177 ++++++++++++++++++
 mlir/test/Dialect/Linalg/pack-matmul.mlir     | 140 ++++++++++++++
 5 files changed, 342 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
 create mode 100644 mlir/test/Dialect/Linalg/pack-matmul.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..d4361c70468bdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,24 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+  let summary = "Convert linalg matmul ops to block layout and back";
+  let description = [{
+    Pack a matmul MxNxK operation into blocked layout
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
+    The result is unpacked back to the original layout.
+
+    Matmul packing splits the operands into smaller blocks (inner dimensions)
+    and then block-transposes the block sub-groups (outer dimensions).
+
+    This data arrangement minimizes distance between consecutive blocks
+    which improves spacial locality and cache behavior.
+  }];
+  let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+  let options = [
+    ListOption<"blockFactors", "block-factors", "int64_t",
+               "Block factors (mb, nb, kb) for relayout">
+  ];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c7012..2bb9277cc7b27e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1628,6 +1628,10 @@ void populateSplitReductionPattern(
 void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
                                      bool transposeLHS = true);
 
+/// Patterns to pack Linalg matmul ops.
+void populatePackMatmulPatterns(RewritePatternSet &patterns,
+                                ArrayRef<int64_t> blockingFactors);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..e9b104ea5aeb58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  PackMatmul.cpp
   Padding.cpp
   Promotion.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
new file mode 100644
index 00000000000000..304de03a343fdc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+  std::optional<int64_t> stride = getConstantIntValue(range.stride);
+  if (!stride || *stride != 1)
+    return std::nullopt;
+  std::optional<int64_t> offset = getConstantIntValue(range.offset);
+  if (!offset)
+    return std::nullopt;
+  std::optional<int64_t> size = getConstantIntValue(range.size);
+  if (!size)
+    return std::nullopt;
+  return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+                                    ArrayRef<OpFoldResult> tiles,
+                                    ArrayRef<size_t> dims) {
+  if (dims.size() != tiles.size() || tiles.empty())
+    return false;
+
+  OpBuilder builder(tileOp);
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<Range> iterationDomain =
+      cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+  for (auto dim : llvm::enumerate(dims)) {
+    if (dim.value() >= iterationDomain.size())
+      return false;
+
+    auto tileSize = getConstantIntValue(tiles[dim.index()]);
+    auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+    // If the tile factor or the range are non-constant, the tile size is
+    // considered to be invalid.
+    if (!tileSize || !rangeOnDim)
+      return false;
+
+    // The dimension must be fully divisible by the tile.
+    if (*rangeOnDim % *tileSize != 0)
+      return false;
+  }
+
+  return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+             ArrayRef<OpFoldResult> mnkTiles) {
+  if (!(isa<linalg::MatmulOp>(matmulOp) ||
+        isa<linalg::BatchMatmulOp>(matmulOp))) {
+    return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+  }
+
+  if (mnkTiles.size() != 3)
+    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+  if (matmulOp.hasDynamicShape())
+    return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+
+  if (matmulOp.hasPureBufferSemantics())
+    return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+
+  SmallVector<size_t, 3> dims{0, 1, 2};
+  // Skip the batch dimension if present.
+  bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
+  if (isBatchMatmulOp)
+    dims = {1, 2, 3};
+
+  if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+                               mnkTiles, dims)) {
+    return rewriter.notifyMatchFailure(matmulOp,
+                                       "expect packing full tiles only");
+  }
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // The op is replaced, we need to set the insertion point after it.
+  rewriter.setInsertionPointAfter(matmulOp);
+
+  auto packedCanonicalMatmul = packMatmulGreedily(
+      rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
+      /*mnkOrder=*/{0, 1, 2});
+  if (failed(packedCanonicalMatmul))
+    return failure();
+
+  assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
+  assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
+         "failed matmul unpacking");
+
+  SmallVector<int64_t> innerPerm = {1, 0};
+  SmallVector<int64_t> outerPerm = {1, 0};
+  // Leave the batch dimension as is.
+  if (isBatchMatmulOp)
+    outerPerm = {0, 2, 1};
+
+  auto packedMatmul =
+      packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+                    packedCanonicalMatmul->packedLinalgOp,
+                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+  if (failed(packedMatmul))
+    return failure();
+
+  return packedMatmul->transposedLinalgOp;
+}
+
+namespace {
+template <typename OpTy>
+struct PackMatmul : public OpRewritePattern<OpTy> {
+  PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+             PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+
+  LogicalResult matchAndRewrite(OpTy matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (blockFactors.empty())
+      return failure();
+    auto packedMatmul =
+        packMatmulOp(rewriter, matmulOp,
+                     getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  SmallVector<int64_t> blockFactors;
+};
+
+// Entry point for packing matmul operations.
+// Pack MatmulOp as following:
+// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+// Pack a BatchMatmulOp as following:
+// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
+struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
+  using LinalgPackMatmulBase::LinalgPackMatmulBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(&getContext());
+    linalg::populatePackMatmulPatterns(patterns, blockFactors);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
+                                        ArrayRef<int64_t> blockFactors) {
+  patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
+      patterns.getContext(), blockFactors);
+}
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
new file mode 100644
index 00000000000000..d7023cfc30559b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant dense<0.0> : tensor<128x128xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = tensor.empty() : tensor<128x128xf32>
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+                  outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[OUT:.+]] = linalg.add
+// CHECK-SAME:  ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+    %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = tensor.empty() : tensor<512x64x64xf32>
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+                           outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %1 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[OUT]] : tensor<512x64x64xf32>

>From b2f7ab4158dbe5b2364de3ea9dc914ddaaa6055f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 13:06:58 +0200
Subject: [PATCH 2/3] Improve test naming

---
 mlir/test/Dialect/Linalg/pack-matmul.mlir | 144 +++++++++++-----------
 1 file changed, 72 insertions(+), 72 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
index d7023cfc30559b..c704d0fb7d4fa5 100644
--- a/mlir/test/Dialect/Linalg/pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
 
 func.func @block_matmul(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
-  %0 = linalg.matmul  ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                      outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
 }
 
@@ -12,103 +12,103 @@ func.func @block_matmul(
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
 
 // CHECK-LABEL: func @block_matmul(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
-// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
-// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
-// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
 // CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
-// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
-// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_matmul_with_constant(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
-  %cst = arith.constant dense<0.0> : tensor<128x128xf32>
-  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                      outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%cst_acc : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
 }
 
 // CHECK-LABEL: func @block_matmul_with_constant(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
-// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_matmul_with_producer(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
   %cst = arith.constant 0.0 : f32
-  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                      outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %acc = linalg.fill ins(%cst : f32) outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%acc : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %1 : tensor<128x128xf32>
 }
 
 // CHECK-LABEL: func @block_matmul_with_producer(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
 // CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_matmul_with_consumer(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>, %D: tensor<128x128xf32>) -> tensor<128x128xf32> {
   %0 = tensor.empty() : tensor<128x128xf32>
-  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                     outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %2 = linalg.add ins(%1, %D : tensor<128x128xf32>, tensor<128x128xf32>)
                   outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %2 : tensor<128x128xf32>
 }
 
 // CHECK-LABEL: func @block_matmul_with_consumer(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>, %[[D:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: %[[OUT:.+]] = linalg.add
-// CHECK-SAME:  ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[ADD_RES:.+]] = linalg.add
+// CHECK-SAME:  ins(%[[RES_UNPACKED]], %[[D]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[RES_DST]] : tensor<128x128xf32>)
+// CHECK: return %[[ADD_RES]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_batch_matmul(
-    %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+    %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
   %0 = tensor.empty() : tensor<512x64x64xf32>
-  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
-                           outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  %1 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+                           outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
   return %1 : tensor<512x64x64xf32>
 }
 
@@ -117,24 +117,24 @@ func.func @block_batch_matmul(
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
 
 // CHECK-LABEL: func @block_batch_matmul(
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
-// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
 // CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
-// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[OUT]] : tensor<512x64x64xf32>
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>

>From e2ad091d54a879fd042d9b67bc453fc8e533104a Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 13:58:56 +0200
Subject: [PATCH 3/3] Improve description

---
 mlir/include/mlir/Dialect/Linalg/Passes.td | 39 +++++++++++++++++-----
 1 file changed, 30 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d4361c70468bdb..907ebb2f9c4be1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -144,15 +144,36 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
 def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
   let summary = "Convert linalg matmul ops to block layout and back";
   let description = [{
-    Pack a matmul MxNxK operation into blocked layout
-    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
-    The result is unpacked back to the original layout.
-
-    Matmul packing splits the operands into smaller blocks (inner dimensions)
-    and then block-transposes the block sub-groups (outer dimensions).
-
-    This data arrangement minimizes distance between consecutive blocks
-    which improves spacial locality and cache behavior.
+    Pack a matmul operation into blocked layout with two levels of subdivision:
+    - major 2D blocks - outer dimensions, consist of minor blocks
+    - minor 2D blocks - inner dimensions, consist of scalar elements
+
+    A 2D matmul MxNxK gets reshaped into blocked 4D representation
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+    where the (MB, NB, KB) dimensions represent the major blocks,
+    and the (mb, nb, kb) are the minor blocks of their respective
+    original 2D dimensions (M, N, K).
+
+    As a part of packing strategy, the RHS operand gets 'block transposed'
+    i.e., the major blocks [KB][NB] get transposed to [NB][KB] layout.
+    The minor blocks remain unchanged.
+    The final result is unpacked back to the original layout.
+
+    Given a matmul operation:
+    ```mlir
+      %res = linalg.matmul ins(%A, %B) outs(%C)
+    ```
+    the traformation can be represented as:
+    ```mlir
+      %A_packed = pack %A : 2D -> 4D
+      %B_packed = pack %B : 2D -> 4D #block_transposed
+      %C_packed = pack %C : 2D -> 4D
+      %res_packed = linalg.mmt4d ins(%A_packed, %B_packed) outs(%C_packed)
+      %res = unpack %res_packed : 4D -> 2D
+    ```
+
+    This packed data arrangement minimizes distance between consecutive
+    blocks which improves spacial locality and cache behavior.
   }];
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
   let options = [



More information about the Mlir-commits mailing list