[Mlir-commits] [mlir] [mlir][linalg] Pack matmul pass (PR #89782)

Adam Siemieniuk llvmlistbot at llvm.org
Tue May 7 06:48:58 PDT 2024


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/89782

>From ad5de2b01b364733483475a3238f22da7a2a2707 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 22 Apr 2024 18:21:34 +0200
Subject: [PATCH 01/22] [mlir][linalg] Pack matmul pass

Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
The result is unpacked back to the original layout.

Matmul packing splits the operands into smaller blocks (inner dimensions)
and then block-transposes the block sub-groups (outer dimensions).

This data arrangement minimizes distance between consecutive blocks
which improves spacial locality and cache behavior.
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  20 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |   4 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   1 +
 .../Dialect/Linalg/Transforms/PackMatmul.cpp  | 177 ++++++++++++++++++
 mlir/test/Dialect/Linalg/pack-matmul.mlir     | 140 ++++++++++++++
 5 files changed, 342 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
 create mode 100644 mlir/test/Dialect/Linalg/pack-matmul.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..d4361c70468bdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,24 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+  let summary = "Convert linalg matmul ops to block layout and back";
+  let description = [{
+    Pack a matmul MxNxK operation into blocked layout
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
+    The result is unpacked back to the original layout.
+
+    Matmul packing splits the operands into smaller blocks (inner dimensions)
+    and then block-transposes the block sub-groups (outer dimensions).
+
+    This data arrangement minimizes distance between consecutive blocks
+    which improves spacial locality and cache behavior.
+  }];
+  let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+  let options = [
+    ListOption<"blockFactors", "block-factors", "int64_t",
+               "Block factors (mb, nb, kb) for relayout">
+  ];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c7012..2bb9277cc7b27e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1628,6 +1628,10 @@ void populateSplitReductionPattern(
 void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
                                      bool transposeLHS = true);
 
+/// Patterns to pack Linalg matmul ops.
+void populatePackMatmulPatterns(RewritePatternSet &patterns,
+                                ArrayRef<int64_t> blockingFactors);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..e9b104ea5aeb58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  PackMatmul.cpp
   Padding.cpp
   Promotion.cpp
   Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
new file mode 100644
index 00000000000000..304de03a343fdc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+  std::optional<int64_t> stride = getConstantIntValue(range.stride);
+  if (!stride || *stride != 1)
+    return std::nullopt;
+  std::optional<int64_t> offset = getConstantIntValue(range.offset);
+  if (!offset)
+    return std::nullopt;
+  std::optional<int64_t> size = getConstantIntValue(range.size);
+  if (!size)
+    return std::nullopt;
+  return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+                                    ArrayRef<OpFoldResult> tiles,
+                                    ArrayRef<size_t> dims) {
+  if (dims.size() != tiles.size() || tiles.empty())
+    return false;
+
+  OpBuilder builder(tileOp);
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<Range> iterationDomain =
+      cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+  for (auto dim : llvm::enumerate(dims)) {
+    if (dim.value() >= iterationDomain.size())
+      return false;
+
+    auto tileSize = getConstantIntValue(tiles[dim.index()]);
+    auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+    // If the tile factor or the range are non-constant, the tile size is
+    // considered to be invalid.
+    if (!tileSize || !rangeOnDim)
+      return false;
+
+    // The dimension must be fully divisible by the tile.
+    if (*rangeOnDim % *tileSize != 0)
+      return false;
+  }
+
+  return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+             ArrayRef<OpFoldResult> mnkTiles) {
+  if (!(isa<linalg::MatmulOp>(matmulOp) ||
+        isa<linalg::BatchMatmulOp>(matmulOp))) {
+    return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+  }
+
+  if (mnkTiles.size() != 3)
+    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+  if (matmulOp.hasDynamicShape())
+    return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+
+  if (matmulOp.hasPureBufferSemantics())
+    return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+
+  SmallVector<size_t, 3> dims{0, 1, 2};
+  // Skip the batch dimension if present.
+  bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
+  if (isBatchMatmulOp)
+    dims = {1, 2, 3};
+
+  if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+                               mnkTiles, dims)) {
+    return rewriter.notifyMatchFailure(matmulOp,
+                                       "expect packing full tiles only");
+  }
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // The op is replaced, we need to set the insertion point after it.
+  rewriter.setInsertionPointAfter(matmulOp);
+
+  auto packedCanonicalMatmul = packMatmulGreedily(
+      rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
+      /*mnkOrder=*/{0, 1, 2});
+  if (failed(packedCanonicalMatmul))
+    return failure();
+
+  assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
+  assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
+         "failed matmul unpacking");
+
+  SmallVector<int64_t> innerPerm = {1, 0};
+  SmallVector<int64_t> outerPerm = {1, 0};
+  // Leave the batch dimension as is.
+  if (isBatchMatmulOp)
+    outerPerm = {0, 2, 1};
+
+  auto packedMatmul =
+      packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+                    packedCanonicalMatmul->packedLinalgOp,
+                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+  if (failed(packedMatmul))
+    return failure();
+
+  return packedMatmul->transposedLinalgOp;
+}
+
+namespace {
+template <typename OpTy>
+struct PackMatmul : public OpRewritePattern<OpTy> {
+  PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+             PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+
+  LogicalResult matchAndRewrite(OpTy matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (blockFactors.empty())
+      return failure();
+    auto packedMatmul =
+        packMatmulOp(rewriter, matmulOp,
+                     getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  SmallVector<int64_t> blockFactors;
+};
+
+// Entry point for packing matmul operations.
+// Pack MatmulOp as following:
+// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+// Pack a BatchMatmulOp as following:
+// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
+struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
+  using LinalgPackMatmulBase::LinalgPackMatmulBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(&getContext());
+    linalg::populatePackMatmulPatterns(patterns, blockFactors);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
+                                        ArrayRef<int64_t> blockFactors) {
+  patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
+      patterns.getContext(), blockFactors);
+}
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
new file mode 100644
index 00000000000000..d7023cfc30559b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.matmul  ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant dense<0.0> : tensor<128x128xf32>
+  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant 0.0 : f32
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = tensor.empty() : tensor<128x128xf32>
+  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+                  outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[OUT:.+]] = linalg.add
+// CHECK-SAME:  ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+    %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = tensor.empty() : tensor<512x64x64xf32>
+  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+                           outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %1 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME:   %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[OUT]] : tensor<512x64x64xf32>

>From b2f7ab4158dbe5b2364de3ea9dc914ddaaa6055f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 13:06:58 +0200
Subject: [PATCH 02/22] Improve test naming

---
 mlir/test/Dialect/Linalg/pack-matmul.mlir | 144 +++++++++++-----------
 1 file changed, 72 insertions(+), 72 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
index d7023cfc30559b..c704d0fb7d4fa5 100644
--- a/mlir/test/Dialect/Linalg/pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
 
 func.func @block_matmul(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
-  %0 = linalg.matmul  ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                      outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
 }
 
@@ -12,103 +12,103 @@ func.func @block_matmul(
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
 
 // CHECK-LABEL: func @block_matmul(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
-// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
-// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
-// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
 // CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
-// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
-// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_matmul_with_constant(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
-  %cst = arith.constant dense<0.0> : tensor<128x128xf32>
-  %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                      outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%cst_acc : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %0 : tensor<128x128xf32>
 }
 
 // CHECK-LABEL: func @block_matmul_with_constant(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
-// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_matmul_with_producer(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
   %cst = arith.constant 0.0 : f32
-  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                      outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %acc = linalg.fill ins(%cst : f32) outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%acc : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %1 : tensor<128x128xf32>
 }
 
 // CHECK-LABEL: func @block_matmul_with_producer(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
 // CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_matmul_with_consumer(
-    %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>, %D: tensor<128x128xf32>) -> tensor<128x128xf32> {
   %0 = tensor.empty() : tensor<128x128xf32>
-  %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
-                     outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
-  %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %2 = linalg.add ins(%1, %D : tensor<128x128xf32>, tensor<128x128xf32>)
                   outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
   return %2 : tensor<128x128xf32>
 }
 
 // CHECK-LABEL: func @block_matmul_with_consumer(
-// CHECK-SAME:    %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>, %[[D:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: %[[OUT:.+]] = linalg.add
-// CHECK-SAME:  ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[ADD_RES:.+]] = linalg.add
+// CHECK-SAME:  ins(%[[RES_UNPACKED]], %[[D]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[RES_DST]] : tensor<128x128xf32>)
+// CHECK: return %[[ADD_RES]] : tensor<128x128xf32>
 
 // -----
 
 func.func @block_batch_matmul(
-    %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+    %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
   %0 = tensor.empty() : tensor<512x64x64xf32>
-  %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
-                           outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  %1 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+                           outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
   return %1 : tensor<512x64x64xf32>
 }
 
