[Mlir-commits] [mlir] [mlir][linalg] Add pass to transpose A matrix of matmul op (PR #89075)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 17 07:16:17 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Cullen Rhodes (c-rhodes)

<details>
<summary>Changes</summary>

This patch introduces a pass `-linalg-matmul-to-matmul-transpose-a`, which transposes the A matrix of a Linalg matmul operation with the aim of memory accesses being contiguous.

Our work enabling a lowering path from `linalg.matmul` to ArmSME has revealed the current lowering results in non-contiguous memory accesses for the A matrix and very poor performance.

This pass provides a simple option to fix this.

---
Full diff: https://github.com/llvm/llvm-project/pull/89075.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+9) 
- (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+3) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp (+92) 
- (added) mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir (+76) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 85f11c66d29a73..38be6e49f574c9 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -141,4 +141,13 @@ def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInter
   ];
 }
 
+def LinalgMatmulToMatmulTransposeAPass
+    : Pass<"linalg-matmul-to-matmul-transpose-a"> {
+  let summary = "Converts `linalg.matmul` to `linalg.matmul_transpose_a`.";
+  let dependentDialects = ["linalg::LinalgDialect"];
+  let description = [{
+    Transposes the A matrix of a `linalg.matmul` for contiguous access.
+  }];
+}
+
 #endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index feb3b3f03cf538..0d354c666b1742 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1616,6 +1616,9 @@ void populateSplitReductionPattern(
     const ControlSplitReductionFn &controlSplitReductionFn,
     bool useAlloc = false);
 
