[Mlir-commits] [mlir] 4c3db25 - [mlir][linalg] Block pack matmul pass (#89782)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 9 10:02:48 PDT 2024


Author: Adam Siemieniuk
Date: 2024-05-09T18:02:44+01:00
New Revision: 4c3db2588e8b38f75744def6e2dd17c556950e46

URL: https://github.com/llvm/llvm-project/commit/4c3db2588e8b38f75744def6e2dd17c556950e46
DIFF: https://github.com/llvm/llvm-project/commit/4c3db2588e8b38f75744def6e2dd17c556950e46.diff

LOG: [mlir][linalg] Block pack matmul pass (#89782)

Pack a matmul MxNxK operation into 4D blocked layout. Any present batch
dimensions remain unchanged and the result is unpacked back to the
original layout.

Matmul block packing splits the operands into major blocks (outer
dimensions) and minor blocks (inner dimensions). The desired block
layout can be controlled through packing options.

Added: 
    mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
    mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
    mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir
    mlir/test/Dialect/Linalg/block-pack-matmul.mlir

Modified: 
    mlir/include/mlir/Dialect/Linalg/Passes.td
    mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
    mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a7..0a4ce8953136d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,63 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgBlockPackMatmul : Pass<"linalg-block-pack-matmul"> {
+  let summary = "Convert linalg matmul ops to block layout and back";
+  let description = [{
+    Pack a matmul operation into blocked layout with two levels of subdivision:
+    - major 2D blocks - outer dimensions, consist of minor blocks
+    - minor 2D blocks - inner dimensions, consist of scalar elements
+
+    A 2D matmul MxNxK gets reshaped into blocked 4D representation
+    as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
+    where the (MB, NB, KB) dimensions represent the major blocks,
+    and the (mb, nb, kb) are the minor blocks of their respective
+    original 2D dimensions (M, N, K).
+
+    Depending on the initial operands' data layout and the specified
+    packing options, the major blocks dimensions might get transposed
+    e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
+    e.g., [mb][kb] -> [kb][mb].
+    Any present batch dimensions remain unchanged.
+    The final result is unpacked back to the original shape.
+
+    For example, given a matmul operation:
+    ```mlir
+      %res = linalg.matmul ins(%A, %B) outs(%C)
+    ```
+    the default transformation result can be represented as:
+    ```mlir
+      %A_packed = pack %A : 2D <MxK> -> 4D <MBxKBxmbxkb>
+      %B_packed = pack %B : 2D <KxN> -> 4D <NBxKBxnbxkb>
+      %C_packed = pack %C : 2D <MxN> -> 4D <MBxNBxmbxnb>
+      %res_packed = linalg.mmt4d ins(%A_packed, %B_packed) outs(%C_packed)
+      %res = unpack %res_packed : 4D <MBxNBxmbxnb> -> 2D <MxN>
+    ```
+  }];
+  let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
+  let options = [
+    ListOption<"blockFactors", "block-factors", "int64_t",
+               "Block factors (mb, nb, kb) for relayout">,
+    Option<"allowPadding", "allow-padding", "bool",
+           /*default=*/"true",
+           "Allow packing padding">,
+    ListOption<"mnkPaddedSizesNextMultipleOf", "mnk-padded-multiples", "int64_t",
+               "Next multiples of the packing sizes">,
+    ListOption<"mnkOrder", "mnk-order", "int64_t",
+               "Permutation of matmul (M, N, K) dimensions order">,
+    Option<"lhsTransposeOuterBlocks", "lhs-transpose-outer-blocks", "bool",
+           /*default=*/"false",
+           "Transpose LHS outer block layout [MB][KB] -> [KB][MB]">,
+    Option<"lhsTransposeInnerBlocks", "lhs-transpose-inner-blocks", "bool",
+           /*default=*/"false",
+           "Transpose LHS inner block layout [mb][kb] -> [kb][mb]">,
+    Option<"rhsTransposeOuterBlocks", "rhs-transpose-outer-blocks", "bool",
+           /*default=*/"true",
+           "Transpose RHS outer block layout [KB][NB] -> [NB][KB]">,
+    Option<"rhsTransposeInnerBlocks", "rhs-transpose-inner-blocks", "bool",
+           /*default=*/"true",
+           "Transpose RHS inner block layout [kb][nb] -> [nb][kb]">
+  ];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES

diff  --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 5ecf84fa9c701..f77c19ed0fcce 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1162,6 +1162,66 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp,
                    ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf,
                    ArrayRef<int64_t> mnkOrder);
 
+struct BlockPackMatmulOptions {
+  /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are
+  /// the parallel dimensions and kb is the reduction dimension.
+  SmallVector<int64_t, 3> blockFactors;
+
+  /// If true, allows packing of dimensions that only partially fit into the
+  /// block factors.
+  bool allowPadding = true;
+
+  /// Next multiples of the packing sizes.
+  SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf;
+
+  /// Permutation of matmul (M, N, K) dimensions order.
+  SmallVector<int64_t, 3> mnkOrder = {0, 1, 2};
+
+  /// Transpose LHS outer block layout [MB][KB] -> [KB][MB].
+  bool lhsTransposeOuterBlocks = false;
+
+  /// Transpose LHS inner block layout [mb][kb] -> [kb][mb].
+  bool lhsTransposeInnerBlocks = false;
+
+  /// Transpose RHS outer block layout [KB][NB] -> [NB][KB].
+  bool rhsTransposeOuterBlocks = true;
+
+  /// Transpose RHS inner block layout [kb][nb] -> [nb][kb].
+  bool rhsTransposeInnerBlocks = true;
+};
+
+/// Function type which is used to control matmul packing.
+/// It is expected to return valid packing configuration for each operation.
+/// Lack of packing options indicates that no valid configuration could be
+/// assigned and the operation will not be packed.
+using ControlBlockPackMatmulFn =
+    std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>;
+
+/// Pack a matmul operation into blocked 4D layout.
+///
+/// Relayout a matmul operation into blocked layout with two levels of
+/// subdivision:
+///   - major 2D blocks - outer dimensions, consist of minor blocks
+///   - minor 2D blocks - inner dimensions, consist of scalar elements
+///
+/// A 2D matmul MxNxK gets reshaped into blocked 4D representation
+/// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb]
+/// where the (MB, NB, KB) dimensions represent the major blocks,
+/// and the (mb, nb, kb) are the minor blocks of their respective
+/// original 2D dimensions (M, N, K).
+///
+/// Depending on the initial operands' data layout and the specified
+/// packing options, the major blocks dimensions might get transposed
+/// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed
+/// e.g., [mb][kb] -> [kb][mb].
+/// Any present batch dimensions remain unchanged.
+/// The final result is unpacked back to the original shape.
+///
+/// Return failure if no valid packing options are provided.
+FailureOr<PackResult>
+blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+                const ControlBlockPackMatmulFn &controlPackMatmul);
+
 /// Rewrite tensor.from_elements to linalg.generic.
 FailureOr<Operation *>
 rewriteInDestinationPassingStyle(RewriterBase &rewriter,
@@ -1628,6 +1688,10 @@ void populateSplitReductionPattern(
 void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
                                      bool transposeLHS = true);
 
+/// Patterns to block pack Linalg matmul ops.
+void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
+                                     const ControlBlockPackMatmulFn &controlFn);
+
 } // namespace linalg
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
new file mode 100644
index 0000000000000..c07d1387ec753
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -0,0 +1,321 @@
+//===- BlockPackMatmul.cpp - Linalg matmul block packing ------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+#include <optional>
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGBLOCKPACKMATMUL
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+/// Return constant range span or nullopt, otherwise.
+static std::optional<int64_t> getConstantRange(const Range &range) {
+  std::optional<int64_t> stride = getConstantIntValue(range.stride);
+  if (!stride || *stride != 1)
+    return std::nullopt;
+  std::optional<int64_t> offset = getConstantIntValue(range.offset);
+  if (!offset)
+    return std::nullopt;
+  std::optional<int64_t> size = getConstantIntValue(range.size);
+  if (!size)
+    return std::nullopt;
+  return (*size - *offset);
+}
+
+/// Return true if all dimensions are fully divisible by the respective tiles.
+static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp,
+                                    ArrayRef<OpFoldResult> tiles,
+                                    ArrayRef<int64_t> dims) {
+  if (dims.size() != tiles.size() || tiles.empty())
+    return false;
+
+  FailureOr<ContractionDimensions> contractDims =
+      inferContractionDims(linalgOp);
+  if (failed(contractDims))
+    return false;
+  unsigned batchDimsOffset = contractDims->batch.size();
+
+  // Skip the batch dimension if present.
+  // Offset all dimensions accordingly.
+  SmallVector<int64_t, 3> offsetDims{dims};
+  for (size_t i = 0; i < offsetDims.size(); i++)
+    offsetDims[i] += batchDimsOffset;
+
+  auto tileOp = cast<TilingInterface>(linalgOp.getOperation());
+  OpBuilder builder(tileOp);
+  OpBuilder::InsertionGuard guard(builder);
+  SmallVector<Range> iterationDomain = tileOp.getIterationDomain(builder);
+
+  for (auto dim : llvm::enumerate(offsetDims)) {
+    if (dim.value() >= static_cast<int64_t>(iterationDomain.size()))
+      return false;
+
+    std::optional<int64_t> tileSize = getConstantIntValue(tiles[dim.index()]);
+    std::optional<int64_t> rangeOnDim =
+        getConstantRange(iterationDomain[dim.value()]);
+
+    // If the tile factor or the range are non-constant, the tile size is
+    // considered to be invalid.
+    if (!tileSize || !rangeOnDim)
+      return false;
+
+    // The dimension must be fully divisible by the tile.
+    if (*rangeOnDim % *tileSize != 0)
+      return false;
+  }
+
+  return true;
+}
+
+/// Return failure or packed matmul with one of its operands transposed.
+static FailureOr<PackTransposeResult>
+transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+                      tensor::PackOp packOp, AffineMap operandMap,
+                      ArrayRef<unsigned> blocksStartDimPos,
+                      bool transposeOuterBlocks, bool transposeInnerBlocks) {
+  assert(operandMap.getNumDims() >= 4 &&
+         "expected at least 4D prepacked matmul");
+  assert(blocksStartDimPos.size() >= 2 &&
+         "expected starting outer and inner block positions");
+
+  // Bias toward innermost dimensions.
+  unsigned outerBlockPos = operandMap.getNumResults() - 4;
+  unsigned innerBlockPos = operandMap.getNumResults() - 2;
+
+  // Transpose control options define the desired block and element layout.
+  // Block transposition (outer dimensions) or element transposition (inner
+  // dimensions) may not be necessary depending on the original matmul data
+  // layout.
+  bool isOuterTransposed =
+      operandMap.getDimPosition(outerBlockPos) != blocksStartDimPos.end()[-2];
+  bool isInnerTransposed =
+      operandMap.getDimPosition(innerBlockPos) != blocksStartDimPos.back();
+
+  // Transpose only the dimensions that need that to conform to the provided
+  // transpotion settings.
+  SmallVector<int64_t> innerPerm{0, 1};
+  if (isInnerTransposed != transposeInnerBlocks)
+    innerPerm = {1, 0};
+  SmallVector<int64_t> outerPerm{0, 1};
+  if (isOuterTransposed != transposeOuterBlocks)
+    outerPerm = {1, 0};
+
+  // Leave the outer dimensions, like batch, unchanged by offsetting all
+  // outer dimensions permutations.
+  SmallVector<int64_t> offsetPerms;
+  for (auto i : llvm::seq(0u, outerBlockPos))
+    offsetPerms.push_back(i);
+  for (auto perm : outerPerm)
+    offsetPerms.push_back(perm + outerBlockPos);
+  outerPerm = offsetPerms;
+
+  FailureOr<PackTransposeResult> packTransposedMatmul =
+      packTranspose(rewriter, packOp, linalgOp,
+                    /*maybeUnPackOp=*/nullptr, outerPerm, innerPerm);
+
+  return packTransposedMatmul;
+}
+
+/// Pack a matmul operation into blocked 4D layout.
+FailureOr<PackResult>
+linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
+                        const ControlBlockPackMatmulFn &controlPackMatmul) {
+  if (linalgOp.hasPureBufferSemantics())
+    return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics");
+
+  std::optional<BlockPackMatmulOptions> options = controlPackMatmul(linalgOp);
+  if (!options)
+    return rewriter.notifyMatchFailure(linalgOp, "invalid packing options");
+
+  if (options->blockFactors.size() != 3)
+    return rewriter.notifyMatchFailure(linalgOp, "require 3 tile factors");
+
+  SmallVector<OpFoldResult> mnkTiles =
+      getAsOpFoldResult(rewriter.getI64ArrayAttr(options->blockFactors));
+
+  // If padding is disabled, make sure that dimensions can be packed cleanly.
+  if (!options->allowPadding &&
+      !validateFullTilesOnDims(linalgOp, mnkTiles, options->mnkOrder)) {
+    return rewriter.notifyMatchFailure(linalgOp,
+                                       "expect packing full tiles only");
+  }
+
+  OpBuilder::InsertionGuard guard(rewriter);
+  // The op is replaced, we need to set the insertion point after it.
+  rewriter.setInsertionPointAfter(linalgOp);
+
+  // Pack the matmul operation into blocked layout with two levels of
+  // subdivision:
+  //   - major 2D blocks - outer dimensions, consist of minor blocks
+  //   - minor 2D blocks - inner dimensions, consist of scalar elements
+  FailureOr<PackResult> packedMatmul = packMatmulGreedily(
+      rewriter, linalgOp, mnkTiles, options->mnkPaddedSizesNextMultipleOf,
+      options->mnkOrder);
+  if (failed(packedMatmul))
+    return failure();
+
+  assert(packedMatmul->packOps.size() == 3 &&
+         "invalid number of pack ops after matmul packing");
+  assert(packedMatmul->unPackOps.size() == 1 &&
+         "invalid number of unpack ops after matmul packing");
+
+  FailureOr<ContractionDimensions> contractDims =
+      inferContractionDims(packedMatmul->packedLinalgOp);
+  if (failed(contractDims))
+    return failure();
+
+  auto genericOp =
+      dyn_cast<linalg::GenericOp>(packedMatmul->packedLinalgOp.getOperation());
+  SmallVector<AffineMap> maps = genericOp.getIndexingMapsArray();
+
+  // Transpose LHS matrix according to the options.
+  FailureOr<PackTransposeResult> packedLhs = transposePackedMatmul(
+      rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[0], maps[0],
+      contractDims->m, options->lhsTransposeOuterBlocks,
+      options->lhsTransposeInnerBlocks);
+  if (failed(packedLhs))
+    return failure();
+
+  // Update results.
+  packedMatmul->packOps[0] = packedLhs->transposedPackOp;
+  packedMatmul->packedLinalgOp = packedLhs->transposedLinalgOp;
+
+  // Transpose RHS matrix according to the options.
+  FailureOr<PackTransposeResult> packedRhs = transposePackedMatmul(
+      rewriter, packedMatmul->packedLinalgOp, packedMatmul->packOps[1], maps[1],
+      contractDims->k, options->rhsTransposeOuterBlocks,
+      options->rhsTransposeInnerBlocks);
+  if (failed(packedRhs))
+    return failure();
+
+  // Update results.
+  packedMatmul->packOps[1] = packedRhs->transposedPackOp;
+  packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp;
+
+  return packedMatmul;
+}
+
+namespace {
+template <typename OpTy>
+struct BlockPackMatmul : public OpRewritePattern<OpTy> {
+  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern<OpTy>(context, benefit), controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(OpTy linalgOp,
+                                PatternRewriter &rewriter) const override {
+    FailureOr<PackResult> packedMatmul =
+        blockPackMatmul(rewriter, linalgOp, controlFn);
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  ControlBlockPackMatmulFn controlFn;
+};
+
+template <>
+struct BlockPackMatmul<linalg::GenericOp>
+    : public OpRewritePattern<linalg::GenericOp> {
+  BlockPackMatmul(MLIRContext *context, ControlBlockPackMatmulFn fun,
+                  PatternBenefit benefit = 1)
+      : OpRewritePattern<linalg::GenericOp>(context, benefit),
+        controlFn(std::move(fun)) {}
+
+  LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
+                                PatternRewriter &rewriter) const override {
+    // Match suitable generics.
+    if (failed(linalg::detail::verifyContractionInterface(
+            linalgOp.getOperation()))) {
+      return rewriter.notifyMatchFailure(linalgOp, "not a contraction");
+    }
+
+    using MapList = ArrayRef<ArrayRef<AffineExpr>>;
+    auto infer = [&](MapList m) {
+      return AffineMap::inferFromExprList(m, linalgOp.getContext());
+    };
+
+    AffineExpr i, j, k;
+    bindDims(linalgOp->getContext(), i, j, k);
+    SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
+
+    // For now, only match simple matmuls.
+    if (!(maps == infer({{i, k}, {k, j}, {i, j}}) ||
+          maps == infer({{k, i}, {k, j}, {i, j}}) ||
+          maps == infer({{i, k}, {j, k}, {i, j}}))) {
+      return rewriter.notifyMatchFailure(linalgOp, "not a suitable matmul");
+    }
+
+    FailureOr<PackResult> packedMatmul =
+        blockPackMatmul(rewriter, linalgOp, controlFn);
+    if (failed(packedMatmul))
+      return failure();
+    return success();
+  }
+
+private:
+  ControlBlockPackMatmulFn controlFn;
+};
+
+/// Convert linalg matmul ops to block layout and back.
+struct LinalgBlockPackMatmul
+    : public impl::LinalgBlockPackMatmulBase<LinalgBlockPackMatmul> {
+  using LinalgBlockPackMatmulBase::LinalgBlockPackMatmulBase;
+
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(&getContext());
+
+    ControlBlockPackMatmulFn controlFn =
+        [&](linalg::LinalgOp op) -> BlockPackMatmulOptions {
+      BlockPackMatmulOptions options;
+      options.blockFactors = SmallVector<int64_t>{*blockFactors};
+      options.allowPadding = allowPadding;
+      options.mnkPaddedSizesNextMultipleOf =
+          SmallVector<int64_t>{*mnkPaddedSizesNextMultipleOf};
+      if (!mnkOrder.empty())
+        options.mnkOrder = SmallVector<int64_t>{*mnkOrder};
+      options.lhsTransposeOuterBlocks = lhsTransposeOuterBlocks;
+      options.lhsTransposeInnerBlocks = lhsTransposeInnerBlocks;
+      options.rhsTransposeOuterBlocks = rhsTransposeOuterBlocks;
+      options.rhsTransposeInnerBlocks = rhsTransposeInnerBlocks;
+      return options;
+    };
+
+    linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
+    if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+} // namespace
+
+void linalg::populateBlockPackMatmulPatterns(
+    RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
+  patterns.add<BlockPackMatmul<linalg::GenericOp>,
+               BlockPackMatmul<linalg::MatmulOp>,
+               BlockPackMatmul<linalg::BatchMatmulOp>,
+               BlockPackMatmul<linalg::MatmulTransposeAOp>,
+               BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
+               BlockPackMatmul<linalg::MatmulTransposeBOp>,
+               BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
+      patterns.getContext(), controlFn);
+}

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3b5282a09569d..ed9f40089282a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
+  BlockPackMatmul.cpp
   Padding.cpp
   Promotion.cpp
   RuntimeOpVerification.cpp