@@ -117,24 +117,24 @@ func.func @block_batch_matmul(
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
 
 // CHECK-LABEL: func @block_batch_matmul(
-// CHECK-SAME:   %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
-// CHECK-SAME:  into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
-// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
 // CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
-// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME:  into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[OUT]] : tensor<512x64x64xf32>
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>

>From e2ad091d54a879fd042d9b67bc453fc8e533104a Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 13:58:56 +0200
Subject: [PATCH 03/22] Improve description

---
 mlir/include/mlir/Dialect/Linalg/Passes.td | 39 +++++++++++++++++-----
 1 file changed, 30 insertions(+), 9 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d4361c70468bdb..907ebb2f9c4be1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -144,15 +144,36 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
 def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
   let summary = "Convert linalg matmul ops to block layout and back";
   let description = [{
-    Pack a matmul MxNxK operation into blocked layout
-    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
-    The result is unpacked back to the original layout.
-
-    Matmul packing splits the operands into smaller blocks (inner dimensions)
-    and then block-transposes the block sub-groups (outer dimensions).
-
-    This data arrangement minimizes distance between consecutive blocks
-    which improves spacial locality and cache behavior.
+    Pack a matmul operation into blocked layout with two levels of subdivision:
+    - major 2D blocks - outer dimensions, consist of minor blocks
+    - minor 2D blocks - inner dimensions, consist of scalar elements
+
+    A 2D matmul MxNxK gets reshaped into blocked 4D representation
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+    where the (MB, NB, KB) dimensions represent the major blocks,
+    and the (mb, nb, kb) are the minor blocks of their respective
+    original 2D dimensions (M, N, K).
+
+    As a part of packing strategy, the RHS operand gets 'block transposed'
+    i.e., the major blocks [KB][NB] get transposed to [NB][KB] layout.
+    The minor blocks remain unchanged.
+    The final result is unpacked back to the original layout.
+
+    Given a matmul operation:
+    ```mlir
+      %res = linalg.matmul ins(%A, %B) outs(%C)
+    ```
+    the traformation can be represented as:
+    ```mlir
+      %A_packed = pack %A : 2D -> 4D
+      %B_packed = pack %B : 2D -> 4D #block_transposed
+      %C_packed = pack %C : 2D -> 4D
+      %res_packed = linalg.mmt4d ins(%A_packed, %B_packed) outs(%C_packed)
+      %res = unpack %res_packed : 4D -> 2D
+    ```
+
+    This packed data arrangement minimizes distance between consecutive
+    blocks which improves spacial locality and cache behavior.
   }];
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
   let options = [

>From 5e6aff876ddda70d67c456b249a1744cf4cbe5a7 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 16:56:32 +0200
Subject: [PATCH 04/22] Pack options

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   9 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |  22 +++-
 .../Dialect/Linalg/Transforms/PackMatmul.cpp  | 111 ++++++++++++------
 3 files changed, 107 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 907ebb2f9c4be1..e9fd990c38ab7b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -178,7 +178,14 @@ def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
   let options = [
     ListOption<"blockFactors", "block-factors", "int64_t",
-               "Block factors (mb, nb, kb) for relayout">
+               "Block factors (mb, nb, kb) for relayout">,
+    ListOption<"mnkOrder", "mnk-order", "int64_t",
+               "Permutation of (mb, nb, kb) dimensions order">,
+    ListOption<"mnkPaddedSizesNextMultipleOf", "mnk-padded-multiples", "int64_t",
+               "Packing sizes next multiple">,
+    Option<"allowPadding", "allow-padding", "bool",
+           /*default=*/"true",
+           "Allow packing padding">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2bb9277cc7b27e..299cbd80768900 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1162,6 +1162,26 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                    ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
                    ArrayRef<int64_t> mnkOrder);
 
+struct PackMatmulOptions {
+  /// Minor block factors for packing relayout in the 'mnkOrder'.
+  SmallVector<int64_t, 3> blockFactors;
+  /// Order of packed dimensions (mb, nb, kb) - permutation of the default
+  /// order.
+  SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
+  SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
+  bool allowPadding = true;
+};
+/// Function type which is used to control matmul block packing.
+/// It is expected to return valid packing configuration for each operation.
+/// Lack of options indicates no valid configuration could be assigned and
+/// will prevent any packing from occuring.
+using ControlPackMatmulFn =
+    std::function<std::optional<PackMatmulOptions>(linalg::LinalgOp)>;
+
+FailureOr<PackResult>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+             const ControlPackMatmulFn &controlPackMatmul);
+
 /// Rewrite tensor.from_elements to linalg.generic.
 FailureOr<Operation *>
 rewriteInDestinationPassingStyle(RewriterBase &rewriter,
@@ -1630,7 +1650,7 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 
 /// Patterns to pack Linalg matmul ops.
 void populatePackMatmulPatterns(RewritePatternSet &patterns,
-                                ArrayRef<int64_t> blockingFactors);
+                                const ControlPackMatmulFn &controlFn);
 
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
index 304de03a343fdc..befd6dc7b2cdff 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -41,7 +41,7 @@ static std::optional<int64_t> getConstantRange(const Range &range) {
 
 static bool validateFullTilesOnDims(TilingInterface tileOp,
                                     ArrayRef<OpFoldResult> tiles,
-                                    ArrayRef<size_t> dims) {
+                                    ArrayRef<int64_t> dims) {
   if (dims.size() != tiles.size() || tiles.empty())
     return false;
 
@@ -51,7 +51,7 @@ static bool validateFullTilesOnDims(TilingInterface tileOp,
       cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
 
   for (auto dim : llvm::enumerate(dims)) {
-    if (dim.value() >= iterationDomain.size())
+    if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
       return false;
 
     auto tileSize = getConstantIntValue(tiles[dim.index()]);
@@ -70,42 +70,62 @@ static bool validateFullTilesOnDims(TilingInterface tileOp,
   return true;
 }
 
-static FailureOr<linalg::LinalgOp>
-packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-             ArrayRef<OpFoldResult> mnkTiles) {
+FailureOr<PackResult>
+linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+                     const ControlPackMatmulFn &controlPackMatmul) {
   if (!(isa<linalg::MatmulOp>(matmulOp) ||
-        isa<linalg::BatchMatmulOp>(matmulOp))) {
+        isa<linalg::BatchMatmulOp>(matmulOp) ||
+        isa<linalg::MatmulTransposeAOp>(matmulOp) ||
+        isa<linalg::MatmulTransposeBOp>(matmulOp) ||
+        isa<linalg::BatchMatmulTransposeAOp>(matmulOp) ||
+        isa<linalg::BatchMatmulTransposeBOp>(matmulOp))) {
     return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
   }
 
-  if (mnkTiles.size() != 3)
-    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
-
   if (matmulOp.hasDynamicShape())
     return rewriter.notifyMatchFailure(matmulOp, "require static shape");
 
   if (matmulOp.hasPureBufferSemantics())
     return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
 
-  SmallVector<size_t, 3> dims{0, 1, 2};
+  std::optional<PackMatmulOptions> options = controlPackMatmul(matmulOp);
+  if (!options)
+    return rewriter.notifyMatchFailure(matmulOp, "invalid packing options");
+
+  if (options->blockFactors.size() != 3)
+    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+  auto mnkTiles =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
+
+  SmallVector<int64_t, 3> dims{options->mnkOrder};
   // Skip the batch dimension if present.
-  bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
-  if (isBatchMatmulOp)
-    dims = {1, 2, 3};
+  bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp) ||
+                         isa<linalg::BatchMatmulTransposeAOp>(matmulOp) ||
+                         isa<linalg::BatchMatmulTransposeBOp>(matmulOp);
+  if (isBatchMatmulOp) {
+    // Offset all dimensions.
+    for (size_t i = 0; i < dims.size(); i++)
+      ++dims[i];
+  }
 
-  if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+  if (!options->allowPadding &&
+      !validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
                                mnkTiles, dims)) {
     return rewriter.notifyMatchFailure(matmulOp,
                                        "expect packing full tiles only");
   }
 
+  bool isTransposedRhs = isa<linalg::MatmulTransposeBOp>(matmulOp) ||
+                         isa<linalg::BatchMatmulTransposeBOp>(matmulOp);
+
   OpBuilder::InsertionGuard guard(rewriter);
   // The op is replaced, we need to set the insertion point after it.
   rewriter.setInsertionPointAfter(matmulOp);
 
   auto packedCanonicalMatmul = packMatmulGreedily(
-      rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
-      /*mnkOrder=*/{0, 1, 2});
+      rewriter, matmulOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
+      options->mnkOrder);
   if (failed(packedCanonicalMatmul))
     return failure();
 
@@ -113,11 +133,21 @@ packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
          "failed matmul unpacking");
 
-  SmallVector<int64_t> innerPerm = {1, 0};
-  SmallVector<int64_t> outerPerm = {1, 0};
+  SmallVector<int64_t> innerPerm{1, 0};
+  SmallVector<int64_t> outerPerm{1, 0};
+  // No need to block transpose if the RHS matrix is already transposed.
+  if (isTransposedRhs)
+    outerPerm = {0, 1};
+
   // Leave the batch dimension as is.
-  if (isBatchMatmulOp)
-    outerPerm = {0, 2, 1};
+  if (isBatchMatmulOp) {
+    // Account for the batch dimension.
+    SmallVector<int64_t> newOuterPerms{0};
+    // Offset all permutations.
+    for (auto perm : outerPerm)
+      newOuterPerms.push_back(++perm);
+    outerPerm = newOuterPerms;
+  }
 
   auto packedMatmul =
       packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
@@ -126,30 +156,28 @@ packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   if (failed(packedMatmul))
     return failure();
 
-  return packedMatmul->transposedLinalgOp;
+  packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
+
+  return packedCanonicalMatmul;
 }
 
 namespace {
 template <typename OpTy>
 struct PackMatmul : public OpRewritePattern<OpTy> {
-  PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+  PackMatmul(MLIRContext *context, ControlPackMatmulFn fun,
              PatternBenefit benefit = 1)
-      : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+      : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
 
   LogicalResult matchAndRewrite(OpTy matmulOp,
                                 PatternRewriter &rewriter) const override {
-    if (blockFactors.empty())
-      return failure();
-    auto packedMatmul =
-        packMatmulOp(rewriter, matmulOp,
-                     getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+    auto packedMatmul = packMatmulOp(rewriter, matmulOp, controlFn);
     if (failed(packedMatmul))
       return failure();
     return success();
   }
 
 private:
-  SmallVector<int64_t> blockFactors;
+  ControlPackMatmulFn controlFn;
 };
 
 // Entry point for packing matmul operations.
@@ -163,7 +191,20 @@ struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
   void runOnOperation() override {
     Operation *op = getOperation();
     RewritePatternSet patterns(&getContext());
-    linalg::populatePackMatmulPatterns(patterns, blockFactors);
+
+    ControlPackMatmulFn controlFn =
+        [&](linalg::LinalgOp op) -> PackMatmulOptions {
+      PackMatmulOptions options;
+      options.blockFactors = SmallVector<int64_t>{*blockFactors};
+      if (!mnkOrder.empty())
+        options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
+      options.mnkPaddedSizesNextMultipleOf =
+          SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
+      options.allowPadding = allowPadding;
+      return options;
+    };
+
+    linalg::populatePackMatmulPatterns(patterns, controlFn);
     if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
       return signalPassFailure();
   }
@@ -171,7 +212,11 @@ struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
 } // namespace
 
 void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
-                                        ArrayRef<int64_t> blockFactors) {
-  patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
-      patterns.getContext(), blockFactors);
+                                        const ControlPackMatmulFn &controlFn) {
+  patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>,
+               PackMatmul<linalg::MatmulTransposeAOp>,
+               PackMatmul<linalg::BatchMatmulTransposeAOp>,
+               PackMatmul<linalg::MatmulTransposeBOp>,
+               PackMatmul<linalg::BatchMatmulTransposeBOp>>(
+      patterns.getContext(), controlFn);
 }

>From 80b1aae4c0ffd3eda87c745a5a949601701da9a8 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 2 May 2024 11:46:52 +0200
Subject: [PATCH 05/22] Comments

---
 mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
index befd6dc7b2cdff..e36c689fa193d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -123,6 +123,10 @@ linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   // The op is replaced, we need to set the insertion point after it.
   rewriter.setInsertionPointAfter(matmulOp);
 
+  // Pack the matmul operation into blocked layout with two levels of
+  // subdivision:
+  //   - major 2D blocks - outer dimensions, consist of minor blocks
+  //   - minor 2D blocks - inner dimensions, consist of scalar elements
   auto packedCanonicalMatmul = packMatmulGreedily(
       rewriter, matmulOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
       options->mnkOrder);
@@ -149,6 +153,9 @@ linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
     outerPerm = newOuterPerms;
   }
 
+  // Block transpose the packed matmul i.e., transpose the outer dimensions
+  // layout of the RHS matrix. The inner dimensions (minor blocks) remain
+  // unchanged.
   auto packedMatmul =
       packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
                     packedCanonicalMatmul->packedLinalgOp,

>From 13bc86a2a3cbc5db924d47aa291f0f2d3aa40af3 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 2 May 2024 11:48:37 +0200
Subject: [PATCH 06/22] Expand autos

---
 mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
