[Mlir-commits] [mlir] [mlir][linalg] Pack matmul pass (PR #89782)
Adam Siemieniuk
llvmlistbot at llvm.org
Thu Apr 25 07:56:02 PDT 2024
https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/89782
>From ad5de2b01b364733483475a3238f22da7a2a2707 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 22 Apr 2024 18:21:34 +0200
Subject: [PATCH 1/3] [mlir][linalg] Pack matmul pass
Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
The result is unpacked back to the original layout.
Matmul packing splits the operands into smaller blocks (inner dimensions)
and then block-transposes the block sub-groups (outer dimensions).
This data arrangement minimizes distance between consecutive blocks
which improves spacial locality and cache behavior.
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 20 ++
.../Dialect/Linalg/Transforms/Transforms.h | 4 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Dialect/Linalg/Transforms/PackMatmul.cpp | 177 ++++++++++++++++++
mlir/test/Dialect/Linalg/pack-matmul.mlir | 140 ++++++++++++++
5 files changed, 342 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
create mode 100644 mlir/test/Dialect/Linalg/pack-matmul.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..d4361c70468bdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,24 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}
+def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+ let summary = "Convert linalg matmul ops to block layout and back";
+ let description = [{
+ Pack a matmul MxNxK operation into blocked layout
+ as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
+ The result is unpacked back to the original layout.
+
+ Matmul packing splits the operands into smaller blocks (inner dimensions)
+ and then block-transposes the block sub-groups (outer dimensions).
+
+ This data arrangement minimizes distance between consecutive blocks
+ which improves spacial locality and cache behavior.
+ }];
+ let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+ let options = [
+ ListOption<"blockFactors", "block-factors", "int64_t",
+ "Block factors (mb, nb, kb) for relayout">
+ ];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c7012..2bb9277cc7b27e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1628,6 +1628,10 @@ void populateSplitReductionPattern(
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS = true);
+/// Patterns to pack Linalg matmul ops.
+void populatePackMatmulPatterns(RewritePatternSet &patterns,
+ ArrayRef<int64_t> blockingFactors);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..e9b104ea5aeb58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ PackMatmul.cpp
Padding.cpp
Promotion.cpp
Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
new file mode 100644
index 00000000000000..304de03a343fdc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+ std::optional<int64_t> stride = getConstantIntValue(range.stride);
+ if (!stride || *stride != 1)
+ return std::nullopt;
+ std::optional<int64_t> offset = getConstantIntValue(range.offset);
+ if (!offset)
+ return std::nullopt;
+ std::optional<int64_t> size = getConstantIntValue(range.size);
+ if (!size)
+ return std::nullopt;
+ return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+ ArrayRef<OpFoldResult> tiles,
+ ArrayRef<size_t> dims) {
+ if (dims.size() != tiles.size() || tiles.empty())
+ return false;
+
+ OpBuilder builder(tileOp);
+ OpBuilder::InsertionGuard guard(builder);
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+ for (auto dim : llvm::enumerate(dims)) {
+ if (dim.value() >= iterationDomain.size())
+ return false;
+
+ auto tileSize = getConstantIntValue(tiles[dim.index()]);
+ auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+ // If the tile factor or the range are non-constant, the tile size is
+ // considered to be invalid.
+ if (!tileSize || !rangeOnDim)
+ return false;
+
+ // The dimension must be fully divisible by the tile.
+ if (*rangeOnDim % *tileSize != 0)
+ return false;
+ }
+
+ return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+ ArrayRef<OpFoldResult> mnkTiles) {
+ if (!(isa<linalg::MatmulOp>(matmulOp) ||
+ isa<linalg::BatchMatmulOp>(matmulOp))) {
+ return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+ }
+
+ if (mnkTiles.size() != 3)
+ return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+ if (matmulOp.hasDynamicShape())
+ return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+
+ if (matmulOp.hasPureBufferSemantics())
+ return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+
+ SmallVector<size_t, 3> dims{0, 1, 2};
+ // Skip the batch dimension if present.
+ bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
+ if (isBatchMatmulOp)
+ dims = {1, 2, 3};
+
+ if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+ mnkTiles, dims)) {
+ return rewriter.notifyMatchFailure(matmulOp,
+ "expect packing full tiles only");
+ }
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // The op is replaced, we need to set the insertion point after it.
+ rewriter.setInsertionPointAfter(matmulOp);
+
+ auto packedCanonicalMatmul = packMatmulGreedily(
+ rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
+ /*mnkOrder=*/{0, 1, 2});
+ if (failed(packedCanonicalMatmul))
+ return failure();
+
+ assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
+ assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
+ "failed matmul unpacking");
+
+ SmallVector<int64_t> innerPerm = {1, 0};
+ SmallVector<int64_t> outerPerm = {1, 0};
+ // Leave the batch dimension as is.
+ if (isBatchMatmulOp)
+ outerPerm = {0, 2, 1};
+
+ auto packedMatmul =
+ packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+ packedCanonicalMatmul->packedLinalgOp,
+ /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+ if (failed(packedMatmul))
+ return failure();
+
+ return packedMatmul->transposedLinalgOp;
+}
+
+namespace {
+template <typename OpTy>
+struct PackMatmul : public OpRewritePattern<OpTy> {
+ PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+
+ LogicalResult matchAndRewrite(OpTy matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (blockFactors.empty())
+ return failure();
+ auto packedMatmul =
+ packMatmulOp(rewriter, matmulOp,
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+ if (failed(packedMatmul))
+ return failure();
+ return success();
+ }
+
+private:
+ SmallVector<int64_t> blockFactors;
+};
+
+// Entry point for packing matmul operations.
+// Pack MatmulOp as following:
+// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+// Pack a BatchMatmulOp as following:
+// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
+struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
+ using LinalgPackMatmulBase::LinalgPackMatmulBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(&getContext());
+ linalg::populatePackMatmulPatterns(patterns, blockFactors);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
+ ArrayRef<int64_t> blockFactors) {
+ patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
+ patterns.getContext(), blockFactors);
+}
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
new file mode 100644
index 00000000000000..d7023cfc30559b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME: into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME: into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant dense<0.0> : tensor<128x128xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[OUT:.+]] = linalg.add
+// CHECK-SAME: ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+ %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+ %0 = tensor.empty() : tensor<512x64x64xf32>
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+ outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+ return %1 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME: into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME: into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[OUT]] : tensor<512x64x64xf32>
>From b2f7ab4158dbe5b2364de3ea9dc914ddaaa6055f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 13:06:58 +0200
Subject: [PATCH 2/3] Improve test naming
---
mlir/test/Dialect/Linalg/pack-matmul.mlir | 144 +++++++++++-----------
1 file changed, 72 insertions(+), 72 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
index d7023cfc30559b..c704d0fb7d4fa5 100644
--- a/mlir/test/Dialect/Linalg/pack-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -1,9 +1,9 @@
// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
func.func @block_matmul(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}
@@ -12,103 +12,103 @@ func.func @block_matmul(
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
// CHECK-LABEL: func @block_matmul(
-// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
-// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
-// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
-// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
-// CHECK-SAME: into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
-// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
-// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
-// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
// -----
func.func @block_matmul_with_constant(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
- %cst = arith.constant dense<0.0> : tensor<128x128xf32>
- %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%cst_acc : tensor<128x128xf32>) -> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}
// CHECK-LABEL: func @block_matmul_with_constant(
-// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
-// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME: into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
// -----
func.func @block_matmul_with_producer(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
%cst = arith.constant 0.0 : f32
- %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
- %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %acc = linalg.fill ins(%cst : f32) outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%acc : tensor<128x128xf32>) -> tensor<128x128xf32>
return %1 : tensor<128x128xf32>
}
// CHECK-LABEL: func @block_matmul_with_producer(
-// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
-// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
-// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
// -----
func.func @block_matmul_with_consumer(
- %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>, %D: tensor<128x128xf32>) -> tensor<128x128xf32> {
%0 = tensor.empty() : tensor<128x128xf32>
- %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
- outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
- %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+ %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %2 = linalg.add ins(%1, %D : tensor<128x128xf32>, tensor<128x128xf32>)
outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
return %2 : tensor<128x128xf32>
}
// CHECK-LABEL: func @block_matmul_with_consumer(
-// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
-// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>, %[[D:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
// CHECK-SAME: outs({{.*}} : tensor<4x8x32x16xf32>)
-// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
-// CHECK: %[[OUT:.+]] = linalg.add
-// CHECK-SAME: ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
-// CHECK: return %[[OUT]] : tensor<128x128xf32>
+// CHECK-SAME: into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[ADD_RES:.+]] = linalg.add
+// CHECK-SAME: ins(%[[RES_UNPACKED]], %[[D]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[RES_DST]] : tensor<128x128xf32>)
+// CHECK: return %[[ADD_RES]] : tensor<128x128xf32>
// -----
func.func @block_batch_matmul(
- %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+ %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
%0 = tensor.empty() : tensor<512x64x64xf32>
- %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
- outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+ %1 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+ outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
return %1 : tensor<512x64x64xf32>
}
@@ -117,24 +117,24 @@ func.func @block_batch_matmul(
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
// CHECK-LABEL: func @block_batch_matmul(
-// CHECK-SAME: %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
-// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
-// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 64]
-// CHECK-SAME: into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
-// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
-// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
-// CHECK-SAME: into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
-// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
-// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
-// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
-// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
-// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
-// CHECK-SAME: into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
-// CHECK: return %[[OUT]] : tensor<512x64x64xf32>
+// CHECK-SAME: into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
>From e2ad091d54a879fd042d9b67bc453fc8e533104a Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Thu, 25 Apr 2024 13:58:56 +0200
Subject: [PATCH 3/3] Improve description
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 39 +++++++++++++++++-----
1 file changed, 30 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index d4361c70468bdb..907ebb2f9c4be1 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -144,15 +144,36 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
let summary = "Convert linalg matmul ops to block layout and back";
let description = [{
- Pack a matmul MxNxK operation into blocked layout
- as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
- The result is unpacked back to the original layout.
-
- Matmul packing splits the operands into smaller blocks (inner dimensions)
- and then block-transposes the block sub-groups (outer dimensions).
-
- This data arrangement minimizes distance between consecutive blocks
- which improves spacial locality and cache behavior.
+ Pack a matmul operation into blocked layout with two levels of subdivision:
+ - major 2D blocks - outer dimensions, consist of minor blocks
+ - minor 2D blocks - inner dimensions, consist of scalar elements
+
+ A 2D matmul MxNxK gets reshaped into blocked 4D representation
+ as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+ where the (MB, NB, KB) dimensions represent the major blocks,
+ and the (mb, nb, kb) are the minor blocks of their respective
+ original 2D dimensions (M, N, K).
+
+ As a part of packing strategy, the RHS operand gets 'block transposed'
+ i.e., the major blocks [KB][NB] get transposed to [NB][KB] layout.
+ The minor blocks remain unchanged.
+ The final result is unpacked back to the original layout.
+
+ Given a matmul operation:
+ ```mlir
+ %res = linalg.matmul ins(%A, %B) outs(%C)
+ ```
+ the traformation can be represented as:
+ ```mlir
+ %A_packed = pack %A : 2D -> 4D
+ %B_packed = pack %B : 2D -> 4D #block_transposed
+ %C_packed = pack %C : 2D -> 4D
+ %res_packed = linalg.mmt4d ins(%A_packed, %B_packed) outs(%C_packed)
+ %res = unpack %res_packed : 4D -> 2D
+ ```
+
+ This packed data arrangement minimizes distance between consecutive
+ blocks which improves spacial locality and cache behavior.
}];
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
let options = [
More information about the Mlir-commits
mailing list