diff  --git a/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
new file mode 100644
index 0000000000000..01ca4374da046
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-layout.mlir
@@ -0,0 +1,101 @@
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
+// RUN: lhs-transpose-outer-blocks=false lhs-transpose-inner-blocks=false \
+// RUN: rhs-transpose-outer-blocks=true rhs-transpose-inner-blocks=true" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=MMT4D
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
+// RUN: lhs-transpose-outer-blocks=false lhs-transpose-inner-blocks=false \
+// RUN: rhs-transpose-outer-blocks=false rhs-transpose-inner-blocks=false" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=MM4D
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 \
+// RUN: lhs-transpose-outer-blocks=true lhs-transpose-inner-blocks=true \
+// RUN: rhs-transpose-outer-blocks=false rhs-transpose-inner-blocks=false" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=MTM4D
+
+func.func @block_matmul(
+    %A: tensor<64x128xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<64x128xf32>, tensor<128x64xf32>)
+                      outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+func.func @block_matmul_transpose_a(
+    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+func.func @block_matmul_transpose_b(
+    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// MMT4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// MMT4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// MMT4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// MMT4D-LABEL: func @block_matmul
+// MMT4D-COUNT-3: tensor.pack
+// MMT4D: linalg.generic
+// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MMT4D-COUNT-1: tensor.unpack
+// MMT4D-LABEL: func @block_matmul_transpose_a
+// MMT4D-COUNT-3: tensor.pack
+// MMT4D: linalg.generic
+// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MMT4D-COUNT-1: tensor.unpack
+// MMT4D-LABEL: func @block_matmul_transpose_b
+// MMT4D-COUNT-3: tensor.pack
+// MMT4D: linalg.generic
+// MMT4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MMT4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MMT4D-COUNT-1: tensor.unpack
+
+// MM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// MM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
+// MM4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// MM4D-LABEL: func @block_matmul
+// MM4D-COUNT-3: tensor.pack
+// MM4D: linalg.generic
+// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MM4D-COUNT-1: tensor.unpack
+// MM4D-LABEL: func @block_matmul_transpose_a
+// MM4D-COUNT-3: tensor.pack
+// MM4D: linalg.generic
+// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MM4D-COUNT-1: tensor.unpack
+// MM4D-LABEL: func @block_matmul_transpose_b
+// MM4D-COUNT-3: tensor.pack
+// MM4D: linalg.generic
+// MM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MM4D-COUNT-1: tensor.unpack
+
+// MTM4D-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d5, d3)>
+// MTM4D-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d1, d5, d4)>
+// MTM4D-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// MTM4D-LABEL: func @block_matmul
+// MTM4D-COUNT-3: tensor.pack
+// MTM4D: linalg.generic
+// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MTM4D-COUNT-1: tensor.unpack
+// MTM4D-LABEL: func @block_matmul_transpose_a
+// MTM4D-COUNT-3: tensor.pack
+// MTM4D: linalg.generic
+// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MTM4D-COUNT-1: tensor.unpack
+// MTM4D-LABEL: func @block_matmul_transpose_b
+// MTM4D-COUNT-3: tensor.pack
+// MTM4D: linalg.generic
+// MTM4D-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// MTM4D-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// MTM4D-COUNT-1: tensor.unpack

diff  --git a/mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir
new file mode 100644
index 0000000000000..9e396ba08d246
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul-padding.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=1" \
+// RUN: -canonicalize | FileCheck %s
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=0" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=NOPAD
+
+// RUN: mlir-opt %s -linalg-block-pack-matmul="block-factors=32,16,64 allow-padding=1 mnk-padded-multiples=256,512,384" \
+// RUN: -canonicalize | FileCheck %s --check-prefix=PAD-MULT
+
+func.func @block_matmul_padding(
+    %A: tensor<123x125xf32>, %B: tensor<125x124xf32>, %C: tensor<123x124xf32>) -> tensor<123x124xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<123x125xf32>, tensor<125x124xf32>)
+                      outs(%C : tensor<123x124xf32>) -> tensor<123x124xf32>
+  return %0 : tensor<123x124xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// CHECK-LABEL: func @block_matmul_padding(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32>
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<123x125xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<125x124xf32> -> tensor<8x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<123x124xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<123x124xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<123x124xf32>
+
+// NOPAD-LABEL: func @block_matmul_padding(
+// NOPAD-SAME:    %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32>
+// NOPAD-NOT: tensor.pack
+// NOPAD: linalg.matmul ins(%[[A]], %[[B]] : tensor<123x125xf32>, tensor<125x124xf32>)
+// NOPAD-SAME: outs(%[[C]] : tensor<123x124xf32>) -> tensor<123x124xf32>
+// NOPAD-NOT: tensor.unpack
+
+// PAD-MULT-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// PAD-MULT-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// PAD-MULT-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+// PAD-MULT-LABEL: func @block_matmul_padding(
+// PAD-MULT-SAME:    %[[A:[0-9a-z]+]]: tensor<123x125xf32>, %[[B:[0-9a-z]+]]: tensor<125x124xf32>, %[[C:[0-9a-z]+]]: tensor<123x124xf32>
+// PAD-MULT-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// PAD-MULT: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<1x1x256x384xf32>
+// PAD-MULT: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// PAD-MULT-SAME:  padding_value(%[[ZERO]] : f32)
+// PAD-MULT-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [256, 384]
+// PAD-MULT-SAME:  into %[[PACK_DST_0]] : tensor<123x125xf32> -> tensor<1x1x256x384xf32>
+// PAD-MULT: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<1x1x512x384xf32>
+// PAD-MULT: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// PAD-MULT-SAME:  padding_value(%[[ZERO]] : f32)
+// PAD-MULT-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [512, 384]
+// PAD-MULT-SAME:  into %[[PACK_DST_1]] : tensor<125x124xf32> -> tensor<1x1x512x384xf32>
+// PAD-MULT: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<1x1x256x512xf32>
+// PAD-MULT: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// PAD-MULT-SAME:  padding_value(%[[ZERO]] : f32)
+// PAD-MULT-SAME:  inner_dims_pos = [0, 1] inner_tiles = [256, 512]
+// PAD-MULT-SAME:  into %[[PACK_DST_2]] : tensor<123x124xf32> -> tensor<1x1x256x512xf32>
+// PAD-MULT: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// PAD-MULT-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// PAD-MULT-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// PAD-MULT-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<1x1x256x384xf32>, tensor<1x1x512x384xf32>) outs(%[[C_PACKED]] : tensor<1x1x256x512xf32>)
+// PAD-MULT: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// PAD-MULT-SAME:  inner_dims_pos = [0, 1] inner_tiles = [256, 512]
+// PAD-MULT-SAME:  into %[[C]] : tensor<1x1x256x512xf32> -> tensor<123x124xf32>
+// PAD-MULT: return %[[RES_UNPACKED]] : tensor<123x124xf32>

diff  --git a/mlir/test/Dialect/Linalg/block-pack-matmul.mlir b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
new file mode 100644
index 0000000000000..cc9af913ca15a
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/block-pack-matmul.mlir
@@ -0,0 +1,478 @@
+// RUN: mlir-opt %s -linalg-block-pack-matmul=block-factors=32,16,64 -canonicalize -split-input-file | FileCheck %s
+
+func.func @block_matmul(
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_dynamic(
+    %A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.matmul  ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
+                      outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-DAG: #[[$MAP_M:.+]] = affine_map<()[s0] -> (s0 ceildiv 32)>
+// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<()[s0] -> (s0 ceildiv 64)>
+// CHECK-DAG: #[[$MAP_N:.+]] = affine_map<()[s0] -> (s0 ceildiv 16)>
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul_dynamic(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<?x?xf32>, %[[B:[0-9a-z]+]]: tensor<?x?xf32>, %[[C:[0-9a-z]+]]: tensor<?x?xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[A_M:.+]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[A_K:.+]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[A_OUTER_TILE_M:.+]] = affine.apply #[[$MAP_M]]()[%[[A_M]]]
+// CHECK-DAG: %[[A_OUTER_TILE_K:.+]] = affine.apply #[[$MAP_K]]()[%[[A_K]]]
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty(%[[A_OUTER_TILE_M]], %[[A_OUTER_TILE_K]]) : tensor<?x?x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<?x?xf32> -> tensor<?x?x32x64xf32>
+// CHECK-DAG: %[[B_K:.+]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[B_N:.+]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[B_OUTER_TILE_K:.+]] = affine.apply #[[$MAP_K]]()[%[[B_K]]]
+// CHECK-DAG: %[[B_OUTER_TILE_N:.+]] = affine.apply #[[$MAP_N]]()[%[[B_N]]]
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty(%[[B_OUTER_TILE_N]], %[[B_OUTER_TILE_K]]) : tensor<?x?x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<?x?xf32> -> tensor<?x?x16x64xf32>
+// CHECK-DAG: %[[C_M:.+]] = tensor.dim %[[C]], %[[C0]] : tensor<?x?xf32>
+// CHECK-DAG: %[[C_N:.+]] = tensor.dim %[[C]], %[[C1]] : tensor<?x?xf32>
+// CHECK-DAG: %[[C_OUTER_TILE_M:.+]] = affine.apply #[[$MAP_M]]()[%[[C_M]]]
+// CHECK-DAG: %[[C_OUTER_TILE_N:.+]] = affine.apply #[[$MAP_N]]()[%[[C_N]]]
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty(%[[C_OUTER_TILE_M]], %[[C_OUTER_TILE_N]]) : tensor<?x?x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  padding_value(%[[ZERO]] : f32)
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<?x?xf32> -> tensor<?x?x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<?x?x32x64xf32>, tensor<?x?x16x64xf32>) outs(%[[C_PACKED]] : tensor<?x?x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<?x?x32x16xf32> -> tensor<?x?xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<?x?xf32>
+
+// -----
+
+func.func @block_matmul_with_constant(
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst_acc = arith.constant dense<0.0> : tensor<128x128xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%cst_acc : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_constant(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[CST_ACC_PACKED:.+]] = arith.constant dense<0.000000e+00> : tensor<4x8x32x16xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = arith.constant dense<0.000000e+00> : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[CST_ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[RES_DST]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_producer(
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %cst = arith.constant 0.0 : f32
+  %acc = linalg.fill ins(%cst : f32) outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                      outs(%acc : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %1 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_producer(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[FILL_DST_PACKED:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[ACC_PACKED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[FILL_DST_PACKED]] : tensor<4x8x32x16xf32>) -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  ins({{.*}} : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[ACC_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_matmul_with_consumer(
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>, %D: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = tensor.empty() : tensor<128x128xf32>
+  %1 = linalg.matmul ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+                     outs(%C : tensor<128x128xf32>) -> tensor<128x128xf32>
+  %2 = linalg.add ins(%1, %D : tensor<128x128xf32>, tensor<128x128xf32>)
+                  outs(%0 : tensor<128x128xf32>) -> tensor<128x128xf32>
+  return %2 : tensor<128x128xf32>
+}
+
+// CHECK-LABEL: func @block_matmul_with_consumer(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>, %[[D:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK-DAG: %[[RES_DST:.+]] = tensor.empty() : tensor<128x128xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  outs({{.*}} : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: %[[ADD_RES:.+]] = linalg.add
+// CHECK-SAME:  ins(%[[RES_UNPACKED]], %[[D]] : tensor<128x128xf32>, tensor<128x128xf32>) outs(%[[RES_DST]] : tensor<128x128xf32>)
+// CHECK: return %[[ADD_RES]] : tensor<128x128xf32>
+
+// -----
+
+func.func @block_batch_matmul(
+    %A: tensor<512x64x128xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = linalg.batch_matmul ins(%A, %B : tensor<512x64x128xf32>, tensor<512x128x64xf32>)
+                           outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %0 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul(
+// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
+
+// -----
+
+func.func @block_matmul_transpose_a(
+    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_a ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul_transpose_a(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
+
+// -----
+
+func.func @block_batch_matmul_transpose_a(
+    %A: tensor<512x128x64xf32>, %B: tensor<512x128x64xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = linalg.batch_matmul_transpose_a ins(%A, %B : tensor<512x128x64xf32>, tensor<512x128x64xf32>)
+                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %0 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul_transpose_a(
+// CHECK-SAME:   %[[A:.+]]: tensor<512x128x64xf32>, %[[B:.+]]: tensor<512x128x64xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x128x64xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 2, 1] inner_dims_pos = [2, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x128x64xf32> -> tensor<512x4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
+
+// -----
+
+func.func @block_matmul_transpose_b(
+    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.matmul_transpose_b ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
+                                 outs(%C : tensor<64x64xf32>) -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_matmul_transpose_b(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
+
+// -----
+
+func.func @block_batch_matmul_transpose_b(
+    %A: tensor<512x64x128xf32>, %B: tensor<512x64x128xf32>, %C: tensor<512x64x64xf32>) -> tensor<512x64x64xf32> {
+  %0 = linalg.batch_matmul_transpose_b ins(%A, %B : tensor<512x64x128xf32>, tensor<512x64x128xf32>)
+                                       outs(%C : tensor<512x64x64xf32>) -> tensor<512x64x64xf32>
+  return %0 : tensor<512x64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d3, d4, d6)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d5, d6)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d4, d5)>
+
+// CHECK-LABEL: func @block_batch_matmul_transpose_b(
+// CHECK-SAME:   %[[A:.+]]: tensor<512x64x128xf32>, %[[B:.+]]: tensor<512x64x128xf32>, %[[C:.+]]: tensor<512x64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<512x2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<512x64x128xf32> -> tensor<512x2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<512x4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<512x64x128xf32> -> tensor<512x4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<512x2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<512x64x64xf32> -> tensor<512x2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<512x2x2x32x64xf32>, tensor<512x4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<512x2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [1, 2] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<512x2x4x32x16xf32> -> tensor<512x64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<512x64x64xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @block_generic_matmul(
+    %A: tensor<128x128xf32>, %B: tensor<128x128xf32>, %C: tensor<128x128xf32>) -> tensor<128x128xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<128x128xf32>, tensor<128x128xf32>)
+    outs(%C : tensor<128x128xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<128x128xf32>
+  return %0 : tensor<128x128xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_generic_matmul(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x128xf32>, %[[B:[0-9a-z]+]]: tensor<128x128xf32>, %[[C:[0-9a-z]+]]: tensor<128x128xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<4x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x128xf32> -> tensor<4x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<8x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x128xf32> -> tensor<8x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<4x8x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<128x128xf32> -> tensor<4x8x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<4x2x32x64xf32>, tensor<8x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<4x8x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<4x8x32x16xf32> -> tensor<128x128xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<128x128xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d2, d0)>
+#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @block_generic_matmul_transpose_a(
+    %A: tensor<128x64xf32>, %B: tensor<128x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<128x64xf32>, tensor<128x64xf32>)
+    outs(%C : tensor<64x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_generic_matmul_transpose_a(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<128x64xf32>, %[[B:[0-9a-z]+]]: tensor<128x64xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<128x64xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<128x64xf32> -> tensor<4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+func.func @block_generic_matmul_transpose_b(
+    %A: tensor<64x128xf32>, %B: tensor<64x128xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
+    ins(%A, %B : tensor<64x128xf32>, tensor<64x128xf32>)
+    outs(%C : tensor<64x64xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    %2 = arith.addf %out, %1 : f32
+    linalg.yield %2 : f32
+  } -> tensor<64x64xf32>
+  return %0 : tensor<64x64xf32>
+}
+
+// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d4, d5)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d3, d4)>
+
+// CHECK-LABEL: func @block_generic_matmul_transpose_b(
+// CHECK-SAME:    %[[A:[0-9a-z]+]]: tensor<64x128xf32>, %[[B:[0-9a-z]+]]: tensor<64x128xf32>, %[[C:[0-9a-z]+]]: tensor<64x64xf32>
+// CHECK: %[[PACK_DST_0:.+]] = tensor.empty() : tensor<2x2x32x64xf32>
+// CHECK: %[[A_PACKED:.+]] = tensor.pack %[[A]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [32, 64]
+// CHECK-SAME:  into %[[PACK_DST_0]] : tensor<64x128xf32> -> tensor<2x2x32x64xf32>
+// CHECK: %[[PACK_DST_1:.+]] = tensor.empty() : tensor<4x2x16x64xf32>
+// CHECK: %[[B_PACKED:.+]] = tensor.pack %[[B]]
+// CHECK-SAME:  outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 64]
+// CHECK-SAME:  into %[[PACK_DST_1]] : tensor<64x128xf32> -> tensor<4x2x16x64xf32>
+// CHECK: %[[PACK_DST_2:.+]] = tensor.empty() : tensor<2x4x32x16xf32>
+// CHECK: %[[C_PACKED:.+]] = tensor.pack %[[C]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[PACK_DST_2]] : tensor<64x64xf32> -> tensor<2x4x32x16xf32>
+// CHECK: %[[GEMM_RES_PACKED:.+]] = linalg.generic
+// CHECK-SAME:  indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP2]]]
+// CHECK-SAME:  iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]
+// CHECK-SAME:  ins(%[[A_PACKED]], %[[B_PACKED]] : tensor<2x2x32x64xf32>, tensor<4x2x16x64xf32>) outs(%[[C_PACKED]] : tensor<2x4x32x16xf32>)
+// CHECK: %[[RES_UNPACKED:.+]] = tensor.unpack %[[GEMM_RES_PACKED]]
+// CHECK-SAME:  inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME:  into %[[C]] : tensor<2x4x32x16xf32> -> tensor<64x64xf32>
+// CHECK: return %[[RES_UNPACKED]] : tensor<64x64xf32>


        


More information about the Mlir-commits mailing list