index e36c689fa193d7..ca7e4287020992 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -54,8 +54,9 @@ static bool validateFullTilesOnDims(TilingInterface tileOp,
     if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
       return false;
 
-    auto tileSize = getConstantIntValue(tiles[dim.index()]);
-    auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+    std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
+    std::optional<int64_t> rangeOnDim =
+        getConstantRange(iterationDomain[dim.value()]);
 
     // If the tile factor or the range are non-constant, the tile size is
     // considered to be invalid.
@@ -95,7 +96,7 @@ linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   if (options->blockFactors.size() != 3)
     return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
 
-  auto mnkTiles =
+  SmallVector<OpFoldResult> mnkTiles =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
 
   SmallVector<int64_t, 3> dims{options->mnkOrder};
@@ -127,7 +128,7 @@ linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   // subdivision:
   //   - major 2D blocks - outer dimensions, consist of minor blocks
   //   - minor 2D blocks - inner dimensions, consist of scalar elements
-  auto packedCanonicalMatmul = packMatmulGreedily(
+  FailureOr<PackResult> packedCanonicalMatmul = packMatmulGreedily(
       rewriter, matmulOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
       options->mnkOrder);
   if (failed(packedCanonicalMatmul))
@@ -156,7 +157,7 @@ linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   // Block transpose the packed matmul i.e., transpose the outer dimensions
   // layout of the RHS matrix. The inner dimensions (minor blocks) remain
   // unchanged.
-  auto packedMatmul =
+  FailureOr<PackTransposeResult> packedMatmul =
       packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
                     packedCanonicalMatmul->packedLinalgOp,
                     /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
@@ -177,7 +178,8 @@ struct PackMatmul : public OpRewritePattern<OpTy> {
 
   LogicalResult matchAndRewrite(OpTy matmulOp,
                                 PatternRewriter &rewriter) const override {
-    auto packedMatmul = packMatmulOp(rewriter, matmulOp, controlFn);
+    FailureOr<PackResult> packedMatmul =
+        packMatmulOp(rewriter, matmulOp, controlFn);
     if (failed(packedMatmul))
       return failure();
     return success();

>From f49aaf0b06e211e26d93881733f9986bda9cf6ef Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 2 May 2024 12:24:07 +0200
Subject: [PATCH 07/22] Docs

---
 .../Dialect/Linalg/Transforms/Transforms.h    | 29 +++++++++++++++++--
 1 file changed, 26 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 299cbd80768900..40530e2448114f 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1165,19 +1165,42 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
 struct PackMatmulOptions {
   /// Minor block factors for packing relayout in the 'mnkOrder'.
   SmallVector<int64_t, 3> blockFactors;
+
   /// Order of packed dimensions (mb, nb, kb) - permutation of the default
   /// order.
   SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
+
   SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
+
+  /// If true, allows packing of dimensions that only partially fit into the
+  /// block factors.
   bool allowPadding = true;
 };
-/// Function type which is used to control matmul block packing.
+
+/// Function type which is used to control matmul packing.
 /// It is expected to return valid packing configuration for each operation.
-/// Lack of options indicates no valid configuration could be assigned and
-/// will prevent any packing from occuring.
+/// Lack of packing options indicates that no valid configuration could be
+/// assigned and the operation will not be packed.
 using ControlPackMatmulFn =
     std::function<std::optional<PackMatmulOptions>(linalg::LinalgOp)>;
 
+/// Pack a matmul operation into blocked 4D layout.
+///
+/// Relayout a matmul operation into blocked layout with two levels of
+/// subdivision:
+///   - major 2D blocks - outer dimensions, consist of minor blocks
+///   - minor 2D blocks - inner dimensions, consist of scalar elements
+///
+/// A 2D matmul MxNxK gets reshaped into blocked 4D representation
+/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+/// where the (MB, NB, KB) dimensions represent the major blocks,
+/// and the (mb, nb, kb) are the minor blocks of their respective
+/// original 2D dimensions (M, N, K).
+///
+/// As a part of packing strategy, the RHS operand gets 'block transposed'
+/// i.e., the major blocks [KB][NB] get transposed to [NB][KB] layout.
+/// The minor blocks remain unchanged.
+/// The final result is unpacked back to the original layout.
 FailureOr<PackResult>
 packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
              const ControlPackMatmulFn &controlPackMatmul);

>From ef541d7d6662209c43eae78d26ff18cb40e28e1d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 3 May 2024 13:41:24 +0200
Subject: [PATCH 08/22] Rename to block pack matmul

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  2 +-
 .../Dialect/Linalg/Transforms/Transforms.h    | 16 +++----
 .../{PackMatmul.cpp => BlockPackMatmul.cpp}   | 47 ++++++++++---------
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  2 +-
 ...ack-matmul.mlir => block-pack-matmul.mlir} |  2 +-
 5 files changed, 35 insertions(+), 34 deletions(-)
 rename mlir/lib/Dialect/Linalg/Transforms/{PackMatmul.cpp => BlockPackMatmul.cpp} (83%)
 rename mlir/test/Dialect/Linalg/{pack-matmul.mlir => block-pack-matmul.mlir} (98%)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index e9fd990c38ab7b..b783f95fbb5538 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,7 +141,7 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
-def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
   let summary = "Convert linalg matmul ops to block layout and back";
   let description = [{
     Pack a matmul operation into blocked layout with two levels of subdivision:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 40530e2448114f..684183989fdfb3 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1163,11 +1163,11 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                    ArrayRef<int64_t> mnkOrder);
 
 struct PackMatmulOptions {
-  /// Minor block factors for packing relayout in the 'mnkOrder'.
+  /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
+  /// the parallel dimensions and kb is the reduction dimension.
   SmallVector<int64_t, 3> blockFactors;
 
-  /// Order of packed dimensions (mb, nb, kb) - permutation of the default
-  /// order.
+  /// Order of the packed dimensions (mb, nb, kb).
   SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
 
   SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
@@ -1202,8 +1202,8 @@ using ControlPackMatmulFn =
 /// The minor blocks remain unchanged.
 /// The final result is unpacked back to the original layout.
 FailureOr<PackResult>
-packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-             const ControlPackMatmulFn &controlPackMatmul);
+blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+                  const ControlPackMatmulFn &controlPackMatmul);
 
 /// Rewrite tensor.from_elements to linalg.generic.
 FailureOr<Operation *>
@@ -1671,9 +1671,9 @@ void populateSplitReductionPattern(
 void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
                                      bool transposeLHS = true);
 
-/// Patterns to pack Linalg matmul ops.
-void populatePackMatmulPatterns(RewritePatternSet &patterns,
-                                const ControlPackMatmulFn &controlFn);
+/// Patterns to block pack Linalg matmul ops.
+void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
+                                     const ControlPackMatmulFn &controlFn);
 
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
similarity index 83%
rename from mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
rename to mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index ca7e4287020992..c38423b12288ba 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -1,4 +1,4 @@
-//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -19,13 +19,14 @@
 #include <optional>
 
 namespace mlir {
-#define GEN_PASS_DEF_LINALGPACKMATMUL
+#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
 #include "mlir/Dialect/Linalg/Passes.h.inc"
 } // namespace mlir
 
 using namespace mlir;
 using namespace mlir::linalg;
 
+/// Return constant range span or nullopt, otherwise.
 static std::optional<int64_t> getConstantRange(const Range &range) {
   std::optional<int64_t> stride = getConstantIntValue(range.stride);
   if (!stride || *stride != 1)
@@ -39,6 +40,7 @@ static std::optional<int64_t> getConstantRange(const Range &range) {
   return (*size - *offset);
 }
 
+/// Return true if all dimensions are fully divisible by the respective tiles.
 static bool validateFullTilesOnDims(TilingInterface tileOp,
                                     ArrayRef<OpFoldResult> tiles,
                                     ArrayRef<int64_t> dims) {
@@ -71,9 +73,10 @@ static bool validateFullTilesOnDims(TilingInterface tileOp,
   return true;
 }
 
+/// Pack a matmul operation into blocked 4D layout.
 FailureOr<PackResult>
-linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-                     const ControlPackMatmulFn &controlPackMatmul) {
+linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+                          const ControlPackMatmulFn &controlPackMatmul) {
   if (!(isa<linalg::MatmulOp>(matmulOp) ||
         isa<linalg::BatchMatmulOp>(matmulOp) ||
         isa<linalg::MatmulTransposeAOp>(matmulOp) ||
@@ -171,15 +174,15 @@ linalg::packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
 
 namespace {
 template <typename OpTy>
-struct PackMatmul : public OpRewritePattern<OpTy> {
-  PackMatmul(MLIRContext *context, ControlPackMatmulFn fun,
-             PatternBenefit benefit = 1)
+struct BlockPackMatmul : public OpRewritePattern<OpTy> {
+  BlockPackMatmul(MLIRContext *context, ControlPackMatmulFn fun,
+                  PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
 
   LogicalResult matchAndRewrite(OpTy matmulOp,
                                 PatternRewriter &rewriter) const override {
     FailureOr<PackResult> packedMatmul =
-        packMatmulOp(rewriter, matmulOp, controlFn);
+        blockPackMatmulOp(rewriter, matmulOp, controlFn);
     if (failed(packedMatmul))
       return failure();
     return success();
@@ -189,13 +192,10 @@ struct PackMatmul : public OpRewritePattern<OpTy> {
   ControlPackMatmulFn controlFn;
 };
 
-// Entry point for packing matmul operations.
-// Pack MatmulOp as following:
-// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
-// Pack a BatchMatmulOp as following:
-// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
-struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
-  using LinalgPackMatmulBase::LinalgPackMatmulBase;
+/// Convert linalg matmul ops to block layout and back.
+struct LinalgBlockPackMatmul
+    : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
+  using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
 
   void runOnOperation() override {
     Operation *op = getOperation();
@@ -213,19 +213,20 @@ struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
       return options;
     };
 
-    linalg::populatePackMatmulPatterns(patterns, controlFn);
+    linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
     if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
       return signalPassFailure();
   }
 };
 } // namespace
 
-void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
-                                        const ControlPackMatmulFn &controlFn) {
-  patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>,
-               PackMatmul<linalg::MatmulTransposeAOp>,
-               PackMatmul<linalg::BatchMatmulTransposeAOp>,
-               PackMatmul<linalg::MatmulTransposeBOp>,
-               PackMatmul<linalg::BatchMatmulTransposeBOp>>(
+void linalg::populateBlockPackMatmulPatterns(
+    RewritePatternSet &patterns, const ControlPackMatmulFn &controlFn) {
+  patterns.add<BlockPackMatmul<linalg::MatmulOp>,
+               BlockPackMatmul<linalg::BatchMatmulOp>,
+               BlockPackMatmul<linalg::MatmulTransposeAOp>,
+               BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
+               BlockPackMatmul<linalg::MatmulTransposeBOp>,
+               BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
       patterns.getContext(), controlFn);
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index e9b104ea5aeb58..5a45859e27d626 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,7 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
-  PackMatmul.cpp
+  BlockPackMatmul.cpp
   Padding.cpp
   Promotion.cpp
   Specialize.cpp
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
similarity index 98%
rename from mlir/test/Dialect/Linalg/pack-matmul.mlir
rename to mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index c704d0fb7d4fa5..46062198b69930 100644
--- a/mlir/test/Dialect/Linalg/pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-block-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
 
 func.func @block_matmul(
     %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {

>From ce2238d0a1a4d6f4d133ee65ca9012b4cf820caa Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 3 May 2024 14:18:53 +0200
Subject: [PATCH 09/22] Add test cases

---
 .../Dialect/Linalg/block-pack-matmul.mlir     | 144 ++++++++++++++++++
 1 file changed, 144 insertions(+)

diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index 46062198b69930..481c018f5876a8 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -138,3 +138,147 @@ func.func @block_batch_matmul(
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
 // CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
+
+// -----
+
+func.func @block_matmul_transpose_a(
+    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul_transpose_a(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  inner_dims_pos = [1, 0] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
+
+// -----
+
+func.func @block_batch_matmul_transpose_a(
+    %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>)
+                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %0 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d1, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul_transpose_a(
+// CHECK-SAME:   %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  inner_dims_pos = [2, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
+
+// -----
+
+func.func @block_matmul_transpose_b(
+    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul_transpose_b(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [1, 0] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
+
+// -----
+
+func.func @block_batch_matmul_transpose_b(
+    %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>)
+                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %0 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul_transpose_b(
+// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [2, 1] inner_tiles = [64, 16]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>

>From 204f379114476452ca589a58c8a24dda2168e705 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 3 May 2024 18:00:12 +0200
Subject: [PATCH 10/22] Formatting

---
 mlir/include/mlir/Dialect/Linalg/Passes.td               | 4 ++--
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 5 ++---
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index b783f95fbb5538..fa0cdc70a97eb9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -159,11 +159,11 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
     The minor blocks remain unchanged.
     The final result is unpacked back to the original layout.
 
-    Given a matmul operation:
+    For example, given a matmul operation:
     ```mlir
       %res = linalg.matmul ins(%A, %B) outs(%C)
     ```
-    the traformation can be represented as:
+    the transformation result can be represented as:
     ```mlir
       %A_packed = pack %A : 2D -> 4D
       %B_packed = pack %B : 2D -> 4D #block_transposed
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 684183989fdfb3..719786f241a9ec 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1197,9 +1197,8 @@ using ControlPackMatmulFn =
 /// and the (mb, nb, kb) are the minor blocks of their respective
 /// original 2D dimensions (M, N, K).
 ///
-/// As a part of packing strategy, the RHS operand gets 'block transposed'
-/// i.e., the major blocks [KB][NB] get transposed to [NB][KB] layout.
-/// The minor blocks remain unchanged.
+/// The RHS operand gets 'block transposed' i.e., the major blocks [KB][NB]
+/// get transposed to [NB][KB] layout. The minor blocks remain unchanged.
 /// The final result is unpacked back to the original layout.
 FailureOr<PackResult>
 blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,

>From 6003bddcf572735925c5448a23c3cdcc510f1c9c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 3 May 2024 18:00:41 +0200
Subject: [PATCH 11/22] Fix transposition for matmul variants

---
 .../Linalg/Transforms/BlockPackMatmul.cpp     | 59 ++++++++++++++-----
 1 file changed, 43 insertions(+), 16 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index c38423b12288ba..7bad2d5c4b6a7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -120,6 +120,8 @@ linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
                                        "expect packing full tiles only");
   }
 
+  bool isTransposedLhs = isa<linalg::MatmulTransposeAOp>(matmulOp) ||
+                         isa<linalg::BatchMatmulTransposeAOp>(matmulOp);
   bool isTransposedRhs = isa<linalg::MatmulTransposeBOp>(matmulOp) ||
                          isa<linalg::BatchMatmulTransposeBOp>(matmulOp);
 
@@ -141,32 +143,57 @@ linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
          "failed matmul unpacking");
 
-  SmallVector<int64_t> innerPerm{1, 0};
-  SmallVector<int64_t> outerPerm{1, 0};
-  // No need to block transpose if the RHS matrix is already transposed.
-  if (isTransposedRhs)
-    outerPerm = {0, 1};
-
-  // Leave the batch dimension as is.
-  if (isBatchMatmulOp) {
+  auto applyBatchDim = [&](ArrayRef<int64_t> perms) -> SmallVector<int64_t> {
     // Account for the batch dimension.
-    SmallVector<int64_t> newOuterPerms{0};
+    SmallVector<int64_t> newPerms{0};
     // Offset all permutations.
-    for (auto perm : outerPerm)
-      newOuterPerms.push_back(++perm);
-    outerPerm = newOuterPerms;
+    for (auto perm : perms)
+      newPerms.push_back(++perm);
+    return newPerms;
+  };
+
+  // If needed, block transpose the packed matmul i.e., transpose the outer
+  // dimensions. The inner dimensions (minor blocks) remain unchanged.
+  if (isTransposedLhs) {
+    // The inner blocks' layout is already correctly enforced by the initial
+    // packing.
+    SmallVector<int64_t> lhsInnerPerm{0, 1};
+    // Only block transpose the outer dimensions for LHS matrix.
+    SmallVector<int64_t> lhsOuterPerm{1, 0};
+    // Leave the batch dimension as is.
+    if (isBatchMatmulOp)
+      lhsOuterPerm = applyBatchDim(lhsOuterPerm);
+
+    FailureOr<PackTransposeResult> packedMatmul =
+        packTranspose(rewriter, packedCanonicalMatmul->packOps[0],
+                      packedCanonicalMatmul->packedLinalgOp,
+                      /*maybeUnPackOp=*/nullptr, lhsOuterPerm, lhsInnerPerm);
+    if (failed(packedMatmul))
+      return failure();
+
+    packedCanonicalMatmul->packOps[0] = packedMatmul->transposedPackOp;
+    packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
   }
 
-  // Block transpose the packed matmul i.e., transpose the outer dimensions
-  // layout of the RHS matrix. The inner dimensions (minor blocks) remain
-  // unchanged.
+  // Transpose the layout of the inner dimension (minor blocks).
+  SmallVector<int64_t> rhsInnerPerm{1, 0};
+  // Block transpose the RHS matrix i.e., transpose the outer dimensions.
+  SmallVector<int64_t> rhsOuterPerm{1, 0};
+  // No need to block transpose if the RHS matrix is already transposed.
+  if (isTransposedRhs)
+    rhsOuterPerm = {0, 1};
+  // Leave the batch dimension as is.
+  if (isBatchMatmulOp)
+    rhsOuterPerm = applyBatchDim(rhsOuterPerm);
+
   FailureOr<PackTransposeResult> packedMatmul =
       packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
                     packedCanonicalMatmul->packedLinalgOp,
-                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+                    /*maybeUnPackOp=*/nullptr, rhsOuterPerm, rhsInnerPerm);
   if (failed(packedMatmul))
     return failure();
 
+  packedCanonicalMatmul->packOps[1] = packedMatmul->transposedPackOp;
   packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
 
   return packedCanonicalMatmul;

>From cf6ca57c3539f53ab8bac3512d3e767da53db2a8 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 6 May 2024 16:42:50 +0200
Subject: [PATCH 12/22] Cleanup test

---
 mlir/test/Dialect/Linalg/block-pack-matmul.mlir | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index 481c018f5876a8..6f134a494e8d0c 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -106,10 +106,9 @@ func.func @block_matmul_with_consumer(
 
 func.func @block_batch_matmul(
     %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
-  %0 = tensor.empty() : tensor<512x64x64xf32>
-  %1 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+  %0 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
                            outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
-  return %1 : tensor<512x64x64xf32>
+  return %0 : tensor<512x64x64xf32>
 }
 
 // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>

>From 22069e21fc94295f1017b0f2abc81f34b29d66fb Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 6 May 2024 16:43:09 +0200
Subject: [PATCH 13/22] WIP better options

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  22 ++-
 .../Dialect/Linalg/Transforms/Transforms.h    |  22 ++-
 .../Linalg/Transforms/BlockPackMatmul.cpp     | 186 ++++++++++++++----
 3 files changed, 185 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index fa0cdc70a97eb9..60d8b5bdec0947 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -179,13 +179,25 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
   let options = [
     ListOption<"blockFactors", "block-factors", "int64_t",
                "Block factors (mb, nb, kb) for relayout">,
-    ListOption<"mnkOrder", "mnk-order", "int64_t",
-               "Permutation of (mb, nb, kb) dimensions order">,
-    ListOption<"mnkPaddedSizesNextMultipleOf", "mnk-padded-multiples", "int64_t",
-               "Packing sizes next multiple">,
     Option<"allowPadding", "allow-padding", "bool",
            /*default=*/"true",
-           "Allow packing padding">
+           "Allow packing padding">,
+    ListOption<"mnkPaddedSizesNextMultipleOf", "mnk-padded-multiples", "int64_t",
+               "Next multiple of the packing sizes">,
+    ListOption<"mnkOrder", "mnk-order", "int64_t",
+               "Permutation of matmul (M, N, K) dimensions order">,
+    Option<"lhsTransposeOuterBlocks", "lhs-transpose-outer-blocks", "bool",
+           /*default=*/"false",
+           "Transpose LHS outer block layout [MB][KB] -> [KB][MB]">,
+    Option<"lhsTransposeInnerBlocks", "lhs-transpose-inner-blocks", "bool",
+           /*default=*/"false",
+           "Transpose LHS inner block layout [mb][kb] -> [kb][mb]">,
+    Option<"rhsTransposeOuterBlocks", "rhs-transpose-outer-blocks", "bool",
+           /*default=*/"true",
+           "Transpose RHS outer block layout [KB][NB] -> [NB][KB]">,
+    Option<"rhsTransposeInnerBlocks", "rhs-transpose-inner-blocks", "bool",
+           /*default=*/"true",
+           "Transpose RHS inner block layout [kb][nb] -> [nb][kb]">,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 719786f241a9ec..c63422f2b14800 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1167,14 +1167,26 @@ struct PackMatmulOptions {
   /// the parallel dimensions and kb is the reduction dimension.
   SmallVector<int64_t, 3> blockFactors;
 
-  /// Order of the packed dimensions (mb, nb, kb).
-  SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
-
-  SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
-
   /// If true, allows packing of dimensions that only partially fit into the
   /// block factors.
   bool allowPadding = true;
+
+  SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
+
+  /// Permutation of matmul (M, N, K) dimensions order.
+  SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
+
+  /// Transpose LHS outer block layout [MB][KB] -> [KB][MB].
+  bool lhsTransposeOuterBlocks = false;
+
+  /// Transpose LHS inner block layout [mb][kb] -> [kb][mb].
+  bool lhsTransposeInnerBlocks = false;
+
+  /// Transpose RHS outer block layout [KB][NB] -> [NB][KB].
+  bool rhsTransposeOuterBlocks = true;
+
+  /// Transpose RHS inner block layout [kb][nb] -> [nb][kb].
+  bool rhsTransposeInnerBlocks = true;
 };
 
 /// Function type which is used to control matmul packing.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 7bad2d5c4b6a7a..d3074677ff59c2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -143,58 +143,170 @@ linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
          "failed matmul unpacking");
 
+  FailureOr<ContractionDimensions> maybeDimensions =
+      inferContractionDims(packedCanonicalMatmul->packedLinalgOp);
+  if (failed(maybeDimensions)) {
+    llvm::errs() << "Failed to infer contraction dims\n";
+  } else {
+    llvm::errs() << "batch: ";
+    for (auto dim : maybeDimensions->batch)
+      llvm::errs() << dim << " ";
+    llvm::errs() << "\n";
+    llvm::errs() << "m: ";
+    for (auto dim : maybeDimensions->m)
+      llvm::errs() << dim << " ";
+    llvm::errs() << "\n";
+    llvm::errs() << "n: ";
+    for (auto dim : maybeDimensions->n)
+      llvm::errs() << dim << " ";
+    llvm::errs() << "\n";
+    llvm::errs() << "k: ";
+    for (auto dim : maybeDimensions->k)
+      llvm::errs() << dim << " ";
+    llvm::errs() << "\n";
+  }
+
+  auto genericOp = dyn_cast<linalg::GenericOp>(
+      packedCanonicalMatmul->packedLinalgOp.getOperation());
+  SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
+
+  AffineMap lhsMap = maps[0];
+  llvm::errs() << "m pos:" << maybeDimensions->m.end()[-2] << "\n";
+  llvm::errs() << "A mat m map: "
+               << lhsMap.getDimPosition(0 + maybeDimensions->batch.size())
+               << "\n";
+  llvm::errs() << "k pos:" << maybeDimensions->k.end()[-2] << "\n";
+  llvm::errs() << "A mat k dim: "
+               << lhsMap.getDimPosition(1 + maybeDimensions->batch.size())
+               << "\n";
+
+  unsigned int batchOffset = maybeDimensions->batch.size();
+  bool isLhsOuterTransposed =
+      lhsMap.getDimPosition(0 + batchOffset) != maybeDimensions->m.end()[-2];
+  bool isLhsInnerTransposed =
+      lhsMap.getDimPosition(2 + batchOffset) != maybeDimensions->m.back();
+
   auto applyBatchDim = [&](ArrayRef<int64_t> perms) -> SmallVector<int64_t> {
     // Account for the batch dimension.
-    SmallVector<int64_t> newPerms{0};
+    SmallVector<int64_t> newPerms;
+    for (auto i : llvm::seq<unsigned>(0, batchOffset))
+      newPerms.push_back(0);
     // Offset all permutations.
     for (auto perm : perms)
-      newPerms.push_back(++perm);
+      newPerms.push_back(perm + batchOffset);
     return newPerms;
   };
 
   // If needed, block transpose the packed matmul i.e., transpose the outer
   // dimensions. The inner dimensions (minor blocks) remain unchanged.
-  if (isTransposedLhs) {
-    // The inner blocks' layout is already correctly enforced by the initial
-    // packing.
-    SmallVector<int64_t> lhsInnerPerm{0, 1};
-    // Only block transpose the outer dimensions for LHS matrix.
-    SmallVector<int64_t> lhsOuterPerm{1, 0};
-    // Leave the batch dimension as is.
-    if (isBatchMatmulOp)
-      lhsOuterPerm = applyBatchDim(lhsOuterPerm);
-
-    FailureOr<PackTransposeResult> packedMatmul =
-        packTranspose(rewriter, packedCanonicalMatmul->packOps[0],
-                      packedCanonicalMatmul->packedLinalgOp,
-                      /*maybeUnPackOp=*/nullptr, lhsOuterPerm, lhsInnerPerm);
-    if (failed(packedMatmul))
-      return failure();
+  // The inner blocks' layout is already correctly enforced by the initial
+  // packing.
+  SmallVector<int64_t> lhsInnerPerm{0, 1};
+  if (isLhsInnerTransposed != options->lhsTransposeInnerBlocks)
+    lhsInnerPerm = {1, 0};
+
+  // Only block transpose the outer dimensions for LHS matrix.
+  SmallVector<int64_t> lhsOuterPerm{0, 1};
+  if (isLhsOuterTransposed != options->lhsTransposeOuterBlocks)
+    lhsOuterPerm = {1, 0};
+  // Leave the batch dimension as is.
+  if (isBatchMatmulOp)
+    lhsOuterPerm = applyBatchDim(lhsOuterPerm);
 
-    packedCanonicalMatmul->packOps[0] = packedMatmul->transposedPackOp;
-    packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
-  }
+  FailureOr<PackTransposeResult> packedLhs =
+      packTranspose(rewriter, packedCanonicalMatmul->packOps[0],
+                    packedCanonicalMatmul->packedLinalgOp,
+                    /*maybeUnPackOp=*/nullptr, lhsOuterPerm, lhsInnerPerm);
+  if (failed(packedLhs))
+    return failure();
 
-  // Transpose the layout of the inner dimension (minor blocks).
-  SmallVector<int64_t> rhsInnerPerm{1, 0};
-  // Block transpose the RHS matrix i.e., transpose the outer dimensions.
-  SmallVector<int64_t> rhsOuterPerm{1, 0};
-  // No need to block transpose if the RHS matrix is already transposed.
-  if (isTransposedRhs)
-    rhsOuterPerm = {0, 1};
+  packedCanonicalMatmul->packOps[0] = packedLhs->transposedPackOp;
+  packedCanonicalMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
+
+  AffineMap rhsMap = maps[1];
+  bool isRhsOuterTransposed =
+      rhsMap.getDimPosition(0 + batchOffset) != maybeDimensions->k.end()[-2];
+  bool isRhsInnerTransposed =
+      rhsMap.getDimPosition(2 + batchOffset) != maybeDimensions->k.back();
+
+  SmallVector<int64_t> rhsInnerPerm{0, 1};
+  if (isRhsInnerTransposed != options->rhsTransposeInnerBlocks)
+    rhsInnerPerm = {1, 0};
+
+  // Only block transpose the outer dimensions for LHS matrix.
+  SmallVector<int64_t> rhsOuterPerm{0, 1};
+  if (isRhsOuterTransposed != options->rhsTransposeOuterBlocks)
+    rhsOuterPerm = {1, 0};
   // Leave the batch dimension as is.
   if (isBatchMatmulOp)
     rhsOuterPerm = applyBatchDim(rhsOuterPerm);
 
-  FailureOr<PackTransposeResult> packedMatmul =
+  FailureOr<PackTransposeResult> packedRhs =
       packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
                     packedCanonicalMatmul->packedLinalgOp,
                     /*maybeUnPackOp=*/nullptr, rhsOuterPerm, rhsInnerPerm);
-  if (failed(packedMatmul))
+  if (failed(packedRhs))
     return failure();
 
-  packedCanonicalMatmul->packOps[1] = packedMatmul->transposedPackOp;
-  packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
+  packedCanonicalMatmul->packOps[1] = packedRhs->transposedPackOp;
+  packedCanonicalMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
+
+  // auto applyBatchDim = [&](ArrayRef<int64_t> perms) -> SmallVector<int64_t>
+  // {
+  //   // Account for the batch dimension.
+  //   SmallVector<int64_t> newPerms{0};
+  //   // Offset all permutations.
+  //   for (auto perm : perms)
+  //     newPerms.push_back(++perm);
+  //   return newPerms;
+  // };
+
+  // // If needed, block transpose the packed matmul i.e., transpose the outer
+  // // dimensions. The inner dimensions (minor blocks) remain unchanged.
+  // if (isTransposedLhs) {
+  //   // The inner blocks' layout is already correctly enforced by the
+  //   initial
+  //   // packing.
+  //   SmallVector<int64_t> lhsInnerPerm{0, 1};
+  //   // Only block transpose the outer dimensions for LHS matrix.
+  //   SmallVector<int64_t> lhsOuterPerm{1, 0};
+  //   // Leave the batch dimension as is.
+  //   if (isBatchMatmulOp)
+  //     lhsOuterPerm = applyBatchDim(lhsOuterPerm);
+
+  //   FailureOr<PackTransposeResult> packedMatmul =
+  //       packTranspose(rewriter, packedCanonicalMatmul->packOps[0],
+  //                     packedCanonicalMatmul->packedLinalgOp,
+  //                     /*maybeUnPackOp=*/nullptr, lhsOuterPerm,
+  //                     lhsInnerPerm);
+  //   if (failed(packedMatmul))
+  //     return failure();
+
+  //   packedCanonicalMatmul->packOps[0] = packedMatmul->transposedPackOp;
+  //   packedCanonicalMatmul->packedLinalgOp =
+  //   packedMatmul->transposedLinalgOp;
+  // }
+
+  // // Transpose the layout of the inner dimension (minor blocks).
+  // SmallVector<int64_t> rhsInnerPerm{1, 0};
+  // // Block transpose the RHS matrix i.e., transpose the outer dimensions.
+  // SmallVector<int64_t> rhsOuterPerm{1, 0};
+  // // No need to block transpose if the RHS matrix is already transposed.
+  // if (isTransposedRhs)
+  //   rhsOuterPerm = {0, 1};
+  // // Leave the batch dimension as is.
+  // if (isBatchMatmulOp)
+  //   rhsOuterPerm = applyBatchDim(rhsOuterPerm);
+
+  // FailureOr<PackTransposeResult> packedMatmul =
+  //     packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+  //                   packedCanonicalMatmul->packedLinalgOp,
+  //                   /*maybeUnPackOp=*/nullptr, rhsOuterPerm, rhsInnerPerm);
+  // if (failed(packedMatmul))
+  //   return failure();
+
+  // packedCanonicalMatmul->packOps[1] = packedMatmul->transposedPackOp;
+  // packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
 
   return packedCanonicalMatmul;
 }
@@ -232,11 +344,15 @@ struct LinalgBlockPackMatmul
         [&](linalg::LinalgOp op) -> PackMatmulOptions {
       PackMatmulOptions options;
       options.blockFactors = SmallVector<int64_t>{*blockFactors};
-      if (!mnkOrder.empty())
-        options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
+      options.allowPadding = allowPadding;
       options.mnkPaddedSizesNextMultipleOf =
           SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
-      options.allowPadding = allowPadding;
+      if (!mnkOrder.empty())
+        options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
+      options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
+      options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
+      options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
+      options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
       return options;
     };
 

>From 825228a41e2d315f419ff1fee11356044f860806 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 6 May 2024 21:06:28 +0200
Subject: [PATCH 14/22] WIP better options 2

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   2 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   5 +-
 .../Linalg/Transforms/BlockPackMatmul.cpp     | 301 ++++++------------
 3 files changed, 108 insertions(+), 200 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 60d8b5bdec0947..cb942570caa533 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -183,7 +183,7 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
            /*default=*/"true",
            "Allow packing padding">,
     ListOption<"mnkPaddedSizesNextMultipleOf", "mnk-padded-multiples", "int64_t",
-               "Next multiple of the packing sizes">,
+               "Next multiples of the packing sizes">,
     ListOption<"mnkOrder", "mnk-order", "int64_t",
                "Permutation of matmul (M, N, K) dimensions order">,
     Option<"lhsTransposeOuterBlocks", "lhs-transpose-outer-blocks", "bool",
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index c63422f2b14800..beb904066940f4 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1171,6 +1171,7 @@ struct PackMatmulOptions {
   /// block factors.
   bool allowPadding = true;
 
+  /// Next multiples of the packing sizes.
   SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
 
   /// Permutation of matmul (M, N, K) dimensions order.
@@ -1213,8 +1214,8 @@ using ControlPackMatmulFn =
 /// get transposed to [NB][KB] layout. The minor blocks remain unchanged.
 /// The final result is unpacked back to the original layout.
 FailureOr<PackResult>
-blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-                  const ControlPackMatmulFn &controlPackMatmul);
+blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+                const ControlPackMatmulFn &controlPackMatmul);
 
 /// Rewrite tensor.from_elements to linalg.generic.
 FailureOr<Operation *>
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index d3074677ff59c2..33073ace3fa0b5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -41,18 +41,30 @@ static std::optional<int64_t> getConstantRange(const Range &range) {
 }
 
 /// Return true if all dimensions are fully divisible by the respective tiles.
-static bool validateFullTilesOnDims(TilingInterface tileOp,
+static bool validateFullTilesOnDims(linalg::LinalgOp matmulOp,
                                     ArrayRef<OpFoldResult> tiles,
                                     ArrayRef<int64_t> dims) {
   if (dims.size() != tiles.size() || tiles.empty())
     return false;
 
+  FailureOr<ContractionDimensions> contractDims =
+      inferContractionDims(matmulOp);
+  if (failed(contractDims))
+    return false;
+  unsigned batchDimsOffset = contractDims->batch.size();
+
+  // Skip the batch dimension if present.
+  // Offset all dimensions accordingly.
+  SmallVector<int64_t, 3> offsetDims{dims};
+  for (size_t i = 0; i < offsetDims.size(); i++)
+    offsetDims[i] += batchDimsOffset;
+
+  auto tileOp = cast<TilingInterface>(matmulOp.getOperation());
   OpBuilder builder(tileOp);
   OpBuilder::InsertionGuard guard(builder);
-  SmallVector<Range> iterationDomain =
-      cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+  SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
 
-  for (auto dim : llvm::enumerate(dims)) {
+  for (auto dim : llvm::enumerate(offsetDims)) {
     if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
       return false;
 
@@ -73,19 +85,60 @@ static bool validateFullTilesOnDims(TilingInterface tileOp,
   return true;
 }
 
+/// Return failure or packed matmul with one of its operands tranposed.
+static FailureOr<PackTransposeResult>
+transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+                      tensor::PackOp packOp, AffineMap operandMap,
+                      ArrayRef<unsigned> blocksStartDimPos,
+                      bool transposeOuterBlocks, bool transposeInnerBlocks,
+                      unsigned outerDimsOffset = 0) {
+  assert(operandMap.getNumDims() >= 4 &&
+         "expected at least 4D prepacked matmul");
+  assert(blocksStartDimPos.size() == 2 &&
+         "expected starting outer and inner block positions");
+
+  // Base dimension positions in 4D packed matmul.
+  unsigned outerBlockPos = 0;
+  unsigned innerBlockPos = 2;
+
+  // Transpose control options define the desired block and element layout.
+  // Block transposition (outer dimensions) or element transposition (inner
+  // dimensions) may not be necessary depending on the original matmul data
+  // layout.
+  bool isOuterTransposed =
+      operandMap.getDimPosition(outerBlockPos + outerDimsOffset) !=
+      blocksStartDimPos.end()[-2];
+  bool isInnerTransposed =
+      operandMap.getDimPosition(innerBlockPos + outerDimsOffset) !=
+      blocksStartDimPos.back();
+
+  // Transpose only the dimensions that need that to conform to the provided
+  // transpotion settings.
+  SmallVector<int64_t> innerPerm{0, 1};
+  if (isInnerTransposed != transposeInnerBlocks)
+    innerPerm = {1, 0};
+  SmallVector<int64_t> outerPerm{0, 1};
+  if (isOuterTransposed != transposeOuterBlocks)
+    outerPerm = {1, 0};
+
+  // Leave the outer dimensions, like batch, unchanged by offsetting all
+  // outer dimensions permutations.
+  SmallVector<int64_t> offsetPerms(outerDimsOffset, 0);
+  for (auto perm : outerPerm)
+    offsetPerms.push_back(perm + outerDimsOffset);
+  outerPerm = offsetPerms;
+
+  FailureOr<PackTransposeResult> packTransposedMatmul =
+      packTranspose(rewriter, packOp, matmulOp,
+                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+
+  return packTransposedMatmul;
+}
+
 /// Pack a matmul operation into blocked 4D layout.
 FailureOr<PackResult>
-linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-                          const ControlPackMatmulFn &controlPackMatmul) {
-  if (!(isa<linalg::MatmulOp>(matmulOp) ||
-        isa<linalg::BatchMatmulOp>(matmulOp) ||
-        isa<linalg::MatmulTransposeAOp>(matmulOp) ||
-        isa<linalg::MatmulTransposeBOp>(matmulOp) ||
-        isa<linalg::BatchMatmulTransposeAOp>(matmulOp) ||
-        isa<linalg::BatchMatmulTransposeBOp>(matmulOp))) {
-    return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
-  }
-
+linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+                        const ControlPackMatmulFn &controlPackMatmul) {
   if (matmulOp.hasDynamicShape())
     return rewriter.notifyMatchFailure(matmulOp, "require static shape");
 
@@ -102,29 +155,13 @@ linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   SmallVector<OpFoldResult> mnkTiles =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
 
-  SmallVector<int64_t, 3> dims{options->mnkOrder};
-  // Skip the batch dimension if present.
-  bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp) ||
-                         isa<linalg::BatchMatmulTransposeAOp>(matmulOp) ||
-                         isa<linalg::BatchMatmulTransposeBOp>(matmulOp);
-  if (isBatchMatmulOp) {
-    // Offset all dimensions.
-    for (size_t i = 0; i < dims.size(); i++)
-      ++dims[i];
-  }
-
+  // If padding is disabled, make sure that dimensions can be packed cleanly.
   if (!options->allowPadding &&
-      !validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
-                               mnkTiles, dims)) {
+      !validateFullTilesOnDims(matmulOp, mnkTiles, options->mnkOrder)) {
     return rewriter.notifyMatchFailure(matmulOp,
                                        "expect packing full tiles only");
   }
 
-  bool isTransposedLhs = isa<linalg::MatmulTransposeAOp>(matmulOp) ||
-                         isa<linalg::BatchMatmulTransposeAOp>(matmulOp);
-  bool isTransposedRhs = isa<linalg::MatmulTransposeBOp>(matmulOp) ||
-                         isa<linalg::BatchMatmulTransposeBOp>(matmulOp);
-
   OpBuilder::InsertionGuard guard(rewriter);
   // The op is replaced, we need to set the insertion point after it.
   rewriter.setInsertionPointAfter(matmulOp);
@@ -133,182 +170,52 @@ linalg::blockPackMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   // subdivision:
   //   - major 2D blocks - outer dimensions, consist of minor blocks
   //   - minor 2D blocks - inner dimensions, consist of scalar elements
-  FailureOr<PackResult> packedCanonicalMatmul = packMatmulGreedily(
+  FailureOr<PackResult> packedMatmul = packMatmulGreedily(
       rewriter, matmulOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
       options->mnkOrder);
-  if (failed(packedCanonicalMatmul))
+  if (failed(packedMatmul))
     return failure();
 
-  assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
-  assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
-         "failed matmul unpacking");
-
-  FailureOr<ContractionDimensions> maybeDimensions =
-      inferContractionDims(packedCanonicalMatmul->packedLinalgOp);
-  if (failed(maybeDimensions)) {
-    llvm::errs() << "Failed to infer contraction dims\n";
-  } else {
-    llvm::errs() << "batch: ";
-    for (auto dim : maybeDimensions->batch)
-      llvm::errs() << dim << " ";
-    llvm::errs() << "\n";
-    llvm::errs() << "m: ";
-    for (auto dim : maybeDimensions->m)
-      llvm::errs() << dim << " ";
-    llvm::errs() << "\n";
-    llvm::errs() << "n: ";
-    for (auto dim : maybeDimensions->n)
-      llvm::errs() << dim << " ";
-    llvm::errs() << "\n";
-    llvm::errs() << "k: ";
-    for (auto dim : maybeDimensions->k)
-      llvm::errs() << dim << " ";
-    llvm::errs() << "\n";
-  }
+  assert(packedMatmul->packOps.size() == 3 &&
+         "invalid number of pack ops after matmul packing");
+  assert(packedMatmul->unPackOps.size() == 1 &&
+         "invalid number of unpack ops after matmul packing");
 
-  auto genericOp = dyn_cast<linalg::GenericOp>(
-      packedCanonicalMatmul->packedLinalgOp.getOperation());
+  FailureOr<ContractionDimensions> contractDims =
+      inferContractionDims(packedMatmul->packedLinalgOp);
+  if (failed(contractDims))
+    return failure();
+  unsigned batchDimsOffset = contractDims->batch.size();
+
+  auto genericOp =
+      dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
   SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
 
-  AffineMap lhsMap = maps[0];
-  llvm::errs() << "m pos:" << maybeDimensions->m.end()[-2] << "\n";
-  llvm::errs() << "A mat m map: "
-               << lhsMap.getDimPosition(0 + maybeDimensions->batch.size())
-               << "\n";
-  llvm::errs() << "k pos:" << maybeDimensions->k.end()[-2] << "\n";
-  llvm::errs() << "A mat k dim: "
-               << lhsMap.getDimPosition(1 + maybeDimensions->batch.size())
-               << "\n";
-
-  unsigned int batchOffset = maybeDimensions->batch.size();
-  bool isLhsOuterTransposed =
-      lhsMap.getDimPosition(0 + batchOffset) != maybeDimensions->m.end()[-2];
-  bool isLhsInnerTransposed =
-      lhsMap.getDimPosition(2 + batchOffset) != maybeDimensions->m.back();
-
-  auto applyBatchDim = [&](ArrayRef<int64_t> perms) -> SmallVector<int64_t> {
-    // Account for the batch dimension.
-    SmallVector<int64_t> newPerms;
-    for (auto i : llvm::seq<unsigned>(0, batchOffset))
-      newPerms.push_back(0);
-    // Offset all permutations.
-    for (auto perm : perms)
-      newPerms.push_back(perm + batchOffset);
-    return newPerms;
-  };
-
-  // If needed, block transpose the packed matmul i.e., transpose the outer
-  // dimensions. The inner dimensions (minor blocks) remain unchanged.
-  // The inner blocks' layout is already correctly enforced by the initial
-  // packing.
-  SmallVector<int64_t> lhsInnerPerm{0, 1};
-  if (isLhsInnerTransposed != options->lhsTransposeInnerBlocks)
-    lhsInnerPerm = {1, 0};
-
-  // Only block transpose the outer dimensions for LHS matrix.
-  SmallVector<int64_t> lhsOuterPerm{0, 1};
-  if (isLhsOuterTransposed != options->lhsTransposeOuterBlocks)
-    lhsOuterPerm = {1, 0};
-  // Leave the batch dimension as is.
-  if (isBatchMatmulOp)
-    lhsOuterPerm = applyBatchDim(lhsOuterPerm);
-
-  FailureOr<PackTransposeResult> packedLhs =
-      packTranspose(rewriter, packedCanonicalMatmul->packOps[0],
-                    packedCanonicalMatmul->packedLinalgOp,
-                    /*maybeUnPackOp=*/nullptr, lhsOuterPerm, lhsInnerPerm);
+  // Transpose LHS matrix according to the options.
+  FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
+      rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
+      contractDims->m, options->lhsTransposeOuterBlocks,
+      options->lhsTransposeInnerBlocks, batchDimsOffset);
   if (failed(packedLhs))
     return failure();
 
-  packedCanonicalMatmul->packOps[0] = packedLhs->transposedPackOp;
-  packedCanonicalMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
-
-  AffineMap rhsMap = maps[1];
-  bool isRhsOuterTransposed =
-      rhsMap.getDimPosition(0 + batchOffset) != maybeDimensions->k.end()[-2];
-  bool isRhsInnerTransposed =
-      rhsMap.getDimPosition(2 + batchOffset) != maybeDimensions->k.back();
-
-  SmallVector<int64_t> rhsInnerPerm{0, 1};
-  if (isRhsInnerTransposed != options->rhsTransposeInnerBlocks)
-    rhsInnerPerm = {1, 0};
-
-  // Only block transpose the outer dimensions for LHS matrix.
-  SmallVector<int64_t> rhsOuterPerm{0, 1};
-  if (isRhsOuterTransposed != options->rhsTransposeOuterBlocks)
-    rhsOuterPerm = {1, 0};
-  // Leave the batch dimension as is.
-  if (isBatchMatmulOp)
-    rhsOuterPerm = applyBatchDim(rhsOuterPerm);
-
-  FailureOr<PackTransposeResult> packedRhs =
-      packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
-                    packedCanonicalMatmul->packedLinalgOp,
-                    /*maybeUnPackOp=*/nullptr, rhsOuterPerm, rhsInnerPerm);
+  // Update results.
+  packedMatmul->packOps[0] = packedLhs->transposedPackOp;
+  packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
+
+  // Transpose RHS matrix according to the options.
+  FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
+      rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
+      contractDims->k, options->rhsTransposeOuterBlocks,
+      options->rhsTransposeInnerBlocks, batchDimsOffset);
   if (failed(packedRhs))
     return failure();
 
-  packedCanonicalMatmul->packOps[1] = packedRhs->transposedPackOp;
-  packedCanonicalMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
-
-  // auto applyBatchDim = [&](ArrayRef<int64_t> perms) -> SmallVector<int64_t>
-  // {
-  //   // Account for the batch dimension.
-  //   SmallVector<int64_t> newPerms{0};
-  //   // Offset all permutations.
-  //   for (auto perm : perms)
-  //     newPerms.push_back(++perm);
-  //   return newPerms;
-  // };
-
-  // // If needed, block transpose the packed matmul i.e., transpose the outer
-  // // dimensions. The inner dimensions (minor blocks) remain unchanged.
-  // if (isTransposedLhs) {
-  //   // The inner blocks' layout is already correctly enforced by the
-  //   initial
-  //   // packing.
-  //   SmallVector<int64_t> lhsInnerPerm{0, 1};
-  //   // Only block transpose the outer dimensions for LHS matrix.
-  //   SmallVector<int64_t> lhsOuterPerm{1, 0};
-  //   // Leave the batch dimension as is.
-  //   if (isBatchMatmulOp)
-  //     lhsOuterPerm = applyBatchDim(lhsOuterPerm);
-
-  //   FailureOr<PackTransposeResult> packedMatmul =
-  //       packTranspose(rewriter, packedCanonicalMatmul->packOps[0],
-  //                     packedCanonicalMatmul->packedLinalgOp,
-  //                     /*maybeUnPackOp=*/nullptr, lhsOuterPerm,
-  //                     lhsInnerPerm);
-  //   if (failed(packedMatmul))
-  //     return failure();
-
-  //   packedCanonicalMatmul->packOps[0] = packedMatmul->transposedPackOp;
-  //   packedCanonicalMatmul->packedLinalgOp =
-  //   packedMatmul->transposedLinalgOp;
-  // }
-
-  // // Transpose the layout of the inner dimension (minor blocks).
-  // SmallVector<int64_t> rhsInnerPerm{1, 0};
-  // // Block transpose the RHS matrix i.e., transpose the outer dimensions.
-  // SmallVector<int64_t> rhsOuterPerm{1, 0};
-  // // No need to block transpose if the RHS matrix is already transposed.
-  // if (isTransposedRhs)
-  //   rhsOuterPerm = {0, 1};
-  // // Leave the batch dimension as is.
-  // if (isBatchMatmulOp)
-  //   rhsOuterPerm = applyBatchDim(rhsOuterPerm);
-
-  // FailureOr<PackTransposeResult> packedMatmul =
-  //     packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
-  //                   packedCanonicalMatmul->packedLinalgOp,
-  //                   /*maybeUnPackOp=*/nullptr, rhsOuterPerm, rhsInnerPerm);
-  // if (failed(packedMatmul))
-  //   return failure();
-
-  // packedCanonicalMatmul->packOps[1] = packedMatmul->transposedPackOp;
-  // packedCanonicalMatmul->packedLinalgOp = packedMatmul->transposedLinalgOp;
-
-  return packedCanonicalMatmul;
+  // Update results.
+  packedMatmul->packOps[1] = packedRhs->transposedPackOp;
+  packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
+
+  return packedMatmul;
 }
 
 namespace {
@@ -321,7 +228,7 @@ struct BlockPackMatmul : public OpRewritePattern<OpTy> {
   LogicalResult matchAndRewrite(OpTy matmulOp,
                                 PatternRewriter &rewriter) const override {
     FailureOr<PackResult> packedMatmul =
-        blockPackMatmulOp(rewriter, matmulOp, controlFn);
+        blockPackMatmul(rewriter, matmulOp, controlFn);
     if (failed(packedMatmul))
       return failure();
     return success();

>From ac9b8a3a8cb834dbcdc8fd84bc6106e556e52b1b Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 6 May 2024 22:57:58 +0200
Subject: [PATCH 15/22] Update tests

---
 .../Dialect/Linalg/block-pack-matmul.mlir     | 80 +++++++++----------
 1 file changed, 40 insertions(+), 40 deletions(-)

diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index 6f134a494e8d0c..e855b42f732558 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -8,19 +8,19 @@ func.func @block_matmul(
 }
 
 // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
 
 // CHECK-LABEL: func @block_matmul(
 // CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
 // CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
 // CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
-// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
 // CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x16x64xf32>
 // CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32>
 // CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
 // CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
@@ -28,7 +28,7 @@ func.func @block_matmul(
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
@@ -49,7 +49,7 @@ func.func @block_matmul_with_constant(
 // CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
 // CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
@@ -72,7 +72,7 @@ func.func @block_matmul_with_producer(
 // CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
 // CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
-// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
@@ -112,19 +112,19 @@ func.func @block_batch_matmul(
 }
 
 // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
 
 // CHECK-LABEL: func @block_batch_matmul(
 // CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
 // CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
 // CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
-// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
 // CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
 // CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
 // CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
 // CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
@@ -132,7 +132,7 @@ func.func @block_batch_matmul(
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
@@ -147,20 +147,20 @@ func.func @block_matmul_transpose_a(
   return %0 : tensor<64x64xf32>
 }
 
-// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
 
 // CHECK-LABEL: func @block_matmul_transpose_a(
 // CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
 // CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
 // CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
-// CHECK-SAME:  inner_dims_pos = [1, 0] inner_tiles = [32, 64]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
 // CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x64x16xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x16x64xf32>
 // CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x64x16xf32>
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
 // CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
 // CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
@@ -168,7 +168,7 @@ func.func @block_matmul_transpose_a(
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
@@ -183,20 +183,20 @@ func.func @block_batch_matmul_transpose_a(
   return %0 : tensor<512x64x64xf32>
 }
 
-// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d1, d4, d6)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
 
 // CHECK-LABEL: func @block_batch_matmul_transpose_a(
 // CHECK-SAME:   %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
 // CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
 // CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
-// CHECK-SAME:  inner_dims_pos = [2, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64]
 // CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
 // CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
 // CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
 // CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
@@ -204,7 +204,7 @@ func.func @block_batch_matmul_transpose_a(
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
@@ -220,19 +220,19 @@ func.func @block_matmul_transpose_b(
 }
 
 // CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
 // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
 
 // CHECK-LABEL: func @block_matmul_transpose_b(
 // CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
 // CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
 // CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
-// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
 // CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x64x16xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x16x64xf32>
 // CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [1, 0] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x64x16xf32>
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
 // CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
 // CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
@@ -240,7 +240,7 @@ func.func @block_matmul_transpose_b(
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
@@ -256,19 +256,19 @@ func.func @block_batch_matmul_transpose_b(
 }
 
 // CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
 // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
 
 // CHECK-LABEL: func @block_batch_matmul_transpose_b(
 // CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
 // CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
 // CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
-// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
 // CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
 // CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
-// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [2, 1] inner_tiles = [64, 16]
-// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32>
 // CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
 // CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
@@ -276,7 +276,7 @@ func.func @block_batch_matmul_transpose_b(
 // CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
 // CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
 // CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
 // CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>

>From 91be9091a4f2634ffba248ec667307603846aa52 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 6 May 2024 23:19:01 +0200
Subject: [PATCH 16/22] Update descriptions

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    | 27 +++++++++----------
 .../Dialect/Linalg/Transforms/Transforms.h    | 23 +++++++++-------
 .../Linalg/Transforms/BlockPackMatmul.cpp     | 10 +++----
 3 files changed, 32 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index cb942570caa533..ccc6baa8256404 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -149,31 +149,30 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
     - minor 2D blocks - inner dimensions, consist of scalar elements
 
     A 2D matmul MxNxK gets reshaped into blocked 4D representation
-    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
     where the (MB, NB, KB) dimensions represent the major blocks,
     and the (mb, nb, kb) are the minor blocks of their respective
     original 2D dimensions (M, N, K).
 
-    As a part of packing strategy, the RHS operand gets 'block transposed'
-    i.e., the major blocks [KB][NB] get transposed to [NB][KB] layout.
-    The minor blocks remain unchanged.
-    The final result is unpacked back to the original layout.
+    Depending on the initial operands' data layout and the specified
+    packing options, both the major blocks dimensions might get transposed
+    e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
+    e.g., [mb][kb] -> [kb][mb].
+    Any present batch dimensions remain unchanged.
+    The final result is unpacked back to the original shape.
 
     For example, given a matmul operation:
     ```mlir
       %res = linalg.matmul ins(%A, %B) outs(%C)
     ```
-    the transformation result can be represented as:
+    the default transformation result can be represented as:
     ```mlir
-      %A_packed = pack %A : 2D -> 4D
-      %B_packed = pack %B : 2D -> 4D #block_transposed
-      %C_packed = pack %C : 2D -> 4D
+      %A_packed = pack %A : 2D <MxK> -> 4D <MBxKBxmbxkb>
+      %B_packed = pack %B : 2D <KxN> -> 4D <NBxKBxnbxkb>
+      %C_packed = pack %C : 2D <MxN> -> 4D <MBxNBxmbxnb>
       %res_packed = linalg.mmt4d ins(%A_packed, %B_packed) outs(%C_packed)
-      %res = unpack %res_packed : 4D -> 2D
+      %res = unpack %res_packed : 4D <MBxNBxmbxnb> -> 2D <MxN>
     ```
-
-    This packed data arrangement minimizes distance between consecutive
-    blocks which improves spacial locality and cache behavior.
   }];
   let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
   let options = [
@@ -197,7 +196,7 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
            "Transpose RHS outer block layout [KB][NB] -> [NB][KB]">,
     Option<"rhsTransposeInnerBlocks", "rhs-transpose-inner-blocks", "bool",
            /*default=*/"true",
-           "Transpose RHS inner block layout [kb][nb] -> [nb][kb]">,
+           "Transpose RHS inner block layout [kb][nb] -> [nb][kb]">
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index beb904066940f4..9e0a5cb01a8604 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1162,7 +1162,7 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                    ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
                    ArrayRef<int64_t> mnkOrder);
 
-struct PackMatmulOptions {
+struct BlockPackMatmulOptions {
   /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
   /// the parallel dimensions and kb is the reduction dimension.
   SmallVector<int64_t, 3> blockFactors;
@@ -1194,8 +1194,8 @@ struct PackMatmulOptions {
 /// It is expected to return valid packing configuration for each operation.
 /// Lack of packing options indicates that no valid configuration could be
 /// assigned and the operation will not be packed.
-using ControlPackMatmulFn =
-    std::function<std::optional<PackMatmulOptions>(linalg::LinalgOp)>;
+using ControlBlockPackMatmulFn =
+    std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>;
 
 /// Pack a matmul operation into blocked 4D layout.
 ///
@@ -1205,17 +1205,22 @@ using ControlPackMatmulFn =
 ///   - minor 2D blocks - inner dimensions, consist of scalar elements
 ///
 /// A 2D matmul MxNxK gets reshaped into blocked 4D representation
-/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
 /// where the (MB, NB, KB) dimensions represent the major blocks,
 /// and the (mb, nb, kb) are the minor blocks of their respective
 /// original 2D dimensions (M, N, K).
 ///
-/// The RHS operand gets 'block transposed' i.e., the major blocks [KB][NB]
-/// get transposed to [NB][KB] layout. The minor blocks remain unchanged.
-/// The final result is unpacked back to the original layout.
+/// Depending on the initial operands' data layout and the specified
+/// packing options, both the major blocks dimensions might get transposed
+/// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
+/// e.g., [mb][kb] -> [kb][mb].
+/// Any present batch dimensions remain unchanged.
+/// The final result is unpacked back to the original shape.
+///
+/// Return failure if no valid packing options are provided.
 FailureOr<PackResult>
 blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-                const ControlPackMatmulFn &controlPackMatmul);
+                const ControlBlockPackMatmulFn &controlPackMatmul);
 
 /// Rewrite tensor.from_elements to linalg.generic.
 FailureOr<Operation *>
@@ -1685,7 +1690,7 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 
 /// Patterns to block pack Linalg matmul ops.
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
-                                     const ControlPackMatmulFn &controlFn);
+                                     const ControlBlockPackMatmulFn &controlFn);
 
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 33073ace3fa0b5..23a25e1d23f88d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -145,7 +145,7 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   if (matmulOp.hasPureBufferSemantics())
     return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
 
-  std::optional<PackMatmulOptions> options = controlPackMatmul(matmulOp);
+  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(matmulOp);
   if (!options)
     return rewriter.notifyMatchFailure(matmulOp, "invalid packing options");
 
@@ -247,9 +247,9 @@ struct LinalgBlockPackMatmul
     Operation *op = getOperation();
     RewritePatternSet patterns(&getContext());
 
-    ControlPackMatmulFn controlFn =
-        [&](linalg::LinalgOp op) -> PackMatmulOptions {
-      PackMatmulOptions options;
+    ControlBlockPackMatmulFn controlFn =
+        [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
+      BlockPackMatmulOptions options;
       options.blockFactors = SmallVector<int64_t>{*blockFactors};
       options.allowPadding = allowPadding;
       options.mnkPaddedSizesNextMultipleOf =
@@ -271,7 +271,7 @@ struct LinalgBlockPackMatmul
 } // namespace
 
 void linalg::populateBlockPackMatmulPatterns(
-    RewritePatternSet &patterns, const ControlPackMatmulFn &controlFn) {
+    RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
   patterns.add<BlockPackMatmul<linalg::MatmulOp>,
                BlockPackMatmul<linalg::BatchMatmulOp>,
                BlockPackMatmul<linalg::MatmulTransposeAOp>,

>From 2c4752144388b9ff2da7ea577fe6a81364aa0318 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 6 May 2024 23:29:47 +0200
Subject: [PATCH 17/22] Refactor

---
 mlir/include/mlir/Dialect/Linalg/Passes.td               | 2 +-
 mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h | 2 +-
 mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp   | 6 +++---
 3 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index ccc6baa8256404..0a4ce8953136dd 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -155,7 +155,7 @@ def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
     original 2D dimensions (M, N, K).
 
     Depending on the initial operands' data layout and the specified
-    packing options, both the major blocks dimensions might get transposed
+    packing options, the major blocks dimensions might get transposed
     e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
     e.g., [mb][kb] -> [kb][mb].
     Any present batch dimensions remain unchanged.
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 9e0a5cb01a8604..472524c99e7f7e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1211,7 +1211,7 @@ using ControlBlockPackMatmulFn =
 /// original 2D dimensions (M, N, K).
 ///
 /// Depending on the initial operands' data layout and the specified
-/// packing options, both the major blocks dimensions might get transposed
+/// packing options, the major blocks dimensions might get transposed
 /// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
 /// e.g., [mb][kb] -> [kb][mb].
 /// Any present batch dimensions remain unchanged.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 23a25e1d23f88d..41f58a2cc47f21 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -138,7 +138,7 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
 /// Pack a matmul operation into blocked 4D layout.
 FailureOr<PackResult>
 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
-                        const ControlPackMatmulFn &controlPackMatmul) {
+                        const ControlBlockPackMatmulFn &controlPackMatmul) {
   if (matmulOp.hasDynamicShape())
     return rewriter.notifyMatchFailure(matmulOp, "require static shape");
 
@@ -221,7 +221,7 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
 namespace {
 template <typename OpTy>
 struct BlockPackMatmul : public OpRewritePattern<OpTy> {
-  BlockPackMatmul(MLIRContext *context, ControlPackMatmulFn fun,
+  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
                   PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
 
@@ -235,7 +235,7 @@ struct BlockPackMatmul : public OpRewritePattern<OpTy> {
   }
 
 private:
-  ControlPackMatmulFn controlFn;
+  ControlBlockPackMatmulFn controlFn;
 };
 
 /// Convert linalg matmul ops to block layout and back.

>From 9bf77ead15f7cb7392faf88935f6a975eeff26e6 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 May 2024 10:19:17 +0200
Subject: [PATCH 18/22] Refactor

---
 .../Dialect/Linalg/Transforms/Transforms.h    |  2 +-
 .../Linalg/Transforms/BlockPackMatmul.cpp     | 38 +++++++++----------
 2 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 472524c99e7f7e..f77c19ed0fcce9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1219,7 +1219,7 @@ using ControlBlockPackMatmulFn =
 ///
 /// Return failure if no valid packing options are provided.
 FailureOr<PackResult>
-blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                 const ControlBlockPackMatmulFn &controlPackMatmul);
 
 /// Rewrite tensor.from_elements to linalg.generic.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 41f58a2cc47f21..076da215b15dd6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -41,14 +41,14 @@ static std::optional<int64_t> getConstantRange(const Range &range) {
 }
 
 /// Return true if all dimensions are fully divisible by the respective tiles.
-static bool validateFullTilesOnDims(linalg::LinalgOp matmulOp,
+static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
                                     ArrayRef<OpFoldResult> tiles,
                                     ArrayRef<int64_t> dims) {
   if (dims.size() != tiles.size() || tiles.empty())
     return false;
 
   FailureOr<ContractionDimensions> contractDims =
-      inferContractionDims(matmulOp);
+      inferContractionDims(linalgOp);
   if (failed(contractDims))
     return false;
   unsigned batchDimsOffset = contractDims->batch.size();
@@ -59,7 +59,7 @@ static bool validateFullTilesOnDims(linalg::LinalgOp matmulOp,
   for (size_t i = 0; i < offsetDims.size(); i++)
     offsetDims[i] += batchDimsOffset;
 
-  auto tileOp = cast<TilingInterface>(matmulOp.getOperation());
+  auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
   OpBuilder builder(tileOp);
   OpBuilder::InsertionGuard guard(builder);
   SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
@@ -87,7 +87,7 @@ static bool validateFullTilesOnDims(linalg::LinalgOp matmulOp,
 
 /// Return failure or packed matmul with one of its operands tranposed.
 static FailureOr<PackTransposeResult>
-transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                       tensor::PackOp packOp, AffineMap operandMap,
                       ArrayRef<unsigned> blocksStartDimPos,
                       bool transposeOuterBlocks, bool transposeInnerBlocks,
@@ -129,7 +129,7 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
   outerPerm = offsetPerms;
 
   FailureOr<PackTransposeResult> packTransposedMatmul =
-      packTranspose(rewriter, packOp, matmulOp,
+      packTranspose(rewriter, packOp, linalgOp,
                     /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
 
   return packTransposedMatmul;
@@ -137,41 +137,41 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
 
 /// Pack a matmul operation into blocked 4D layout.
 FailureOr<PackResult>
-linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                         const ControlBlockPackMatmulFn &controlPackMatmul) {
-  if (matmulOp.hasDynamicShape())
-    return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+  if (linalgOp.hasDynamicShape())
+    return rewriter.notifyMatchFailure(linalgOp, "require static shape");
 
-  if (matmulOp.hasPureBufferSemantics())
-    return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+  if (linalgOp.hasPureBufferSemantics())
+    return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
 
-  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(matmulOp);
+  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
   if (!options)
-    return rewriter.notifyMatchFailure(matmulOp, "invalid packing options");
+    return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
 
   if (options->blockFactors.size() != 3)
-    return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+    return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
 
   SmallVector<OpFoldResult> mnkTiles =
       getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
 
   // If padding is disabled, make sure that dimensions can be packed cleanly.
   if (!options->allowPadding &&
-      !validateFullTilesOnDims(matmulOp, mnkTiles, options->mnkOrder)) {
-    return rewriter.notifyMatchFailure(matmulOp,
+      !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
+    return rewriter.notifyMatchFailure(linalgOp,
                                        "expect packing full tiles only");
   }
 
   OpBuilder::InsertionGuard guard(rewriter);
   // The op is replaced, we need to set the insertion point after it.
-  rewriter.setInsertionPointAfter(matmulOp);
+  rewriter.setInsertionPointAfter(linalgOp);
 
   // Pack the matmul operation into blocked layout with two levels of
   // subdivision:
   //   - major 2D blocks - outer dimensions, consist of minor blocks
   //   - minor 2D blocks - inner dimensions, consist of scalar elements
   FailureOr<PackResult> packedMatmul = packMatmulGreedily(
-      rewriter, matmulOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
+      rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
       options->mnkOrder);
   if (failed(packedMatmul))
     return failure();
@@ -225,10 +225,10 @@ struct BlockPackMatmul : public OpRewritePattern<OpTy> {
                   PatternBenefit benefit = 1)
       : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
 
-  LogicalResult matchAndRewrite(OpTy matmulOp,
+  LogicalResult matchAndRewrite(OpTy linalgOp,
                                 PatternRewriter &rewriter) const override {
     FailureOr<PackResult> packedMatmul =
-        blockPackMatmul(rewriter, matmulOp, controlFn);
+        blockPackMatmul(rewriter, linalgOp, controlFn);
     if (failed(packedMatmul))
       return failure();
     return success();

>From 8c97c6ace69069328932be7389a172f332bf44e9 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 May 2024 13:11:29 +0200
Subject: [PATCH 19/22] More tests

---
 .../Linalg/block-pack-matmul-layout.mlir      | 101 ++++++++++++++++++
 .../Linalg/block-pack-matmul-padding.mlir     |  82 ++++++++++++++
 2 files changed, 183 insertions(+)
 create mode 100644 mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
 create mode 100644 mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir

diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
new file mode 100644
index 00000000000000..f740dbf31255c7
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
+// RUN: lhs-transpose-outer-blocks=false lhs-transpose-inner-blocks=false \
+// RUN: rhs-transpose-outer-blocks=true rhs-transpose-inner-blocks=true" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=MMT4D
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
+// RUN: lhs-transpose-outer-blocks=false lhs-transpose-inner-blocks=false \
+// RUN: rhs-transpose-outer-blocks=false rhs-transpose-inner-blocks=false" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=MM4D
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
+// RUN: lhs-transpose-outer-blocks=true lhs-transpose-inner-blocks=true \
+// RUN: rhs-transpose-outer-blocks=false rhs-transpose-inner-blocks=false" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=MTM4D
+
+func.func @block_matmul(
+    %A: tensor<64x128xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<64x128xf32>, tensor<128x64xf32>)
+                      outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+func.func @block_matmul_transpose_a(
+    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+func.func @block_matmul_transpose_b(
+    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// MMT4D-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// MMT4D-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// MMT4D-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// MMT4D-LABEL: func @block_matmul
+// MMT4D-COUNT-3: tensor.pack
+// MMT4D: linalg.generic
+// MMT4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MMT4D-COUNT-1: tensor.unpack
+// MMT4D-LABEL: func @block_matmul_transpose_a
+// MMT4D-COUNT-3: tensor.pack
+// MMT4D: linalg.generic
+// MMT4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MMT4D-COUNT-1: tensor.unpack
+// MMT4D-LABEL: func @block_matmul_transpose_b
+// MMT4D-COUNT-3: tensor.pack
+// MMT4D: linalg.generic
+// MMT4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MMT4D-COUNT-1: tensor.unpack
+
+// MM4D-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// MM4D-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
+// MM4D-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// MM4D-LABEL: func @block_matmul
+// MM4D-COUNT-3: tensor.pack
+// MM4D: linalg.generic
+// MM4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MM4D-COUNT-1: tensor.unpack
+// MM4D-LABEL: func @block_matmul_transpose_a
+// MM4D-COUNT-3: tensor.pack
+// MM4D: linalg.generic
+// MM4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MM4D-COUNT-1: tensor.unpack
+// MM4D-LABEL: func @block_matmul_transpose_b
+// MM4D-COUNT-3: tensor.pack
+// MM4D: linalg.generic
+// MM4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MM4D-COUNT-1: tensor.unpack
+
+// MTM4D-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5, d3)>
+// MTM4D-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
+// MTM4D-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// MTM4D-LABEL: func @block_matmul
+// MTM4D-COUNT-3: tensor.pack
+// MTM4D: linalg.generic
+// MTM4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MTM4D-COUNT-1: tensor.unpack
+// MTM4D-LABEL: func @block_matmul_transpose_a
+// MTM4D-COUNT-3: tensor.pack
+// MTM4D: linalg.generic
+// MTM4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MTM4D-COUNT-1: tensor.unpack
+// MTM4D-LABEL: func @block_matmul_transpose_b
+// MTM4D-COUNT-3: tensor.pack
+// MTM4D: linalg.generic
+// MTM4D-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// MTM4D-COUNT-1: tensor.unpack
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir
new file mode 100644
index 00000000000000..75defd8b884192
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=1" \
+// RUN: -canonicalize | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=0" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=NOPAD
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=1 mnk-padded-multiples=256,512,384" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=PAD-MULT
+
+func.func @block_matmul_padding(
+    %A: tensor<123x125xf32>, %B: tensor<125x124xf32>, %C: tensor<123x124xf32>) -> tensor<123x124xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<123x125xf32>, tensor<125x124xf32>)
+                      outs(%C : tensor<123x124xf32>) -> tensor<123x124xf32>
+  return %0 : tensor<123x124xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// CHECK-LABEL: func @block_matmul_padding(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32>
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<123x125xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<125x124xf32> -> tensor<8x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<123x124xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<123x124xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<123x124xf32>
+
+// NOPAD-LABEL: func @block_matmul_padding(
+// NOPAD-SAME:    %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32>
+// NOPAD-NOT: tensor.pack
+// NOPAD: linalg.matmul ins(%[[A]], %[[B]] : tensor<123x125xf32>, tensor<125x124xf32>)
+// NOPAD-SAME: outs(%[[C]] : tensor<123x124xf32>) -> tensor<123x124xf32>
+// NOPAD-NOT: tensor.unpack
+
+// PAD-MULT-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// PAD-MULT-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// PAD-MULT-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// PAD-MULT-LABEL: func @block_matmul_padding(
+// PAD-MULT-SAME:    %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32>
+// PAD-MULT-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// PAD-MULT: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<1x1x256x384xf32>
+// PAD-MULT: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// PAD-MULT-SAME:  padding_value(%[[ZERO]] : f32)
+// PAD-MULT-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [256, 384]
+// PAD-MULT-SAME:  into %[[PACK_DST_0]] : tensor<123x125xf32> -> tensor<1x1x256x384xf32>
+// PAD-MULT: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<1x1x512x384xf32>
+// PAD-MULT: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// PAD-MULT-SAME:  padding_value(%[[ZERO]] : f32)
+// PAD-MULT-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [512, 384]
+// PAD-MULT-SAME:  into %[[PACK_DST_1]] : tensor<125x124xf32> -> tensor<1x1x512x384xf32>
+// PAD-MULT: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<1x1x256x512xf32>
+// PAD-MULT: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// PAD-MULT-SAME:  padding_value(%[[ZERO]] : f32)
+// PAD-MULT-SAME:  inner_dims_pos = [0, 1] inner_tiles = [256, 512]
+// PAD-MULT-SAME:  into %[[PACK_DST_2]] : tensor<123x124xf32> -> tensor<1x1x256x512xf32>
+// PAD-MULT: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// PAD-MULT-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// PAD-MULT-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// PAD-MULT-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<1x1x256x384xf32>, tensor<1x1x512x384xf32>) outs(%[[C_PACKED]] : tensor<1x1x256x512xf32>)
+// PAD-MULT: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// PAD-MULT-SAME:  inner_dims_pos = [0, 1] inner_tiles = [256, 512]
+// PAD-MULT-SAME:  into %[[C]] : tensor<1x1x256x512xf32> -> tensor<123x124xf32>
+// PAD-MULT: return %[[RES_UNPACKED]] : tensor<123x124xf32>

>From 5d8dead6764fe5474d1b19f28970d98e5e034bf7 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 May 2024 13:27:48 +0200
Subject: [PATCH 20/22] Allow dynamic shapes

---
 .../Linalg/Transforms/BlockPackMatmul.cpp     |  3 -
 .../Dialect/Linalg/block-pack-matmul.mlir     | 57 +++++++++++++++++++
 2 files changed, 57 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 076da215b15dd6..d3f3f99b92196b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -139,9 +139,6 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
 FailureOr<PackResult>
 linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                         const ControlBlockPackMatmulFn &controlPackMatmul) {
-  if (linalgOp.hasDynamicShape())
-    return rewriter.notifyMatchFailure(linalgOp, "require static shape");
-
   if (linalgOp.hasPureBufferSemantics())
     return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
 
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index e855b42f732558..b9e5f1c774df80 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -36,6 +36,63 @@ func.func @block_matmul(
 
 // -----
 
+func.func @block_matmul_dynamic(
+    %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                      outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[MAP_M:.*]] = affine_map<()[s0] -> (s0 ceildiv 32)>
+// CHECK-DAG: #[[MAP_K:.*]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+// CHECK-DAG: #[[MAP_N:.*]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul_dynamic(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<?x?xf32>, %[[B:[0-9a-z]+]]: tensor<?x?xf32>, %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[A_M:.+]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[A_K:.+]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[A_OUTER_TILE_M:.+]] = affine.apply #[[MAP_M]]()[%[[A_M]]]
+// CHECK-DAG: %[[A_OUTER_TILE_K:.+]] = affine.apply #[[MAP_K]]()[%[[A_K]]]
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty(%[[A_OUTER_TILE_M]], %[[A_OUTER_TILE_K]]) : tensor<?x?x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<?x?xf32> -> tensor<?x?x32x64xf32>
+// CHECK-DAG: %[[B_K:.+]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[B_N:.+]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[B_OUTER_TILE_K:.+]] = affine.apply #[[MAP_K]]()[%[[B_K]]]
+// CHECK-DAG: %[[B_OUTER_TILE_N:.+]] = affine.apply #[[MAP_N]]()[%[[B_N]]]
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty(%[[B_OUTER_TILE_N]], %[[B_OUTER_TILE_K]]) : tensor<?x?x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<?x?xf32> -> tensor<?x?x16x64xf32>
+// CHECK-DAG: %[[C_M:.+]] = tensor.dim %[[C]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[C_N:.+]] = tensor.dim %[[C]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[C_OUTER_TILE_M:.+]] = affine.apply #[[MAP_M]]()[%[[C_M]]]
+// CHECK-DAG: %[[C_OUTER_TILE_N:.+]] = affine.apply #[[MAP_N]]()[%[[C_N]]]
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty(%[[C_OUTER_TILE_M]], %[[C_OUTER_TILE_N]]) : tensor<?x?x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<?x?xf32> -> tensor<?x?x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<?x?x32x64xf32>, tensor<?x?x16x64xf32>) outs(%[[C_PACKED]] : tensor<?x?x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<?x?x32x16xf32> -> tensor<?x?xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<?x?xf32>
+
+// -----
+
 func.func @block_matmul_with_constant(
     %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> {
   %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32>

>From 94e4fb77b9e45609e4286117d3b6e0b753cf6146 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 May 2024 14:21:25 +0200
Subject: [PATCH 21/22] Support simple generics

---
 .../Linalg/Transforms/BlockPackMatmul.cpp     |  46 +++++-
 .../Dialect/Linalg/block-pack-matmul.mlir     | 138 ++++++++++++++++++
 2 files changed, 183 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index d3f3f99b92196b..654d4a0129be44 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -235,6 +235,49 @@ struct BlockPackMatmul : public OpRewritePattern<OpTy> {
   ControlBlockPackMatmulFn controlFn;
 };
 
+template <>
+struct BlockPackMatmul<linalg::GenericOp>
+    : public OpRewritePattern<linalg::GenericOp> {
+  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern<linalg::GenericOp>(context, benefit),
+        controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    // Match suitable generics.
+    if (failed(linalg::detail::verifyContractionInterface(
+            linalgOp.getOperation()))) {
+      return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
+    }
+
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [&](MapList m) {
+      return AffineMap::inferFromExprList(m, linalgOp.getContext());
+    };
+
+    AffineExpr i, j, k;
+    bindDims(linalgOp->getContext(), i, j, k);
+    SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
+
+    // For now, only match simple matmuls.
+    if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
+          maps == infer({{k, i}, {k, j}, {i, j}}) ||
+          maps == infer({{i, k}, {j, k}, {i, j}}))) {
+      return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
+    }
+
+    FailureOr<PackResult> packedMatmul =
+        blockPackMatmul(rewriter, linalgOp, controlFn);
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  ControlBlockPackMatmulFn controlFn;
+};
+
 /// Convert linalg matmul ops to block layout and back.
 struct LinalgBlockPackMatmul
     : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
@@ -269,7 +312,8 @@ struct LinalgBlockPackMatmul
 
 void linalg::populateBlockPackMatmulPatterns(
     RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
-  patterns.add<BlockPackMatmul<linalg::MatmulOp>,
+  patterns.add<BlockPackMatmul<linalg::GenericOp>,
+               BlockPackMatmul<linalg::MatmulOp>,
                BlockPackMatmul<linalg::BatchMatmulOp>,
                BlockPackMatmul<linalg::MatmulTransposeAOp>,
                BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
diff --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
index b9e5f1c774df80..769bf71d3ef304 100644
--- a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -338,3 +338,141 @@ func.func @block_batch_matmul_transpose_b(
 // CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
 // CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
 // CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @block_generic_matmul(
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+    outs(%C : tensor<128x128xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_generic_matmul(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @block_generic_matmul_transpose_a(
+    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
+    outs(%C : tensor<64x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_generic_matmul_transpose_a(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @block_generic_matmul_transpose_b(
+    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
+    outs(%C : tensor<64x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_generic_matmul_transpose_b(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>

>From 6598369fc782981f71ffe5eeb38b737548b3f8d8 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 7 May 2024 15:05:23 +0200
Subject: [PATCH 22/22] Typo

---
 mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 654d4a0129be44..e1e254dcb25a03 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -85,7 +85,7 @@ static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
   return true;
 }
 
-/// Return failure or packed matmul with one of its operands tranposed.
+/// Return failure or packed matmul with one of its operands transposed.
 static FailureOr<PackTransposeResult>
 transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
                       tensor::PackOp packOp, AffineMap operandMap,



More information about the Mlir-commits mailing list