[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose matmul op (PR #89075)
Cullen Rhodes
llvmlistbot at llvm.org
Fri Apr 19 00:29:14 PDT 2024
https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/89075
>From 3bd53f154a2fcab7f952d190923b41430830907c Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 16 Apr 2024 09:16:33 +0000
Subject: [PATCH 1/4] [mlir][linalg] Add pass to transpose A matrix of matmul
op
This patch introduces a pass `-linalg-matmul-to-matmul-transpose-a`,
which transposes the A matrix of a Linalg matmul operation, with the aim
of memory accesses being contiguous.
Our work enabling a lowering path from `linalg.matmul` to ArmSME has
revealed the current lowering results in non-contiguous memory accesses
for the A matrix and very poor performance.
This pass provides a simple option to fix this.
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 9 ++
.../Dialect/Linalg/Transforms/Transforms.h | 3 +
.../Dialect/Linalg/Transforms/CMakeLists.txt | 1 +
.../Transforms/MatmulToMatmulTransposeA.cpp | 92 +++++++++++++++++++
.../Linalg/matmul-to-matmul-transpose-a.mlir | 76 +++++++++++++++
5 files changed, 181 insertions(+)
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
create mode 100644 mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..38be6e49f574c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,13 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}
+def LinalgMatmulToMatmulTransposeAPass
+ : Pass<"linalg-matmul-to-matmul-transpose-a"> {
+ let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let description = [{
+ Transposes the A matrix of a `linalg.matmul` for contiguous access.
+ }];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..0d354c666b1742 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,9 @@ void populateSplitReductionPattern(
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);
+/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
+void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..bca4954f959da3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MatmulToMatmulTransposeA.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
new file mode 100644
index 00000000000000..45551cd9167b60
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -0,0 +1,92 @@
+//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
+// with the aim of the memory accesses becoming contiguous.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace `linalg.matmul(a, b)` with
+/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
+struct MatmulToMatmulTransposeA final
+ : public OpRewritePattern<linalg::MatmulOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (!bufferization::hasTensorSemantics(matmulOp))
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with tensors are supported");
+
+ Value a = matmulOp.getInputs()[0];
+ auto aType = cast<ShapedType>(a.getType());
+ if (aType.getRank() != 2)
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only 2-D matmul ops are supported");
+
+ Location loc = matmulOp.getLoc();
+
+ SmallVector<Value> dynamicDims;
+ if (aType.isDynamicDim(1))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
+ if (aType.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
+
+ auto aShape = aType.getShape();
+ SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transposedShape, aType.getElementType(), dynamicDims);
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ auto transposeAOp =
+ rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
+ rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+ matmulOp, matmulOp.getResultTypes(),
+ ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
+ matmulOp.getOutputs());
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
+}
+
+namespace {
+struct LinalgMatmulToMatmulTransposeAPass
+ : public impl::LinalgMatmulToMatmulTransposeAPassBase<
+ LinalgMatmulToMatmulTransposeAPass> {
+ using impl::LinalgMatmulToMatmulTransposeAPassBase<
+ LinalgMatmulToMatmulTransposeAPass>::
+ LinalgMatmulToMatmulTransposeAPassBase;
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateMatmulToMatmulTransposeAPattern(patterns);
+ (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
new file mode 100644
index 00000000000000..c3c8dff98ba3c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func @static(
+// CHECK-SAME: %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
+// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// CHECK: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK: return %[[C]] : tensor<16x16xf32>
+// CHECK: }
+func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %init = tensor.empty() : tensor<16x16xf32>
+ %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL: func.func @dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
+// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// CHECK: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: return %[[C]] : tensor<?x?xf32>
+// CHECK: }
+func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+ %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+ %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL: func.func @mixed(
+// CHECK-SAME: %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
+// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
+// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// CHECK: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK: return %[[B0]] : tensor<?x16xf32>
+// CHECK: }
+func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
+ %init = tensor.empty(%d0) : tensor<?x16xf32>
+ %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
+ return %0 : tensor<?x16xf32>
+}
>From 2009cd03a86eaf5bf3f5d6e6c663282ca66e59dd Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Wed, 17 Apr 2024 14:20:34 +0000
Subject: [PATCH 2/4] run clang-format
---
.../Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
index 45551cd9167b60..73a812ce4550b1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -42,8 +42,8 @@ struct MatmulToMatmulTransposeA final
Value a = matmulOp.getInputs()[0];
auto aType = cast<ShapedType>(a.getType());
if (aType.getRank() != 2)
- return rewriter.notifyMatchFailure(
- matmulOp, "only 2-D matmul ops are supported");
+ return rewriter.notifyMatchFailure(matmulOp,
+ "only 2-D matmul ops are supported");
Location loc = matmulOp.getLoc();
>From e407d6e017edea7e9ffc327fe8d86d9688e620d2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Thu, 18 Apr 2024 13:45:55 +0000
Subject: [PATCH 3/4] address comments
---
mlir/include/mlir/Dialect/Linalg/Passes.td | 13 +--
.../Dialect/Linalg/Transforms/Transforms.h | 6 +-
.../Dialect/Linalg/Transforms/CMakeLists.txt | 2 +-
.../Transforms/MatmulToMatmulTransposeA.cpp | 92 ----------------
.../Linalg/Transforms/TransposeMatmul.cpp | 102 ++++++++++++++++++
...transpose-a.mlir => transpose-matmul.mlir} | 33 ++++--
6 files changed, 136 insertions(+), 112 deletions(-)
delete mode 100644 mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
create mode 100644 mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
rename mlir/test/Dialect/Linalg/{matmul-to-matmul-transpose-a.mlir => transpose-matmul.mlir} (65%)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 38be6e49f574c9..bd6580f7feef1d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,13 +141,14 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}
-def LinalgMatmulToMatmulTransposeAPass
- : Pass<"linalg-matmul-to-matmul-transpose-a"> {
- let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+def LinalgTransposeMatmulPass : Pass<"linalg-transpose-matmul"> {
+ let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a` "
+ "(default) or `linalg.matmul_transpose_b`";
let dependentDialects = ["linalg::LinalgDialect"];
- let description = [{
- Transposes the A matrix of a `linalg.matmul` for contiguous access.
- }];
+ let options = [
+ Option<"transposeA", "transpose-a", "bool", /*default=*/"true",
+ "If true transpose A (default), otherwise transpose B.">
+ ];
}
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0d354c666b1742..ac43029bb7e0b9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,8 +1616,10 @@ void populateSplitReductionPattern(
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);
-/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
-void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+/// Pattern to convert `linalg.matmul` to `linalg.matmul_transpose_a` (default)
+/// or `linalg.matmul_transpose_b`.
+void populateTransposeMatmulPattern(RewritePatternSet &patterns,
+ bool transposeA = true);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index bca4954f959da3..ee6e391d0cc682 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,7 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
- MatmulToMatmulTransposeA.cpp
+ TransposeMatmul.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
deleted file mode 100644
index 73a812ce4550b1..00000000000000
--- a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
+++ /dev/null
@@ -1,92 +0,0 @@
-//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
-// with the aim of the memory accesses becoming contiguous.
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
-#include "mlir/IR/PatternMatch.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-
-namespace mlir {
-#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
-#include "mlir/Dialect/Linalg/Passes.h.inc"
-} // namespace mlir
-
-#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
-
-using namespace mlir;
-using namespace mlir::linalg;
-
-namespace {
-/// Pattern to replace `linalg.matmul(a, b)` with
-/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
-struct MatmulToMatmulTransposeA final
- : public OpRewritePattern<linalg::MatmulOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
- PatternRewriter &rewriter) const override {
- if (!bufferization::hasTensorSemantics(matmulOp))
- return rewriter.notifyMatchFailure(
- matmulOp, "only matmul ops with tensors are supported");
-
- Value a = matmulOp.getInputs()[0];
- auto aType = cast<ShapedType>(a.getType());
- if (aType.getRank() != 2)
- return rewriter.notifyMatchFailure(matmulOp,
- "only 2-D matmul ops are supported");
-
- Location loc = matmulOp.getLoc();
-
- SmallVector<Value> dynamicDims;
- if (aType.isDynamicDim(1))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
- if (aType.isDynamicDim(0))
- dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
-
- auto aShape = aType.getShape();
- SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
- Value empty = rewriter.create<tensor::EmptyOp>(
- loc, transposedShape, aType.getElementType(), dynamicDims);
- static constexpr std::array<int64_t, 2> perm = {1, 0};
- auto transposeAOp =
- rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
- rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
- matmulOp, matmulOp.getResultTypes(),
- ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
- matmulOp.getOutputs());
-
- return success();
- }
-};
-} // namespace
-
-void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
- RewritePatternSet &patterns) {
- patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
-}
-
-namespace {
-struct LinalgMatmulToMatmulTransposeAPass
- : public impl::LinalgMatmulToMatmulTransposeAPassBase<
- LinalgMatmulToMatmulTransposeAPass> {
- using impl::LinalgMatmulToMatmulTransposeAPassBase<
- LinalgMatmulToMatmulTransposeAPass>::
- LinalgMatmulToMatmulTransposeAPassBase;
- void runOnOperation() override {
- Operation *op = getOperation();
- RewritePatternSet patterns(op->getContext());
- populateMatmulToMatmulTransposeAPattern(patterns);
- (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
- }
-};
-} // namespace
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
new file mode 100644
index 00000000000000..bc93df5bf7033a
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -0,0 +1,102 @@
+//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed matmul ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSEMATMULPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-transpose-matmul"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace
+///
+/// linalg.matmul(a, b)
+///
+/// with
+///
+/// linalg.matmul_transpose_a(linalg.transpose(a), b)
+///
+/// By default A is transposed. If `transposeA` is set to false then B is
+/// transposed. Note: only 2-D matrices are currently supported.
+struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
+ TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
+ : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+
+ LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (!bufferization::hasTensorSemantics(matmulOp))
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with tensors are supported");
+
+ Value input = matmulOp.getInputs()[transposeA ? 0 : 1];
+ auto type = cast<ShapedType>(input.getType());
+ if (type.getRank() != 2)
+ return rewriter.notifyMatchFailure(matmulOp,
+ "only 2-D matmul ops are supported");
+
+ Location loc = matmulOp.getLoc();
+
+ SmallVector<Value> dynamicDims;
+ if (type.isDynamicDim(1))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+ if (type.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+ auto shape = type.getShape();
+ SmallVector<int64_t> transposedShape{shape[1], shape[0]};
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transposedShape, type.getElementType(), dynamicDims);
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ auto transposeOp =
+ rewriter.create<linalg::TransposeOp>(loc, input, empty, perm);
+ if (transposeA)
+ rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+ matmulOp, matmulOp.getResultTypes(),
+ ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
+ matmulOp.getOutputs());
+ else
+ rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
+ matmulOp, matmulOp.getResultTypes(),
+ ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
+ matmulOp.getOutputs());
+
+ return success();
+ }
+
+private:
+ bool transposeA;
+};
+} // namespace
+
+void mlir::linalg::populateTransposeMatmulPattern(RewritePatternSet &patterns,
+ bool transposeA) {
+ patterns.add<TransposeMatmul>(patterns.getContext(), transposeA);
+}
+
+namespace {
+struct LinalgTransposeMatmulPass
+ : public impl::LinalgTransposeMatmulPassBase<LinalgTransposeMatmulPass> {
+ using impl::LinalgTransposeMatmulPassBase<
+ LinalgTransposeMatmulPass>::LinalgTransposeMatmulPassBase;
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateTransposeMatmulPattern(patterns, transposeA);
+ (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
similarity index 65%
rename from mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
rename to mlir/test/Dialect/Linalg/transpose-matmul.mlir
index c3c8dff98ba3c9..470453c8ff78c1 100644
--- a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
+++ b/mlir/test/Dialect/Linalg/transpose-matmul.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -linalg-transpose-matmul -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-A
+// RUN: mlir-opt -linalg-transpose-matmul=transpose-a=false -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefixes=CHECK,TRANSPOSE-B
// CHECK-LABEL: func.func @static(
// CHECK-SAME: %[[A:.*]]: tensor<16x8xf32>,
@@ -6,9 +7,12 @@
// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
-// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
-// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
-// CHECK: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
+// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
+// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<16x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
// CHECK: return %[[C]] : tensor<16x16xf32>
// CHECK: }
func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
@@ -31,10 +35,14 @@ func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf
// CHECK: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
-// CHECK: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
-// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
-// CHECK: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-A: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// TRANSPOSE-A: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// TRANSPOSE-B: %[[B_DIM0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
+// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty(%[[B_DIM1]], %[[B_DIM0]]) : tensor<?x?xf32>
+// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<?x?xf32>) outs(%[[B_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// TRANSPOSE-B: %[[C:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: return %[[C]] : tensor<?x?xf32>
// CHECK: }
func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
@@ -59,9 +67,12 @@ func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>
// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
-// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
-// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
-// CHECK: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-A: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// TRANSPOSE-A: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// TRANSPOSE-A: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// TRANSPOSE-B: %[[B_TRANSP_INIT:.*]] = tensor.empty() : tensor<16x8xf32>
+// TRANSPOSE-B: %[[B_TRANSP:.*]] = linalg.transpose ins(%[[B]] : tensor<8x16xf32>) outs(%[[B_TRANSP_INIT]] : tensor<16x8xf32>) permutation = [1, 0]
+// TRANSPOSE-B: %[[B0:.*]] = linalg.matmul_transpose_b ins(%[[A]], %[[B_TRANSP]] : tensor<?x8xf32>, tensor<16x8xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
// CHECK: return %[[B0]] : tensor<?x16xf32>
// CHECK: }
func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
>From 329f94596ffe938c77065ba2f2f299507918c151 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Fri, 19 Apr 2024 07:27:34 +0000
Subject: [PATCH 4/4] clarify expectations in pass
---
mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index bc93df5bf7033a..2d6c9cd59ab298 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -5,6 +5,12 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
+// This pass converts `linalg.matmul` operations to the transposed
+// `linalg.matmul_transpose_a` or `linalg.matmul_transpose_b` variants.
+//
+// This is intended to be a simple high-level (target-agnostic) matmul
+// transposition transformation.
+//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
More information about the Mlir-commits
mailing list