[Mlir-commits] [mlir] [mlir][linalg] Pack matmul pass (PR #89782)
Adam Siemieniuk
llvmlistbot at llvm.org
Tue Apr 23 08:42:09 PDT 2024
https://github.com/adam-smnk created https://github.com/llvm/llvm-project/pull/89782
Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]. The result is unpacked back to the original layout.
Matmul packing splits the operands into smaller blocks (inner dimensions) and then block-transposes the block sub-groups (outer dimensions).
This data arrangement minimizes distance between consecutive blocks which improves spacial locality and cache behavior.
>From ad5de2b01b364733483475a3238f22da7a2a2707 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Mon, 22 Apr 2024 18:21:34 +0200
Subject: [PATCH] [mlir][linalg] Pack matmul pass
Pack a matmul MxNxK operation into blocked layout
as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
The result is unpacked back to the original layout.
Matmul packing splits the operands into smaller blocks (inner dimensions)
and then block-transposes the block sub-groups (outer dimensions).
This data arrangement minimizes distance between consecutive blocks
which improves spacial locality and cache behavior.
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 20 ++
.../Dialect/Linalg/Transforms/Transforms.h | 4 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Dialect/Linalg/Transforms/PackMatmul.cpp | 177 ++++++++++++++++++
mlir/test/Dialect/Linalg/pack-matmul.mlir | 140 ++++++++++++++
5 files changed, 342 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
create mode 100644 mlir/test/Dialect/Linalg/pack-matmul.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..d4361c70468bdb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,24 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}
+def LinalgPackMatmul : Pass<"linalg-pack-matmul"> {
+ let summary = "Convert linalg matmul ops to block layout and back";
+ let description = [{
+ Pack a matmul MxNxK operation into blocked layout
+ as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb].
+ The result is unpacked back to the original layout.
+
+ Matmul packing splits the operands into smaller blocks (inner dimensions)
+ and then block-transposes the block sub-groups (outer dimensions).
+
+ This data arrangement minimizes distance between consecutive blocks
+ which improves spacial locality and cache behavior.
+ }];
+ let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+ let options = [
+ ListOption<"blockFactors", "block-factors", "int64_t",
+ "Block factors (mb, nb, kb) for relayout">
+ ];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c7012..2bb9277cc7b27e 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1628,6 +1628,10 @@ void populateSplitReductionPattern(
void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS = true);
+/// Patterns to pack Linalg matmul ops.
+void populatePackMatmulPatterns(RewritePatternSet &patterns,
+ ArrayRef<int64_t> blockingFactors);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index ee6e391d0cc682..e9b104ea5aeb58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
+ PackMatmul.cpp
Padding.cpp
Promotion.cpp
Specialize.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
new file mode 100644
index 00000000000000..304de03a343fdc
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackMatmul.cpp
@@ -0,0 +1,177 @@
+//===- PackMatmul.cpp - Linalg matmul packing -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+static std::optional<int64_t> getConstantRange(const Range &range) {
+ std::optional<int64_t> stride = getConstantIntValue(range.stride);
+ if (!stride || *stride != 1)
+ return std::nullopt;
+ std::optional<int64_t> offset = getConstantIntValue(range.offset);
+ if (!offset)
+ return std::nullopt;
+ std::optional<int64_t> size = getConstantIntValue(range.size);
+ if (!size)
+ return std::nullopt;
+ return (*size - *offset);
+}
+
+static bool validateFullTilesOnDims(TilingInterface tileOp,
+ ArrayRef<OpFoldResult> tiles,
+ ArrayRef<size_t> dims) {
+ if (dims.size() != tiles.size() || tiles.empty())
+ return false;
+
+ OpBuilder builder(tileOp);
+ OpBuilder::InsertionGuard guard(builder);
+ SmallVector<Range> iterationDomain =
+ cast<TilingInterface>(tileOp.getOperation()).getIterationDomain(builder);
+
+ for (auto dim : llvm::enumerate(dims)) {
+ if (dim.value() >= iterationDomain.size())
+ return false;
+
+ auto tileSize = getConstantIntValue(tiles[dim.index()]);
+ auto rangeOnDim = getConstantRange(iterationDomain[dim.value()]);
+
+ // If the tile factor or the range are non-constant, the tile size is
+ // considered to be invalid.
+ if (!tileSize || !rangeOnDim)
+ return false;
+
+ // The dimension must be fully divisible by the tile.
+ if (*rangeOnDim % *tileSize != 0)
+ return false;
+ }
+
+ return true;
+}
+
+static FailureOr<linalg::LinalgOp>
+packMatmulOp(RewriterBase &rewriter, linalg::LinalgOp matmulOp,
+ ArrayRef<OpFoldResult> mnkTiles) {
+ if (!(isa<linalg::MatmulOp>(matmulOp) ||
+ isa<linalg::BatchMatmulOp>(matmulOp))) {
+ return rewriter.notifyMatchFailure(matmulOp, "not a matmul-like operation");
+ }
+
+ if (mnkTiles.size() != 3)
+ return rewriter.notifyMatchFailure(matmulOp, "require 3 tile factors");
+
+ if (matmulOp.hasDynamicShape())
+ return rewriter.notifyMatchFailure(matmulOp, "require static shape");
+
+ if (matmulOp.hasPureBufferSemantics())
+ return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics");
+
+ SmallVector<size_t, 3> dims{0, 1, 2};
+ // Skip the batch dimension if present.
+ bool isBatchMatmulOp = isa<linalg::BatchMatmulOp>(matmulOp);
+ if (isBatchMatmulOp)
+ dims = {1, 2, 3};
+
+ if (!validateFullTilesOnDims(cast<TilingInterface>(matmulOp.getOperation()),
+ mnkTiles, dims)) {
+ return rewriter.notifyMatchFailure(matmulOp,
+ "expect packing full tiles only");
+ }
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ // The op is replaced, we need to set the insertion point after it.
+ rewriter.setInsertionPointAfter(matmulOp);
+
+ auto packedCanonicalMatmul = packMatmulGreedily(
+ rewriter, matmulOp, mnkTiles, /*mnkPaddedSizesNextMultipleOf=*/{},
+ /*mnkOrder=*/{0, 1, 2});
+ if (failed(packedCanonicalMatmul))
+ return failure();
+
+ assert(packedCanonicalMatmul->packOps.size() == 3 && "failed matmul packing");
+ assert(packedCanonicalMatmul->unPackOps.size() == 1 &&
+ "failed matmul unpacking");
+
+ SmallVector<int64_t> innerPerm = {1, 0};
+ SmallVector<int64_t> outerPerm = {1, 0};
+ // Leave the batch dimension as is.
+ if (isBatchMatmulOp)
+ outerPerm = {0, 2, 1};
+
+ auto packedMatmul =
+ packTranspose(rewriter, packedCanonicalMatmul->packOps[1],
+ packedCanonicalMatmul->packedLinalgOp,
+ /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+ if (failed(packedMatmul))
+ return failure();
+
+ return packedMatmul->transposedLinalgOp;
+}
+
+namespace {
+template <typename OpTy>
+struct PackMatmul : public OpRewritePattern<OpTy> {
+ PackMatmul(MLIRContext *context, ArrayRef<int64_t> blockFactors,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<OpTy>(context, benefit), blockFactors(blockFactors) {}
+
+ LogicalResult matchAndRewrite(OpTy matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (blockFactors.empty())
+ return failure();
+ auto packedMatmul =
+ packMatmulOp(rewriter, matmulOp,
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(blockFactors)));
+ if (failed(packedMatmul))
+ return failure();
+ return success();
+ }
+
+private:
+ SmallVector<int64_t> blockFactors;
+};
+
+// Entry point for packing matmul operations.
+// Pack MatmulOp as following:
+// [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][kb][nb]
+// Pack a BatchMatmulOp as following:
+// [B][MB][NB][mb][nb] += [B][MB][KB][mb][kb] * [B][NB][KB][kb][nb]
+struct LinalgPackMatmul : public impl::LinalgPackMatmulBase<LinalgPackMatmul> {
+ using LinalgPackMatmulBase::LinalgPackMatmulBase;
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(&getContext());
+ linalg::populatePackMatmulPatterns(patterns, blockFactors);
+ if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+} // namespace
+
+void linalg::populatePackMatmulPatterns(RewritePatternSet &patterns,
+ ArrayRef<int64_t> blockFactors) {
+ patterns.add<PackMatmul<linalg::MatmulOp>, PackMatmul<linalg::BatchMatmulOp>>(
+ patterns.getContext(), blockFactors);
+}
diff --git a/mlir/test/Dialect/Linalg/pack-matmul.mlir b/mlir/test/Dialect/Linalg/pack-matmul.mlir
new file mode 100644
index 00000000000000..d7023cfc30559b
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/pack-matmul.mlir
@@ -0,0 +1,140 @@
+// RUN: mlir-opt %s -linalg-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>
+// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME: into %[[BUF0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[BUF1:.*]] = tensor.empty() : tensor<8x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [64, 16]
+// CHECK-SAME: into %[[BUF1]] : tensor<128x128xf32> -> tensor<8x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]],
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]}
+// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[PACK2]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant dense<0.0> : tensor<128x128xf32>
+ %0 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%cst : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF_RES:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[BUF_OUT:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF_OUT]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %cst = arith.constant 0.0 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[BUF_RES:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[BUF_RES]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x64x16xf32>) outs(%[[FILL]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+ %arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>, %arg3: tensor<128x128xf32>) -> tensor<128x128xf32> {
+ %0 = tensor.empty() : tensor<128x128xf32>
+ %1 = linalg.matmul ins(%arg0, %arg1 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%arg2 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ %2 = linalg.add ins(%1, %arg3 : tensor<128x128xf32>, tensor<128x128xf32>)
+ outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+ return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG1:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG2:[0-9a-z]+]]: tensor<128x128xf32>, %[[ARG3:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[BUF:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[OUT:.+]] = linalg.add
+// CHECK-SAME: ins(%[[UNPACK]], %[[ARG3]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[BUF]] : tensor<128x128xf32>)
+// CHECK: return %[[OUT]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+ %arg0: tensor<512x64x128xf32>, %arg1: tensor<512x128x64xf32>, %arg2: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+ %0 = tensor.empty() : tensor<512x64x64xf32>
+ %1 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+ outs(%arg2 : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+ return %1 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d6, d5)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<512x64x128xf32>, %[[ARG1:.+]]: tensor<512x128x64xf32>, %[[ARG2:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[BUF0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK0:.+]] = tensor.pack %[[ARG0]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME: into %[[BUF0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[BUF1:.+]] = tensor.empty() : tensor<512x4x2x64x16xf32>
+// CHECK: %[[PACK1:.+]] = tensor.pack %[[ARG1]]
+// CHECK-SAME: outer_dims_perm = [0, 2, 1] inner_dims_pos = [1, 2] inner_tiles = [64, 16]
+// CHECK-SAME: into %[[BUF1]] : tensor<512x128x64xf32> -> tensor<512x4x2x64x16xf32>
+// CHECK: %[[BUF2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[PACK2:.+]] = tensor.pack %[[ARG2]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[BUF2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[VAL:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME: ins(%[[PACK0]], %[[PACK1]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x64x16xf32>) outs(%[[PACK2]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[OUT:.+]] = tensor.unpack %[[VAL]]
+// CHECK-SAME: inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG2]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[OUT]] : tensor<512x64x64xf32>
More information about the Mlir-commits
mailing list