[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose matmul op (PR #89075)

Diego Caballero llvmlistbot at llvm.org
Fri Apr 19 02:02:45 PDT 2024


================
@@ -0,0 +1,102 @@
+//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed matmul ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSEMATMULPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-transpose-matmul"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace
+///
+///   linalg.matmul(a, b)
+///
+/// with
+///
+///   linalg.matmul_transpose_a(linalg.transpose(a), b)
+///
+/// By default A is transposed. If `transposeA` is set to false then B is
+/// transposed. Note: only 2-D matrices are currently supported.
+struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
+  TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
+      : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
----------------
dcaballe wrote:

Good point... We've been leaning towards vecmat as a canonical form but I guess there's no reason to enforce that.. Let's skip it for now. The batch matmul variant would be helpful, though, but we can add it separately.

https://github.com/llvm/llvm-project/pull/89075


More information about the Mlir-commits mailing list