[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose A matrix of matmul op (PR #89075)

Cullen Rhodes llvmlistbot at llvm.org
Thu Apr 18 07:28:27 PDT 2024


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/89075

>From 3bd53f154a2fcab7f952d190923b41430830907c Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 16 Apr 2024 09:16:33 +0000
Subject: [PATCH 1/3] [mlir][linalg] Add pass to transpose A matrix of matmul
 op

This patch introduces a pass `-linalg-matmul-to-matmul-transpose-a`,
which transposes the A matrix of a Linalg matmul operation, with the aim
of memory accesses being contiguous.

Our work enabling a lowering path from `linalg.matmul` to ArmSME has
revealed the current lowering results in non-contiguous memory accesses
for the A matrix and very poor performance.

This pass provides a simple option to fix this.
---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  9 ++
 .../Dialect/Linalg/Transforms/Transforms.h    |  3 +
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |  1 +
 .../Transforms/MatmulToMatmulTransposeA.cpp   | 92 +++++++++++++++++++
 .../Linalg/matmul-to-matmul-transpose-a.mlir  | 76 +++++++++++++++
 5 files changed, 181 insertions(+)
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
 create mode 100644 mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..38be6e49f574c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,13 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgMatmulToMatmulTransposeAPass
+    : Pass<"linalg-matmul-to-matmul-transpose-a"> {
+  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let description = [{
+    Transposes the A matrix of a `linalg.matmul` for contiguous access.
+  }];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..0d354c666b1742 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,9 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
+/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
+void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..bca4954f959da3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MatmulToMatmulTransposeA.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
new file mode 100644
index 00000000000000..45551cd9167b60
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -0,0 +1,92 @@
+//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
+// with the aim of the memory accesses becoming contiguous.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace `linalg.matmul(a, b)` with
+/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
+struct MatmulToMatmulTransposeA final
+    : public OpRewritePattern<linalg::MatmulOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(matmulOp))
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only matmul ops with tensors are supported");
+
+    Value a = matmulOp.getInputs()[0];
+    auto aType = cast<ShapedType>(a.getType());
+    if (aType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only 2-D matmul ops are supported");
+
+    Location loc = matmulOp.getLoc();
+
+    SmallVector<Value> dynamicDims;
+    if (aType.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
+    if (aType.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
+
+    auto aShape = aType.getShape();
+    SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, aType.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto transposeAOp =
+        rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
+    rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+        matmulOp, matmulOp.getResultTypes(),
+        ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
+}
+
+namespace {
+struct LinalgMatmulToMatmulTransposeAPass
+    : public impl::LinalgMatmulToMatmulTransposeAPassBase<
+          LinalgMatmulToMatmulTransposeAPass> {
+  using impl::LinalgMatmulToMatmulTransposeAPassBase<
+      LinalgMatmulToMatmulTransposeAPass>::
+      LinalgMatmulToMatmulTransposeAPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateMatmulToMatmulTransposeAPattern(patterns);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
new file mode 100644
index 00000000000000..c3c8dff98ba3c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @static(
+// CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME:                      %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           return %[[C]] : tensor<16x16xf32>
+// CHECK:         }
+func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %init = tensor.empty() : tensor<16x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @dynamic(
+// CHECK-SAME:                       %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                       %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           return %[[C]] : tensor<?x?xf32>
+// CHECK:         }
+func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @mixed(
+// CHECK-SAME:                     %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME:                     %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// CHECK:           %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           return %[[B0]] : tensor<?x16xf32>
+// CHECK:         }
+func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
+  %init = tensor.empty(%d0) : tensor<?x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
+  return %0 : tensor<?x16xf32>
+}

>From 2009cd03a86eaf5bf3f5d6e6c663282ca66e59dd Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 17 Apr 2024 14:20:34 +0000
Subject: [PATCH 2/3] run clang-format

---
 .../Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
index 45551cd9167b60..73a812ce4550b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -42,8 +42,8 @@ struct MatmulToMatmulTransposeA final
     Value a = matmulOp.getInputs()[0];
     auto aType = cast<ShapedType>(a.getType());
     if (aType.getRank() != 2)
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only 2-D matmul ops are supported");
+      return rewriter.notifyMatchFailure(matmulOp,
+                                         "only 2-D matmul ops are supported");
 
     Location loc = matmulOp.getLoc();
 

>From e407d6e017edea7e9ffc327fe8d86d9688e620d2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 18 Apr 2024 13:45:55 +0000
Subject: [PATCH 3/3] address comments

---
 mlir/include/mlir/Dialect/Linalg/Passes.td    |  13 +--
 .../Dialect/Linalg/Transforms/Transforms.h    |   6 +-
 .../Dialect/Linalg/Transforms/CMakeLists.txt  |   2 +-
 .../Transforms/MatmulToMatmulTransposeA.cpp   |  92 ----------------
 .../Linalg/Transforms/TransposeMatmul.cpp     | 102 ++++++++++++++++++
 ...transpose-a.mlir => transpose-matmul.mlir} |  33 ++++--
 6 files changed, 136 insertions(+), 112 deletions(-)
 delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
 create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
 rename mlir/test/Dialect/Linalg/{matmul-to-matmul-transpose-a.mlir => transpose-matmul.mlir} (65%)

diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 38be6e49f574c9..bd6580f7feef1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,13 +141,14 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
-def LinalgMatmulToMatmulTransposeAPass
-    : Pass<"linalg-matmul-to-matmul-transpose-a"> {
-  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+def LinalgTransposeMatmulPass : Pass<"linalg-transpose-matmul"> {
+  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a` "
+                "(default) or `linalg.matmul_transpose_b`";
   let dependentDialects = ["linalg::LinalgDialect"];
-  let description = [{
-    Transposes the A matrix of a `linalg.matmul` for contiguous access.
-  }];
+  let options = [
+    Option<"transposeA", "transpose-a", "bool", /*default=*/"true",
+           "If true transpose A (default), otherwise transpose B.">
+  ];
 }
 
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0d354c666b1742..ac43029bb7e0b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,8 +1616,10 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
-/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
-void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+/// Pattern to convert `linalg.matmul` to `linalg.matmul_transpose_a` (default)
+/// or `linalg.matmul_transpose_b`.
+void populateTransposeMatmulPattern(RewritePatternSet &patterns,
+                                    bool transposeA = true);
 
 } // namespace linalg
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bca4954f959da3..ee6e391d0cc682 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,7 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
-  MatmulToMatmulTransposeA.cpp
+  TransposeMatmul.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
deleted file mode 100644
index 73a812ce4550b1..00000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
+++ /dev/null
@@ -1,92 +0,0 @@
-//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
-// with the aim of the memory accesses becoming contiguous.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
-#include "mlir/Dialect/Linalg/Passes.h.inc"
-} // namespace mlir
-
-#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-namespace {
-/// Pattern to replace `linalg.matmul(a, b)` with
-/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
-struct MatmulToMatmulTransposeA final
-    : public OpRewritePattern<linalg::MatmulOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
-                                PatternRewriter &rewriter) const override {
-    if (!bufferization::hasTensorSemantics(matmulOp))
-      return rewriter.notifyMatchFailure(
-          matmulOp, "only matmul ops with tensors are supported");
-
-    Value a = matmulOp.getInputs()[0];
-    auto aType = cast<ShapedType>(a.getType());
-    if (aType.getRank() != 2)
-      return rewriter.notifyMatchFailure(matmulOp,
-                                         "only 2-D matmul ops are supported");
-
-    Location loc = matmulOp.getLoc();
-
-    SmallVector<Value> dynamicDims;
-    if (aType.isDynamicDim(1))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
-    if (aType.isDynamicDim(0))
-      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
-
-    auto aShape = aType.getShape();
-    SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
-    Value empty = rewriter.create<tensor::EmptyOp>(
-        loc, transposedShape, aType.getElementType(), dynamicDims);
-    static constexpr std::array<int64_t, 2> perm = {1, 0};
-    auto transposeAOp =
-        rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
-    rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
-        matmulOp, matmulOp.getResultTypes(),
-        ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
-        matmulOp.getOutputs());
-
-    return success();
-  }
-};
-} // namespace
-
-void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
-    RewritePatternSet &patterns) {
-  patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
-}
-
-namespace {
-struct LinalgMatmulToMatmulTransposeAPass
-    : public impl::LinalgMatmulToMatmulTransposeAPassBase<
-          LinalgMatmulToMatmulTransposeAPass> {
-  using impl::LinalgMatmulToMatmulTransposeAPassBase<
-      LinalgMatmulToMatmulTransposeAPass>::
-      LinalgMatmulToMatmulTransposeAPassBase;
-  void runOnOperation() override {
-    Operation *op = getOperation();
-    RewritePatternSet patterns(op->getContext());
-    populateMatmulToMatmulTransposeAPattern(patterns);
-    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
-  }
-};
-} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
new file mode 100644
index 00000000000000..bc93df5bf7033a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -0,0 +1,102 @@
+//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed matmul ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSEMATMULPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-transpose-matmul"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace
+///
+///   linalg.matmul(a, b)
+///
+/// with
+///
+///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///
+/// By default A is transposed. If `transposeA` is set to false then B is
+/// transposed. Note: only 2-D matrices are currently supported.
+struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
+  TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
+      : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(matmulOp))
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only matmul ops with tensors are supported");
+
+    Value input = matmulOp.getInputs()[transposeA ? 0 : 1];
+    auto type = cast<ShapedType>(input.getType());
+    if (type.getRank() != 2)
+      return rewriter.notifyMatchFailure(matmulOp,
+                                         "only 2-D matmul ops are supported");
+
+    Location loc = matmulOp.getLoc();
+
+    SmallVector<Value> dynamicDims;
+    if (type.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+    if (type.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+    auto shape = type.getShape();
+    SmallVector<int64_t> transposedShape{shape[1], shape[0]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, type.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto transposeOp =
+        rewriter.create<linalg::TransposeOp>(loc, input, empty, perm);
+    if (transposeA)
+      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+          matmulOp, matmulOp.getResultTypes(),
+          ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+          matmulOp.getOutputs());
+    else
+      rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
+          matmulOp, matmulOp.getResultTypes(),
+          ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+          matmulOp.getOutputs());
+
+    return success();
+  }
+
+private:
+  bool transposeA;
+};
+} // namespace
+
+void mlir::linalg::populateTransposeMatmulPattern(RewritePatternSet &patterns,
+                                                  bool transposeA) {
+  patterns.add<TransposeMatmul>(patterns.getContext(), transposeA);
+}
+
+namespace {
+struct LinalgTransposeMatmulPass
+    : public impl::LinalgTransposeMatmulPassBase<LinalgTransposeMatmulPass> {
+  using impl::LinalgTransposeMatmulPassBase<
+      LinalgTransposeMatmulPass>::LinalgTransposeMatmulPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateTransposeMatmulPattern(patterns, transposeA);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
similarity index 65%
rename from mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
rename to mlir/test/Dialect/Linalg/transpose-matmul.mlir
index c3c8dff98ba3c9..470453c8ff78c1 100644
--- a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-transpose-matmul -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
+// RUN: mlir-opt -linalg-transpose-matmul=transpose-a=false -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
 
 // CHECK-LABEL:   func.func @static(
 // CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
@@ -6,9 +7,12 @@
 // CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
 // CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
-// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
-// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
-// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
+// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
 // CHECK:           return %[[C]] : tensor<16x16xf32>
 // CHECK:         }
 func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
@@ -31,10 +35,14 @@ func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf
 // CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
 // CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
-// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
-// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-A:     %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// TRANSPOSE-A:     %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-B:     %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// TRANSPOSE-B:     %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK:           return %[[C]] : tensor<?x?xf32>
 // CHECK:         }
 func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
@@ -59,9 +67,12 @@ func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>
 // CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
 // CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
 // CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
-// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
-// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
-// CHECK:           %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-A:     %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// TRANSPOSE-A:     %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// TRANSPOSE-A:     %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-B:     %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
+// TRANSPOSE-B:     %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
+// TRANSPOSE-B:     %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
 // CHECK:           return %[[B0]] : tensor<?x16xf32>
 // CHECK:         }
 func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {



More information about the Mlir-commits mailing list