[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose matmul op (PR #89075)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Apr 19 09:33:02 PDT 2024
================
@@ -0,0 +1,161 @@
+//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// This is intended to be a simple high-level (target-agnostic) matmul
+// transposition transformation.
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGTRANSPOSEMATMULPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-transpose-matmul"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace
+///
+/// linalg.matmul(a, b)
+///
+/// with
+///
+/// linalg.matmul_transpose_a(linalg.transpose(a), b)
+///
+/// By default A is transposed. If `transposeA` is set to false then B is
+/// transposed.
+struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
+ TransposeMatmul(MLIRContext *ctx, bool transposeA, PatternBenefit benefit = 1)
+ : OpRewritePattern(ctx, benefit), transposeA(transposeA) {}
+
+ LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (!bufferization::hasTensorSemantics(matmulOp))
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with tensors are supported");
+
+ Location loc = matmulOp.getLoc();
+ Value input = matmulOp.getInputs()[transposeA ? 0 : 1];
+ auto type = cast<ShapedType>(input.getType());
+
+ SmallVector<Value> dynamicDims;
+ if (type.isDynamicDim(1))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
+ if (type.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+ auto shape = type.getShape();
+ SmallVector<int64_t> transposedShape{shape[1], shape[0]};
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transposedShape, type.getElementType(), dynamicDims);
----------------
MacDue wrote:
Nit: This does not need to be a vector:
```suggestion
std::array transposedShape{shape[1], shape[0]};
Value empty = rewriter.create<tensor::EmptyOp>(
loc, transposedShape, type.getElementType(), dynamicDims);
```
or
```suggestion
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(), dynamicDims);
```
Should work
https://github.com/llvm/llvm-project/pull/89075
More information about the Mlir-commits
mailing list