[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose A matrix of matmul op (PR #89075)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Apr 17 07:16:17 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-linalg
Author: Cullen Rhodes (c-rhodes)
<details>
<summary>Changes</summary>
This patch introduces a pass `-linalg-matmul-to-matmul-transpose-a`, which transposes the A matrix of a Linalg matmul operation with the aim of memory accesses being contiguous.
Our work enabling a lowering path from `linalg.matmul` to ArmSME has revealed the current lowering results in non-contiguous memory accesses for the A matrix and very poor performance.
This pass provides a simple option to fix this.
---
Full diff: https://github.com/llvm/llvm-project/pull/89075.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+9)
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+3)
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
- (added) mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp (+92)
- (added) mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir (+76)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..38be6e49f574c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,13 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
];
}
+def LinalgMatmulToMatmulTransposeAPass
+ : Pass<"linalg-matmul-to-matmul-transpose-a"> {
+ let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+ let dependentDialects = ["linalg::LinalgDialect"];
+ let description = [{
+ Transposes the A matrix of a `linalg.matmul` for contiguous access.
+ }];
+}
+
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..0d354c666b1742 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,9 @@ void populateSplitReductionPattern(
const ControlSplitReductionFn &controlSplitReductionFn,
bool useAlloc = false);
+/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
+void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+
} // namespace linalg
} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..bca4954f959da3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MatmulToMatmulTransposeA.cpp
MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
new file mode 100644
index 00000000000000..45551cd9167b60
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -0,0 +1,92 @@
+//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
+// with the aim of the memory accesses becoming contiguous.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace `linalg.matmul(a, b)` with
+/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
+struct MatmulToMatmulTransposeA final
+ : public OpRewritePattern<linalg::MatmulOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+ PatternRewriter &rewriter) const override {
+ if (!bufferization::hasTensorSemantics(matmulOp))
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only matmul ops with tensors are supported");
+
+ Value a = matmulOp.getInputs()[0];
+ auto aType = cast<ShapedType>(a.getType());
+ if (aType.getRank() != 2)
+ return rewriter.notifyMatchFailure(
+ matmulOp, "only 2-D matmul ops are supported");
+
+ Location loc = matmulOp.getLoc();
+
+ SmallVector<Value> dynamicDims;
+ if (aType.isDynamicDim(1))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
+ if (aType.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
+
+ auto aShape = aType.getShape();
+ SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
+ Value empty = rewriter.create<tensor::EmptyOp>(
+ loc, transposedShape, aType.getElementType(), dynamicDims);
+ static constexpr std::array<int64_t, 2> perm = {1, 0};
+ auto transposeAOp =
+ rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
+ rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+ matmulOp, matmulOp.getResultTypes(),
+ ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
+ matmulOp.getOutputs());
+
+ return success();
+ }
+};
+} // namespace
+
+void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
+}
+
+namespace {
+struct LinalgMatmulToMatmulTransposeAPass
+ : public impl::LinalgMatmulToMatmulTransposeAPassBase<
+ LinalgMatmulToMatmulTransposeAPass> {
+ using impl::LinalgMatmulToMatmulTransposeAPassBase<
+ LinalgMatmulToMatmulTransposeAPass>::
+ LinalgMatmulToMatmulTransposeAPassBase;
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ RewritePatternSet patterns(op->getContext());
+ populateMatmulToMatmulTransposeAPattern(patterns);
+ (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+ }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
new file mode 100644
index 00000000000000..c3c8dff98ba3c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL: func.func @static(
+// CHECK-SAME: %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
+// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// CHECK: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK: return %[[C]] : tensor<16x16xf32>
+// CHECK: }
+func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %init = tensor.empty() : tensor<16x16xf32>
+ %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
+ return %0 : tensor<16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL: func.func @dynamic(
+// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK: %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
+// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// CHECK: %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK: return %[[C]] : tensor<?x?xf32>
+// CHECK: }
+func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+ %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+ %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+ %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL: func.func @mixed(
+// CHECK-SAME: %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME: %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
+// CHECK: %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
+// CHECK: %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK: %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// CHECK: %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// CHECK: %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK: return %[[B0]] : tensor<?x16xf32>
+// CHECK: }
+func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+ %cst = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
+ %init = tensor.empty(%d0) : tensor<?x16xf32>
+ %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
+ %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
+ return %0 : tensor<?x16xf32>
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/89075
More information about the Mlir-commits
mailing list