[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose matmul op (PR #89075)

Cullen Rhodes llvmlistbot at llvm.org
Mon Apr 22 01:42:08 PDT 2024


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/89075

>From 3bd53f154a2fcab7f952d190923b41430830907c Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 16 Apr 2024 09:16:33 +0000
Subject: [PATCH 1/6] [mlir][linalg] Add pass to transpose A matrix of matmul
 op

This patch introduces a pass `-linalg-matmul-to-matmul-transpose-a`,
which transposes the A matrix of a Linalg matmul operation, with the aim
of memory accesses being contiguous.

Our work enabling a lowering path from `linalg.matmul` to ArmSME has
revealed the current lowering results in non-contiguous memory accesses
for the A matrix and very poor performance.

This pass provides a simple option to fix this.
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  9 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/MatmulToMatmulTransposeA.cpp   | 92 +++++++++++++++++++
 .../Linalg/matmul-to-matmul-transpose-a.mlir  | 76 +++++++++++++++
 5 files changed, 181 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
 create mode 100644 mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..38be6e49f574c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,13 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgMatmulToMatmulTransposeAPass
+    : Pass<"linalg-matmul-to-matmul-transpose-a"> {
+  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let description = [{
+    Transposes the A matrix of a `linalg.matmul` for contiguous access.
+  }];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..0d354c666b1742 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,9 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
+/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
+void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..bca4954f959da3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MatmulToMatmulTransposeA.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
new file mode 100644
index 00000000000000..45551cd9167b60
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -0,0 +1,92 @@
+//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
+// with the aim of the memory accesses becoming contiguous.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace `linalg.matmul(a, b)` with
+/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
+struct MatmulToMatmulTransposeA final
+    : public OpRewritePattern<linalg::MatmulOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(matmulOp))
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only matmul ops with tensors are supported");
+
+    Value a = matmulOp.getInputs()[0];
+    auto aType = cast<ShapedType>(a.getType());
+    if (aType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only 2-D matmul ops are supported");
+
+    Location loc = matmulOp.getLoc();
+
+    SmallVector<Value> dynamicDims;
+    if (aType.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
+    if (aType.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
+
+    auto aShape = aType.getShape();
+    SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, aType.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto transposeAOp =
+        rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
+    rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+        matmulOp, matmulOp.getResultTypes(),
+        ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
+}
+
+namespace {
+struct LinalgMatmulToMatmulTransposeAPass
+    : public impl::LinalgMatmulToMatmulTransposeAPassBase<
+          LinalgMatmulToMatmulTransposeAPass> {
+  using impl::LinalgMatmulToMatmulTransposeAPassBase<
+      LinalgMatmulToMatmulTransposeAPass>::
+      LinalgMatmulToMatmulTransposeAPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateMatmulToMatmulTransposeAPattern(patterns);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
new file mode 100644
index 00000000000000..c3c8dff98ba3c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @static(
+// CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME:                      %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           return %[[C]] : tensor<16x16xf32>
+// CHECK:         }
+func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %init = tensor.empty() : tensor<16x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @dynamic(
+// CHECK-SAME:                       %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                       %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           return %[[C]] : tensor<?x?xf32>
+// CHECK:         }
+func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @mixed(
+// CHECK-SAME:                     %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME:                     %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// CHECK:           %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           return %[[B0]] : tensor<?x16xf32>
+// CHECK:         }
+func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
+  %init = tensor.empty(%d0) : tensor<?x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
+  return %0 : tensor<?x16xf32>
+}

>From 2009cd03a86eaf5bf3f5d6e6c663282ca66e59dd Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 17 Apr 2024 14:20:34 +0000
Subject: [PATCH 2/6] run clang-format

---
 .../Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
index 45551cd9167b60..73a812ce4550b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -42,8 +42,8 @@ struct MatmulToMatmulTransposeA final
     Value a = matmulOp.getInputs()[0];
     auto aType = cast<ShapedType>(a.getType());
     if (aType.getRank() != 2)
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only 2-D matmul ops are supported");
+      return rewriter.notifyMatchFailure(matmulOp,
+                                         "only 2-D matmul ops are supported");
 
     Location loc = matmulOp.getLoc();
 

>From e407d6e017edea7e9ffc327fe8d86d9688e620d2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 18 Apr 2024 13:45:55 +0000
Subject: [PATCH 3/6] address comments

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  13 +--
 .../Dialect/Linalg/Transforms/Transforms.h    |   6 +-
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   2 +-
 .../Transforms/MatmulToMatmulTransposeA.cpp   |  92 ----------------
 .../Linalg/Transforms/TransposeMatmul.cpp     | 102 ++++++++++++++++++
 ...transpose-a.mlir => transpose-matmul.mlir} |  33 ++++--
 6 files changed, 136 insertions(+), 112 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
 rename mlir/test/Dialect/Linalg/{matmul-to-matmul-transpose-a.mlir => transpose-matmul.mlir} (65%)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 38be6e49f574c9..bd6580f7feef1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,13 +141,14 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
-def LinalgMatmulToMatmulTransposeAPass
-    : Pass<"linalg-matmul-to-matmul-transpose-a"> {
-  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+def LinalgTransposeMatmulPass : Pass<"linalg-transpose-matmul"> {
+  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a` "
+                "(default) or `linalg.matmul_transpose_b`";
   let dependentDialects = ["linalg::LinalgDialect"];
-  let description = [{
-    Transposes the A matrix of a `linalg.matmul` for contiguous access.
-  }];
+  let options = [
+    Option<"transposeA", "transpose-a", "bool", /*default=*/"true",
+           "If true transpose A (default), otherwise transpose B.">
+  ];
 }
 
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0d354c666b1742..ac43029bb7e0b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,8 +1616,10 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
-/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
-void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+/// Pattern to convert `linalg.matmul` to `linalg.matmul_transpose_a` (default)
+/// or `linalg.matmul_transpose_b`.
+void populateTransposeMatmulPattern(RewritePatternSet &patterns,
+                                    bool transposeA = true);
 
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bca4954f959da3..ee6e391d0cc682 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,7 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
-  MatmulToMatmulTransposeA.cpp
+  TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
deleted file mode 100644
index 73a812ce4550b1..00000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
+++ /dev/null
@@ -1,92 +0,0 @@
-//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
-// with the aim of the memory accesses becoming contiguous.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
-#include "mlir/Dialect/Linalg/Passes.h.inc"
-} // namespace mlir
-
-#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-namespace {
-/// Pattern to replace `linalg.matmul(a, b)` with
-/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
-struct MatmulToMatmulTransposeA final
-    : public OpRewritePattern<linalg::MatmulOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
-                                PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(matmulOp))
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only matmul ops with tensors are supported");
-
-    Value a = matmulOp.getInputs()[0];
-    auto aType = cast<ShapedType>(a.getType());
-    if (aType.getRank() != 2)
-      return rewriter.notifyMatchFailure(matmulOp,
-                                         "only 2-D matmul ops are supported");
-
-    Location loc = matmulOp.getLoc();
-
-    SmallVector<Value> dynamicDims;
-    if (aType.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
-    if (aType.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
-
-    auto aShape = aType.getShape();
-    SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, transposedShape, aType.getElementType(), dynamicDims);
-    static constexpr std::array<int64_t, 2> perm = {1, 0};
-    auto transposeAOp =
-        rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
-    rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
-        matmulOp, matmulOp.getResultTypes(),
-        ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
-        matmulOp.getOutputs());
-
-    return success();
-  }
-};
-} // namespace
-
-void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
-    RewritePatternSet &patterns) {
-  patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
-}
-
-namespace {
-struct LinalgMatmulToMatmulTransposeAPass
-    : public impl::LinalgMatmulToMatmulTransposeAPassBase<
-          LinalgMatmulToMatmulTransposeAPass> {
-  using impl::LinalgMatmulToMatmulTransposeAPassBase<
-      LinalgMatmulToMatmulTransposeAPass>::
-      LinalgMatmulToMatmulTransposeAPassBase;
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    RewritePatternSet patterns(op->getContext());
-    populateMatmulToMatmulTransposeAPattern(patterns);
-    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
-  }
-};
-} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
new file mode 100644
index 00000000000000..bc93df5bf7033a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -0,0 +1,102 @@
+//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed matmul ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSEMATMULPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-transpose-matmul"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace
+///
+///   linalg.matmul(a, b)
+///
+/// with
+///
+///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///
+/// By default A is transposed. If `transposeA` is set to false then B is
+/// transposed. Note: only 2-D matrices are currently supported.
+struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
+  TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
+      : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(matmulOp))
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only matmul ops with tensors are supported");
+
+    Value input = matmulOp.getInputs()[transposeA ? 0 : 1];
+    auto type = cast<ShapedType>(input.getType());
+    if (type.getRank() != 2)
+      return rewriter.notifyMatchFailure(matmulOp,
+                                         "only 2-D matmul ops are supported");
+
+    Location loc = matmulOp.getLoc();
+
+    SmallVector<Value> dynamicDims;
+    if (type.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+    if (type.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+    auto shape = type.getShape();
+    SmallVector<int64_t> transposedShape{shape[1], shape[0]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, type.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto transposeOp =
+        rewriter.create<linalg::TransposeOp>(loc, input, empty, perm);
+    if (transposeA)
+      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+          matmulOp, matmulOp.getResultTypes(),
+          ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+          matmulOp.getOutputs());
+    else
+      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
+          matmulOp, matmulOp.getResultTypes(),
+          ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+          matmulOp.getOutputs());
+
+    return success();
+  }
+
+private:
+  bool transposeA;
+};
+} // namespace
+
+void mlir::linalg::populateTransposeMatmulPattern(RewritePatternSet &patterns,
+                                                  bool transposeA) {
+  patterns.add<TransposeMatmul>(patterns.getContext(), transposeA);
+}
+
+namespace {
+struct LinalgTransposeMatmulPass
+    : public impl::LinalgTransposeMatmulPassBase<LinalgTransposeMatmulPass> {
+  using impl::LinalgTransposeMatmulPassBase<
+      LinalgTransposeMatmulPass>::LinalgTransposeMatmulPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateTransposeMatmulPattern(patterns, transposeA);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
similarity index 65%
rename from mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
rename to mlir/test/Dialect/Linalg/transpose-matmul.mlir
index c3c8dff98ba3c9..470453c8ff78c1 100644
--- a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-transpose-matmul -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
+// RUN: mlir-opt -linalg-transpose-matmul=transpose-a=false -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
 
 // CHECK-LABEL:   func.func @static(
 // CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
@@ -6,9 +7,12 @@
 // CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
-// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
-// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
-// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
+// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
 // CHECK:           return %[[C]] : tensor<16x16xf32>
 // CHECK:         }
 func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
@@ -31,10 +35,14 @@ func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf
 // CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
 // CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
-// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
-// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-A:     %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK:           return %[[C]] : tensor<?x?xf32>
 // CHECK:         }
 func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
@@ -59,9 +67,12 @@ func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>
 // CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
 // CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
-// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
-// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
-// CHECK:           %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// TRANSPOSE-A:     %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
+// TRANSPOSE-B:     %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
 // CHECK:           return %[[B0]] : tensor<?x16xf32>
 // CHECK:         }
 func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {

>From 329f94596ffe938c77065ba2f2f299507918c151 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 19 Apr 2024 07:27:34 +0000
Subject: [PATCH 4/6] clarify expectations in pass

---
 mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index bc93df5bf7033a..2d6c9cd59ab298 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -5,6 +5,12 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
+// This pass converts `linalg.matmul` operations to the transposed
+// `linalg.matmul_transpose_a` or `linalg.matmul_transpose_b` variants.
+//
+// This is intended to be a simple high-level (target-agnostic) matmul
+// transposition transformation.
+//===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"

>From 49cbd5dcc3b4648784c92a66d8a1065199d44087 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 19 Apr 2024 10:36:06 +0000
Subject: [PATCH 5/6] add batch_matmul

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |   3 +-
 .../Dialect/Linalg/Transforms/Transforms.h    |   7 +-
 .../Linalg/Transforms/TransposeMatmul.cpp     |  81 +++++++++---
 .../test/Dialect/Linalg/transpose-matmul.mlir | 116 ++++++++++++++++--
 4 files changed, 175 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index bd6580f7feef1d..3704f3a7d64711 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -142,8 +142,7 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
 }
 
 def LinalgTransposeMatmulPass : Pass<"linalg-transpose-matmul"> {
-  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a` "
-                "(default) or `linalg.matmul_transpose_b`";
+  let summary = "Converts Linalg matmul ops to transposed variants.";
   let dependentDialects = ["linalg::LinalgDialect"];
   let options = [
     Option<"transposeA", "transpose-a", "bool", /*default=*/"true",
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ac43029bb7e0b9..9a606d558afc67 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,10 +1616,9 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
-/// Pattern to convert `linalg.matmul` to `linalg.matmul_transpose_a` (default)
-/// or `linalg.matmul_transpose_b`.
-void populateTransposeMatmulPattern(RewritePatternSet &patterns,
-                                    bool transposeA = true);
+/// Patterns to convert Linalg matmul ops to transposed variants.
+void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
+                                     bool transposeA = true);
 
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 2d6c9cd59ab298..1b807ea559d0be 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -1,13 +1,10 @@
-//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed matmul ---===//
+//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-// This pass converts `linalg.matmul` operations to the transposed
-// `linalg.matmul_transpose_a` or `linalg.matmul_transpose_b` variants.
-//
 // This is intended to be a simple high-level (target-agnostic) matmul
 // transposition transformation.
 //===----------------------------------------------------------------------===//
@@ -37,7 +34,7 @@ namespace {
 ///   linalg.matmul_transpose_a(linalg.transpose(a), b)
 ///
 /// By default A is transposed. If `transposeA` is set to false then B is
-/// transposed. Note: only 2-D matrices are currently supported.
+/// transposed.
 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
   TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
       : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
@@ -48,13 +45,9 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
       return rewriter.notifyMatchFailure(
           matmulOp, "only matmul ops with tensors are supported");
 
+    Location loc = matmulOp.getLoc();
     Value input = matmulOp.getInputs()[transposeA ? 0 : 1];
     auto type = cast<ShapedType>(input.getType());
-    if (type.getRank() != 2)
-      return rewriter.notifyMatchFailure(matmulOp,
-                                         "only 2-D matmul ops are supported");
-
-    Location loc = matmulOp.getLoc();
 
     SmallVector<Value> dynamicDims;
     if (type.isDynamicDim(1))
@@ -83,14 +76,74 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
     return success();
   }
 
+private:
+  bool transposeA;
+};
+
+/// Pattern to replace
+///
+///   linalg.batch_matmul(a, b)
+///
+/// with
+///
+///   linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
+///
+/// Only the non-batch dimensions are transposed. By default A is transposed. If
+/// `transposeA` is set to false then B is transposed.
+struct TransposeBatchMatmul final
+    : public OpRewritePattern<linalg::BatchMatmulOp> {
+  TransposeBatchMatmul(MLIRContext *ctx, bool transposeA,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+
+  LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(batchMatmulOp))
+      return rewriter.notifyMatchFailure(
+          batchMatmulOp, "only matmul ops with tensors are supported");
+
+    Location loc = batchMatmulOp.getLoc();
+    Value input = batchMatmulOp.getInputs()[transposeA ? 0 : 1];
+    auto type = cast<ShapedType>(input.getType());
+
+    SmallVector<Value> dynamicDims;
+    if (type.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+    if (type.isDynamicDim(2))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
+    if (type.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+
+    auto shape = type.getShape();
+    SmallVector<int64_t> transposedShape{shape[0], shape[2], shape[1]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, type.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
+    auto transposeOp =
+        rewriter.create<linalg::TransposeOp>(loc, input, empty, perm);
+    if (transposeA)
+      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
+          batchMatmulOp, batchMatmulOp.getResultTypes(),
+          ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
+          batchMatmulOp.getOutputs());
+    else
+      rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
+          batchMatmulOp, batchMatmulOp.getResultTypes(),
+          ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
+          batchMatmulOp.getOutputs());
+
+    return success();
+  }
+
 private:
   bool transposeA;
 };
 } // namespace
 
-void mlir::linalg::populateTransposeMatmulPattern(RewritePatternSet &patterns,
-                                                  bool transposeA) {
-  patterns.add<TransposeMatmul>(patterns.getContext(), transposeA);
+void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
+                                                   bool transposeA) {
+  patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
+                                                      transposeA);
 }
 
 namespace {
@@ -101,7 +154,7 @@ struct LinalgTransposeMatmulPass
   void runOnOperation() override {
     Operation *op = getOperation();
     RewritePatternSet patterns(op->getContext());
-    populateTransposeMatmulPattern(patterns, transposeA);
+    populateTransposeMatmulPatterns(patterns, transposeA);
     (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
   }
 };
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
index 470453c8ff78c1..0df5aefd1f7bfd 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,9 +1,9 @@
 // RUN: mlir-opt -linalg-transpose-matmul -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
 // RUN: mlir-opt -linalg-transpose-matmul=transpose-a=false -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
 
-// CHECK-LABEL:   func.func @static(
-// CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
-// CHECK-SAME:                      %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK-LABEL:   func.func @matmul_static(
+// CHECK-SAME:                             %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME:                             %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
 // CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
@@ -15,7 +15,7 @@
 // TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
 // CHECK:           return %[[C]] : tensor<16x16xf32>
 // CHECK:         }
-func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+func.func @matmul_static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
   %cst = arith.constant 0.0 : f32
   %init = tensor.empty() : tensor<16x16xf32>
   %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
@@ -25,9 +25,9 @@ func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf
 
 //-----
 
-// CHECK-LABEL:   func.func @dynamic(
-// CHECK-SAME:                       %[[A:.*]]: tensor<?x?xf32>,
-// CHECK-SAME:                       %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK-LABEL:   func.func @matmul_dynamic(
+// CHECK-SAME:                              %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                              %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
 // CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[C1:.*]] = arith.constant 1 : index
@@ -45,7 +45,7 @@ func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf
 // TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK:           return %[[C]] : tensor<?x?xf32>
 // CHECK:         }
-func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
   %cst = arith.constant 0.0 : f32
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -59,9 +59,9 @@ func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>
 
 //-----
 
-// CHECK-LABEL:   func.func @mixed(
-// CHECK-SAME:                     %[[A:.*]]: tensor<?x8xf32>,
-// CHECK-SAME:                     %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK-LABEL:   func.func @matmul_mixed(
+// CHECK-SAME:                            %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME:                            %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
 // CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
@@ -75,7 +75,7 @@ func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>
 // TRANSPOSE-B:     %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
 // CHECK:           return %[[B0]] : tensor<?x16xf32>
 // CHECK:         }
-func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+func.func @matmul_mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
   %cst = arith.constant 0.0 : f32
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -85,3 +85,95 @@ func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>
   %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
   return %0 : tensor<?x16xf32>
 }
+
+//-----
+
+// CHECK-LABEL:   func.func @batch_matmul_static(
+// CHECK-SAME:                                   %[[A:.*]]: tensor<2x16x8xf32>,
+// CHECK-SAME:                                   %[[B:.*]]: tensor<2x8x16xf32>) -> tensor<2x16x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<2x16x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x8x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x16xf32>) permutation = [0, 2, 1]
+// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x16xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
+// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x16x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+// CHECK:           return %[[C]] : tensor<2x16x16xf32>
+// CHECK:         }
+func.func @batch_matmul_static(%A: tensor<2x16x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x16x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %init = tensor.empty() : tensor<2x16x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+  %0 = linalg.batch_matmul ins(%A, %B : tensor<2x16x8xf32>, tensor<2x8x16xf32>) outs(%C : tensor<2x16x16xf32>) -> tensor<2x16x16xf32>
+  return %0 : tensor<2x16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @batch_matmul_dynamic(
+// CHECK-SAME:                                    %[[A:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                                    %[[B:.*]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?x?xf32>
+// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?x?xf32>
+// CHECK:           %[[B_DIM2:.*]] = tensor.dim %[[B]], %[[C2]] : tensor<?x?x?xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM1]], %[[B_DIM2]]) : tensor<?x?x?xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-A:     %[[A_DIM2:.*]] = tensor.dim %[[A]], %[[C2]] : tensor<?x?x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[A_DIM2]], %[[A_DIM1]]) : tensor<?x?x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
+// TRANSPOSE-A:     %[[C:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?x?xf32>
+// TRANSPOSE-B:     %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?x?xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM0]], %[[B_DIM2]], %[[B_DIM1]]) : tensor<?x?x?xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?x?xf32>) permutation = [0, 2, 1]
+// TRANSPOSE-B:     %[[C:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK:           return %[[C]] : tensor<?x?x?xf32>
+// CHECK:         }
+func.func @batch_matmul_dynamic(%A: tensor<?x?x?xf32>, %B: tensor<?x?x?xf32>) -> (tensor<?x?x?xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x?x?xf32>
+  %d1 = tensor.dim %A, %c1 : tensor<?x?x?xf32>
+  %d2 = tensor.dim %B, %c2 : tensor<?x?x?xf32>
+  %init = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  %0 = linalg.batch_matmul ins(%A, %B : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%C : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @batch_matmul_mixed(
+// CHECK-SAME:                                  %[[A:.*]]: tensor<2x?x8xf32>,
+// CHECK-SAME:                                  %[[B:.*]]: tensor<2x8x16xf32>) -> tensor<2x?x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<2x?x8xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x?x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]]) : tensor<2x8x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<2x?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<2x8x?xf32>) permutation = [0, 2, 1]
+// TRANSPOSE-A:     %[[B0:.*]] = linalg.batch_matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<2x8x?xf32>, tensor<2x8x16xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<2x16x8xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<2x8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<2x16x8xf32>) permutation = [0, 2, 1]
+// TRANSPOSE-B:     %[[B0:.*]] = linalg.batch_matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<2x?x8xf32>, tensor<2x16x8xf32>) outs(%[[C_ZERO]] : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+// CHECK:           return %[[B0]] : tensor<2x?x16xf32>
+// CHECK:         }
+func.func @batch_matmul_mixed(%A: tensor<2x?x8xf32>, %B: tensor<2x8x16xf32>) -> (tensor<2x?x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d1 = tensor.dim %A, %c1 : tensor<2x?x8xf32>
+  %init = tensor.empty(%d1) : tensor<2x?x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+  %0 = linalg.batch_matmul ins(%A, %B : tensor<2x?x8xf32>, tensor<2x8x16xf32>) outs(%C : tensor<2x?x16xf32>) -> tensor<2x?x16xf32>
+  return %0 : tensor<2x?x16xf32>
+}

>From 7891d7db692c501034917106fbf315f902393b49 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 22 Apr 2024 08:29:35 +0000
Subject: [PATCH 6/6] address comments

- address Ben's nits.
- replace pass with transform op.
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  9 ----
 .../Linalg/TransformOps/LinalgTransformOps.td | 15 ++++++
 .../TransformOps/LinalgTransformOps.cpp       |  5 ++
 .../Linalg/Transforms/TransposeMatmul.cpp     | 47 +++++--------------
 .../Dialect/Linalg/transpose-matmul-a.mlir    | 13 +++++
 .../Dialect/Linalg/transpose-matmul-b.mlir    | 13 +++++
 .../test/Dialect/Linalg/transpose-matmul.mlir |  4 +-
 7 files changed, 60 insertions(+), 46 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
 create mode 100644 mlir/test/Dialect/Linalg/transpose-matmul-b.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 3704f3a7d64711..85f11c66d29a73 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,13 +141,4 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
-def LinalgTransposeMatmulPass : Pass<"linalg-transpose-matmul"> {
-  let summary = "Converts Linalg matmul ops to transposed variants.";
-  let dependentDialects = ["linalg::LinalgDialect"];
-  let options = [
-    Option<"transposeA", "transpose-a", "bool", /*default=*/"true",
-           "If true transpose A (default), otherwise transpose B.">
-  ];
-}
-
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8edaa7db6cef3b..3cc8ed4c6fd57e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -73,6 +73,21 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
   let assemblyFormat = "attr-dict";
 }
 
+def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
+    "apply_patterns.linalg.transpose_matmul",
+    [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Collects patterns to convert Linalg matmul ops to transposed variants.
+
+    By default the A matrix is transposed, set `transpose_a=false` to transpose
+    B matrix.
+  }];
+
+  let arguments = (ins DefaultValuedAttr<BoolAttr, "true">:$transpose_a);
+
+  let assemblyFormat = "(`transpose_a` `=` $transpose_a^)? attr-dict";
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 7e7cf1d0244613..d7c956819c2602 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -199,6 +199,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
   linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
 }
 
+void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  linalg::populateTransposeMatmulPatterns(patterns, getTransposeA());
+}
+
 //===----------------------------------------------------------------------===//
 // BufferizeToAllocationOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index 1b807ea559d0be..f5d7f0befc2fb2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -9,16 +9,10 @@
 // transposition transformation.
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
-namespace mlir {
-#define GEN_PASS_DEF_LINALGTRANSPOSEMATMULPASS
-#include "mlir/Dialect/Linalg/Passes.h.inc"
-} // namespace mlir
-
 #define DEBUG_TYPE "linalg-transpose-matmul"
 
 using namespace mlir;
@@ -36,8 +30,8 @@ namespace {
 /// By default A is transposed. If `transposeA` is set to false then B is
 /// transposed.
 struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
-  TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
-      : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+  TransposeMatmul(MLIRContext *ctx, bool transposeA)
+      : OpRewritePattern(ctx), transposeA(transposeA) {}
 
   LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
                                 PatternRewriter &rewriter) const override {
@@ -56,12 +50,11 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
       dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
 
     auto shape = type.getShape();
-    SmallVector<int64_t> transposedShape{shape[1], shape[0]};
     Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, transposedShape, type.getElementType(), dynamicDims);
-    static constexpr std::array<int64_t, 2> perm = {1, 0};
-    auto transposeOp =
-        rewriter.create<linalg::TransposeOp>(loc, input, empty, perm);
+        loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
+        dynamicDims);
+    auto transposeOp = rewriter.create<linalg::TransposeOp>(
+        loc, input, empty, ArrayRef<int64_t>{1, 0});
     if (transposeA)
       rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
           matmulOp, matmulOp.getResultTypes(),
@@ -92,9 +85,8 @@ struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
 /// `transposeA` is set to false then B is transposed.
 struct TransposeBatchMatmul final
     : public OpRewritePattern<linalg::BatchMatmulOp> {
-  TransposeBatchMatmul(MLIRContext *ctx, bool transposeA,
-                       PatternBenefit benefit = 1)
-      : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+  TransposeBatchMatmul(MLIRContext *ctx, bool transposeA)
+      : OpRewritePattern(ctx), transposeA(transposeA) {}
 
   LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
                                 PatternRewriter &rewriter) const override {
@@ -115,12 +107,11 @@ struct TransposeBatchMatmul final
       dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
 
     auto shape = type.getShape();
-    SmallVector<int64_t> transposedShape{shape[0], shape[2], shape[1]};
     Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, transposedShape, type.getElementType(), dynamicDims);
-    static constexpr std::array<int64_t, 3> perm = {0, 2, 1};
-    auto transposeOp =
-        rewriter.create<linalg::TransposeOp>(loc, input, empty, perm);
+        loc, ArrayRef<int64_t>{shape[2], shape[1], shape[0]},
+        type.getElementType(), dynamicDims);
+    auto transposeOp = rewriter.create<linalg::TransposeOp>(
+        loc, input, empty, ArrayRef<int64_t>{2, 1, 0});
     if (transposeA)
       rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
           batchMatmulOp, batchMatmulOp.getResultTypes(),
@@ -145,17 +136,3 @@ void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
   patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
                                                       transposeA);
 }
-
-namespace {
-struct LinalgTransposeMatmulPass
-    : public impl::LinalgTransposeMatmulPassBase<LinalgTransposeMatmulPass> {
-  using impl::LinalgTransposeMatmulPassBase<
-      LinalgTransposeMatmulPass>::LinalgTransposeMatmulPassBase;
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    RewritePatternSet patterns(op->getContext());
-    populateTransposeMatmulPatterns(patterns, transposeA);
-    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
-  }
-};
-} // namespace
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
new file mode 100644
index 00000000000000..4ff8c1c64871b4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-a.mlir
@@ -0,0 +1,13 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.linalg.transpose_matmul
+    } : !transform.any_op
+    transform.apply_cse to %0 : !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
new file mode 100644
index 00000000000000..b7bba835ad4b04
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transpose-matmul-b.mlir
@@ -0,0 +1,13 @@
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.linalg.transpose_matmul transpose_a=false
+    } : !transform.any_op
+    transform.apply_cse to %0 : !transform.any_op
+    transform.apply_patterns to %0 {
+      transform.apply_patterns.canonicalization
+    } : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transpose-matmul.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
index 0df5aefd1f7bfd..d2b7e9f7f1992c 100644
--- a/mlir/test/Dialect/Linalg/transpose-matmul.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt -linalg-transpose-matmul -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
-// RUN: mlir-opt -linalg-transpose-matmul=transpose-a=false -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
+// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-a.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
+// RUN: mlir-opt -transform-preload-library='transform-library-paths=%p/transpose-matmul-b.mlir' -transform-interpreter -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
 
 // CHECK-LABEL:   func.func @matmul_static(
 // CHECK-SAME:                             %[[A:.*]]: tensor<16x8xf32>,



More information about the Mlir-commits mailing list