+/// Pattern to replace `linalg.matmul` with `linalg.matmul_transpose_a`.
+void populateMatmulToMatmulTransposeAPattern(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 513c54de5d7bfc..bca4954f959da3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   InlineScalarOperands.cpp
   Interchange.cpp
   Loops.cpp
+  MatmulToMatmulTransposeA.cpp
   MeshShardingInterfaceImpl.cpp
   NamedOpConversions.cpp
   Padding.cpp
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
new file mode 100644
index 00000000000000..45551cd9167b60
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MatmulToMatmulTransposeA.cpp
@@ -0,0 +1,92 @@
+//===- MatmulToMatmulTransposeA.cpp - Linalg matmul to matmul_transpose_a -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This rewrite and pass transposes the A matrix of a `linalg.matmul` operation
+// with the aim of the memory accesses becoming contiguous.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_LINALGMATMULTOMATMULTRANSPOSEAPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-matmul-to-matmul-transpose-a"
+
+using namespace mlir;
+using namespace mlir::linalg;
+
+namespace {
+/// Pattern to replace `linalg.matmul(a, b)` with
+/// `linalg.matmul_transpose_a(linalg.transpose(a), b)`.
+struct MatmulToMatmulTransposeA final
+    : public OpRewritePattern<linalg::MatmulOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
+                                PatternRewriter &rewriter) const override {
+    if (!bufferization::hasTensorSemantics(matmulOp))
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only matmul ops with tensors are supported");
+
+    Value a = matmulOp.getInputs()[0];
+    auto aType = cast<ShapedType>(a.getType());
+    if (aType.getRank() != 2)
+      return rewriter.notifyMatchFailure(
+          matmulOp, "only 2-D matmul ops are supported");
+
+    Location loc = matmulOp.getLoc();
+
+    SmallVector<Value> dynamicDims;
+    if (aType.isDynamicDim(1))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 1));
+    if (aType.isDynamicDim(0))
+      dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, a, 0));
+
+    auto aShape = aType.getShape();
+    SmallVector<int64_t> transposedShape{aShape[1], aShape[0]};
+    Value empty = rewriter.create<tensor::EmptyOp>(
+        loc, transposedShape, aType.getElementType(), dynamicDims);
+    static constexpr std::array<int64_t, 2> perm = {1, 0};
+    auto transposeAOp =
+        rewriter.create<linalg::TransposeOp>(loc, a, empty, perm);
+    rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
+        matmulOp, matmulOp.getResultTypes(),
+        ValueRange{transposeAOp->getResult(0), matmulOp.getInputs()[1]},
+        matmulOp.getOutputs());
+
+    return success();
+  }
+};
+} // namespace
+
+void mlir::linalg::populateMatmulToMatmulTransposeAPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<MatmulToMatmulTransposeA>(patterns.getContext());
+}
+
+namespace {
+struct LinalgMatmulToMatmulTransposeAPass
+    : public impl::LinalgMatmulToMatmulTransposeAPassBase<
+          LinalgMatmulToMatmulTransposeAPass> {
+  using impl::LinalgMatmulToMatmulTransposeAPassBase<
+      LinalgMatmulToMatmulTransposeAPass>::
+      LinalgMatmulToMatmulTransposeAPassBase;
+  void runOnOperation() override {
+    Operation *op = getOperation();
+    RewritePatternSet patterns(op->getContext());
+    populateMatmulToMatmulTransposeAPattern(patterns);
+    (void)applyPatternsAndFoldGreedily(op, std::move(patterns));
+  }
+};
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
new file mode 100644
index 00000000000000..c3c8dff98ba3c9
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/matmul-to-matmul-transpose-a.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt -linalg-matmul-to-matmul-transpose-a -cse -canonicalize -split-input-file %s | FileCheck %s
+
+// CHECK-LABEL:   func.func @static(
+// CHECK-SAME:                      %[[A:.*]]: tensor<16x8xf32>,
+// CHECK-SAME:                      %[[B:.*]]: tensor<8x16xf32>) -> tensor<16x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C_INIT:.*]] = tensor.empty() : tensor<16x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty() : tensor<8x16xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<16x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x16xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x16xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<16x16xf32>) -> tensor<16x16xf32>
+// CHECK:           return %[[C]] : tensor<16x16xf32>
+// CHECK:         }
+func.func @static(%A: tensor<16x8xf32>, %B: tensor<8x16xf32>) -> (tensor<16x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %init = tensor.empty() : tensor<16x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<16x16xf32>) -> tensor<16x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<16x8xf32>, tensor<8x16xf32>) outs(%C : tensor<16x16xf32>) -> tensor<16x16xf32>
+  return %0 : tensor<16x16xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @dynamic(
+// CHECK-SAME:                       %[[A:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                       %[[B:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
+// CHECK:           %[[B_DIM1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]], %[[B_DIM1]]) : tensor<?x?xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           %[[A_DIM1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM1]], %[[A_DIM0]]) : tensor<?x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x?xf32>) outs(%[[A_TRANSP_INIT]] : tensor<?x?xf32>) permutation = [1, 0]
+// CHECK:           %[[C:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C_ZERO]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK:           return %[[C]] : tensor<?x?xf32>
+// CHECK:         }
+func.func @dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x?xf32>
+  %d1 = tensor.dim %B, %c1 : tensor<?x?xf32>
+  %init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%C : tensor<?x?xf32>) -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+//-----
+
+// CHECK-LABEL:   func.func @mixed(
+// CHECK-SAME:                     %[[A:.*]]: tensor<?x8xf32>,
+// CHECK-SAME:                     %[[B:.*]]: tensor<8x16xf32>) -> tensor<?x16xf32> {
+// CHECK:           %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[A_DIM0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x8xf32>
+// CHECK:           %[[C_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<?x16xf32>
+// CHECK:           %[[C_ZERO:.*]] = linalg.fill ins(%[[C0_F32]] : f32) outs(%[[C_INIT]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           %[[A_TRANSP_INIT:.*]] = tensor.empty(%[[A_DIM0]]) : tensor<8x?xf32>
+// CHECK:           %[[A_TRANSP:.*]] = linalg.transpose ins(%[[A]] : tensor<?x8xf32>) outs(%[[A_TRANSP_INIT]] : tensor<8x?xf32>) permutation = [1, 0]
+// CHECK:           %[[B0:.*]] = linalg.matmul_transpose_a ins(%[[A_TRANSP]], %[[B]] : tensor<8x?xf32>, tensor<8x16xf32>) outs(%[[C_ZERO]] : tensor<?x16xf32>) -> tensor<?x16xf32>
+// CHECK:           return %[[B0]] : tensor<?x16xf32>
+// CHECK:         }
+func.func @mixed(%A: tensor<?x8xf32>, %B: tensor<8x16xf32>) -> (tensor<?x16xf32>) {
+  %cst = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %d0 = tensor.dim %A, %c0 : tensor<?x8xf32>
+  %init = tensor.empty(%d0) : tensor<?x16xf32>
+  %C = linalg.fill ins(%cst : f32) outs(%init : tensor<?x16xf32>) -> tensor<?x16xf32>
+  %0 = linalg.matmul ins(%A, %B : tensor<?x8xf32>, tensor<8x16xf32>) outs(%C : tensor<?x16xf32>) -> tensor<?x16xf32>
+  return %0 : tensor<?x16xf32>
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/89075


More information about the Mlir-commits mailing list