[Mlir-commits] [mlir] [mlir][ArmSME] Lower multi-tile stores to a single loop (PR #96187)

Benjamin Maxwell llvmlistbot at llvm.org
Tue Jun 25 03:18:29 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/96187

>From 5420435e8ace78c406c03868cc79453a3e7fc11d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 20 Jun 2024 11:31:04 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Lower multi-tile stores to a single loop
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).

Example (in pseudo-MLIR):

```
vector.transfer_write vector, dest[x, y], mask
  : vector<[16]x[4]xf32>, memref<?x?xf32>
```
Is rewritten to:
```
for i in range (0, 4 * vscale) {
  let sliceRow = i + tile_n.row * vscale;              ─┐
  let sliceCol = tile_n.col * vscale;                   |
  slice = vector.extract tile_n[i]                      |
    : vector<[4]xf32> from vector<[16]x[4]xf32>         |
  slice_mask = vector.extract mask[sliceRow]            |- Repeated 4x for
    : vector<[4]xi1> from vector<[16]x[4]xi1>           |  all tiles in
  vector.transfer_write                                 |  [16]x[4]
    slice, dest[x + sliceRow, y + sliceCol], slice_mask |
    : vector<[4]xf32>, memref<?x?xf32>                  ┘
}
```
---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 134 +++++++++++++++++-
 .../Dialect/ArmSME/vector-legalization.mlir   |  69 ++++++++-
 .../Linalg/CPU/ArmSME/multi-tile-matmul.mlir  |   3 +-
 3 files changed, 196 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index b595c6dd8a684..e71e9e8969ffa 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -373,6 +374,130 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+/// Legalize a multi-tile transfer_write as a single store loop. This is done as
+/// part of type decomposition as at this level we know each tile write is
+/// disjoint, but that information is lost after decomposition (without
+/// static analysis).
+///
+/// Example (in pseudo-MLIR):
+///
+/// ```
+/// vector.transfer_write vector, dest[x, y], mask
+///   : vector<[16]x[4]xf32>, memref<?x?xf32>
+/// ```
+/// Is rewritten to:
+/// ```
+/// for i in range (0, 4 * vscale) {
+///   let sliceRow = i + tile_n.row * vscale;              ─┐
+///   let sliceCol = tile_n.col * vscale;                   |
+///   slice = vector.extract tile_n[i]                      |
+///     : vector<[4]xf32> from vector<[16]x[4]xf32>         |
+///   slice_mask = vector.extract mask[sliceRow]            |- Repeated 4x for
+///     : vector<[4]xi1> from vector<[16]x[4]xi1>           |  all tiles in
+///   vector.transfer_write                                 |  [16]x[4]
+///     slice, dest[x + sliceRow, y + sliceCol], slice_mask |
+///     : vector<[4]xf32>, memref<?x?xf32>                  ┘
+/// }
+/// ```
+struct LegalizeMultiTileTransferWriteAsStoreLoop
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (writeOp.hasPureTensorSemantics())
+      return rewriter.notifyMatchFailure(
+          writeOp, "TODO: tensor semantics are unsupported");
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    bool transposed = !permutationMap.isIdentity();
+    if (transposed)
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "TODO: transpose unsupported");
+
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
+                                              vectorType.getDimSize(1) > 16)))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto loc = writeOp.getLoc();
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto createVscaleMultiple = [&](int64_t multiplier) {
+      return rewriter.create<arith::MulIOp>(
+          loc, vscale,
+          rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+    };
+
+    // Get SME tile and slice types.
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto minTileSlices = smeTileType.getDimSize(0);
+    VectorType sliceMaskType =
+        VectorType::get(minTileSlices, rewriter.getI1Type(), true);
+
+    // Create loop over all tile slices.
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = createVscaleMultiple(minTileSlices);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto storeLoop =
+        rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    rewriter.setInsertionPointToStart(storeLoop.getBody());
+
+    // For each tile sub-tile of the multi-tile `vectorType`.
+    auto inputSMETiles = adaptor.getVector();
+    auto inductionVar = storeLoop.getInductionVar();
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
+      // The coordinates of the tile within `vectorType`.
+      auto tileRow = createVscaleMultiple(smeTile.row);
+      auto tileCol = createVscaleMultiple(smeTile.col);
+
+      // The current slice of `vectorType` we are processing.
+      auto sliceIndex =
+          rewriter.create<arith::AddIOp>(loc, tileRow, inductionVar);
+
+      // Where in the destination memref the current slice will be stored.
+      auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
+                                                     writeOp.getIndices()[0]);
+      auto storeCol =
+          rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
+
+      // Extract the mask for the current slice.
+      Value sliceMask = nullptr;
+      if (mask) {
+        sliceMask = rewriter.create<vector::ExtractOp>(
+            loc, mask, OpFoldResult(sliceIndex));
+        if (sliceMaskType != sliceMask.getType())
+          sliceMask = rewriter.create<vector::ScalableExtractOp>(
+              loc, sliceMaskType, sliceMask, smeTile.col);
+      }
+
+      // Extract and store the current slice slice.
+      Value tile = inputSMETiles[index];
+      auto slice = rewriter.create<vector::ExtractOp>(loc, tile, inductionVar);
+      rewriter.create<vector::TransferWriteOp>(
+          loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
+          AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
+          sliceMask,
+          rewriter.getBoolArrayAttr(
+              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
+    }
+
+    rewriter.eraseOp(writeOp);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ArmSME-specific fixup canonicalizations/folds
 //===----------------------------------------------------------------------===//
@@ -663,9 +788,12 @@ struct VectorLegalizationPass
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
                  ConvertIllegalShapeCastOpsToTransposes>(context);
-    // Note: High benefit to ensure masked outer products are lowered first.
-    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
-        converter, context, 1024);
+    // Note: These two patterns are added with a high benefit to ensure:
+    //  - Masked outer products are handled before unmasked ones
+    //  - Multi-tile writes are lowered as a store loop (if possible)
+    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
+                 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
+                                                            /*benefit=*/1024);
     patterns.add<LegalizeArithConstantOpsByDecomposition,
                  LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f43ef1cce787c..70d2a06797a31 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
 func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
 {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
   // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
-  // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT:   %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
@@ -201,6 +207,47 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
 
 // -----
 
+// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
+// CHECK-SAME:                                    %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:                                    %[[DIM_0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM_1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
+func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[MASK:.*]] =  vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: }
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
 #transpose = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +256,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
 func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
 {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +269,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
   // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
index ada744b322fe9..03a7d25cffa76 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -1,7 +1,8 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN:   -test-lower-to-arm-sme -test-lower-to-llvm | \
+// RUN:   -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \

>From ee29dbeb8ab9f0bc4f6740d9e2038774f8972807 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 24 Jun 2024 11:07:22 +0000
Subject: [PATCH 2/4] Fixups

---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 39 ++++++++++++-------
 .../Dialect/ArmSME/vector-legalization.mlir   | 29 ++++++++++++++
 2 files changed, 53 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index e71e9e8969ffa..4c6d03872cd3a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -376,27 +376,36 @@ struct LegalizeTransferWriteOpsByDecomposition
 
 /// Legalize a multi-tile transfer_write as a single store loop. This is done as
 /// part of type decomposition as at this level we know each tile write is
-/// disjoint, but that information is lost after decomposition (without
-/// static analysis).
+/// disjoint, but that information is lost after decomposition (without analysis
+/// to reconstruct it).
 ///
-/// Example (in pseudo-MLIR):
+/// Example:
 ///
 /// ```
-/// vector.transfer_write vector, dest[x, y], mask
-///   : vector<[16]x[4]xf32>, memref<?x?xf32>
+/// vector.transfer_write %vector, %dest[%y, %x], %mask
+///   : vector<[16]x[8]xi16>, memref<?x?xi16>
 /// ```
 /// Is rewritten to:
 /// ```
-/// for i in range (0, 4 * vscale) {
-///   let sliceRow = i + tile_n.row * vscale;              ─┐
-///   let sliceCol = tile_n.col * vscale;                   |
-///   slice = vector.extract tile_n[i]                      |
-///     : vector<[4]xf32> from vector<[16]x[4]xf32>         |
-///   slice_mask = vector.extract mask[sliceRow]            |- Repeated 4x for
-///     : vector<[4]xi1> from vector<[16]x[4]xi1>           |  all tiles in
-///   vector.transfer_write                                 |  [16]x[4]
-///     slice, dest[x + sliceRow, y + sliceCol], slice_mask |
-///     : vector<[4]xf32>, memref<?x?xf32>                  ┘
+/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
+///   %upper_slice_y = arith.addi %slice_idx, %y : index
+///   %upper_slice_mask = vector.extract %mask[%slice_idx]
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>
+///   %upper_slice = vector.extract %upper_tile[%slice_idx]
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>
+///   vector.transfer_write %upper_slice,
+///     %dest[%upper_slice_y, %x], %upper_slice_mask
+///     : vector<[8]xi16>, memref<?x?xi16>
+///   // Same again for the lower tile:
+///   %lower_slice_idx = arith.addi %c8_vscale, %slice_idx : index
+///   %lower_slice_y = arith.addi %lower_slice_idx, %y : index
+///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>
+///   %lower_slice = vector.extract %lower_tile[%slice_idx]
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>
+///   vector.transfer_write %lower_slice,
+///     %dest[%lower_slice_y, %x], %lower_slice_mask
+///     : vector<[8]xi16>, memref<?x?xi16>
 /// }
 /// ```
 struct LegalizeMultiTileTransferWriteAsStoreLoop
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 70d2a06797a31..a9783bd945ac0 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -248,6 +248,35 @@ func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0:
 
 // -----
 
+// Tensor semantics are not supported for the store loop lowering.
+
+// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
+// CHECK-NOT: scf.for
+func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// Transposes are not supported for the store loop lowering.
+
+// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
+// CHECK-NOT: scf.for
+func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
+  return
+}
+
+// -----
+
 #transpose = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(

>From efbb5708c6de9d51fb53b72cd8065576f0c5caa0 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 25 Jun 2024 09:41:16 +0000
Subject: [PATCH 3/4] Fixups

---
 .../ArmSME/Transforms/VectorLegalization.cpp   | 13 ++++++++-----
 .../Dialect/ArmSME/vector-legalization.mlir    | 18 ++++++++++++++++--
 2 files changed, 24 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 4c6d03872cd3a..64db892c28ad7 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -434,6 +434,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
       return rewriter.notifyMatchFailure(writeOp,
                                          kMatchFailureNotSMETileTypeMultiple);
 
+    // Note: We also disallow masks where any dimension is larger than 16 as
+    // that won't be possible to arm_sve.psel.
     auto mask = writeOp.getMask();
     if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
                                               vectorType.getDimSize(1) > 16)))
@@ -462,9 +464,9 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
         rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
     rewriter.setInsertionPointToStart(storeLoop.getBody());
 
-    // For each tile sub-tile of the multi-tile `vectorType`.
+    // For each sub-tile of the multi-tile `vectorType`.
     auto inputSMETiles = adaptor.getVector();
-    auto inductionVar = storeLoop.getInductionVar();
+    auto tileSliceIndex = storeLoop.getInductionVar();
     for (auto [index, smeTile] : llvm::enumerate(
              decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
       // The coordinates of the tile within `vectorType`.
@@ -473,7 +475,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
 
       // The current slice of `vectorType` we are processing.
       auto sliceIndex =
-          rewriter.create<arith::AddIOp>(loc, tileRow, inductionVar);
+          rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
 
       // Where in the destination memref the current slice will be stored.
       auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
@@ -491,9 +493,10 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
               loc, sliceMaskType, sliceMask, smeTile.col);
       }
 
-      // Extract and store the current slice slice.
+      // Extract and store the current slice.
       Value tile = inputSMETiles[index];
-      auto slice = rewriter.create<vector::ExtractOp>(loc, tile, inductionVar);
+      auto slice =
+          rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
       rewriter.create<vector::TransferWriteOp>(
           loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
           AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a9783bd945ac0..71d80bc16ea12 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -182,8 +182,8 @@ func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector
   // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
   // CHECK-NEXT:   vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
   // CHECK-NEXT:   %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
-  // CHECK-NEXT:   %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
-  // CHECK-NEXT:   vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT:   %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
   // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
@@ -277,6 +277,20 @@ func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32
 
 // -----
 
+// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.
+
+// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
+// CHECK-NOT: scf.for
+func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
 #transpose = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(

>From e4c986d7f60ed3bccbfa988527498d44c2fa9619 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 25 Jun 2024 10:17:23 +0000
Subject: [PATCH 4/4] Comments

---
 .../ArmSME/Transforms/VectorLegalization.cpp  | 35 +++++++++----------
 1 file changed, 16 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 64db892c28ad7..3bd538b93dad4 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -379,7 +379,7 @@ struct LegalizeTransferWriteOpsByDecomposition
 /// disjoint, but that information is lost after decomposition (without analysis
 /// to reconstruct it).
 ///
-/// Example:
+/// Example (pseudo-MLIR):
 ///
 /// ```
 /// vector.transfer_write %vector, %dest[%y, %x], %mask
@@ -388,24 +388,21 @@ struct LegalizeTransferWriteOpsByDecomposition
 /// Is rewritten to:
 /// ```
 /// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
-///   %upper_slice_y = arith.addi %slice_idx, %y : index
-///   %upper_slice_mask = vector.extract %mask[%slice_idx]
-///     : vector<[8]xi1> from vector<[16]x[8]xi1>
-///   %upper_slice = vector.extract %upper_tile[%slice_idx]
-///     : vector<[8]xi16> from vector<[8]x[8]xi16>
-///   vector.transfer_write %upper_slice,
-///     %dest[%upper_slice_y, %x], %upper_slice_mask
-///     : vector<[8]xi16>, memref<?x?xi16>
-///   // Same again for the lower tile:
-///   %lower_slice_idx = arith.addi %c8_vscale, %slice_idx : index
-///   %lower_slice_y = arith.addi %lower_slice_idx, %y : index
-///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]
-///     : vector<[8]xi1> from vector<[16]x[8]xi1>
-///   %lower_slice = vector.extract %lower_tile[%slice_idx]
-///     : vector<[8]xi16> from vector<[8]x[8]xi16>
-///   vector.transfer_write %lower_slice,
-///     %dest[%lower_slice_y, %x], %lower_slice_mask
-///     : vector<[8]xi16>, memref<?x?xi16>
+///   %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>           |
+///   %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>          |
+///   vector.transfer_write %upper_slice,                   |
+///     %dest[%slice_idx + %y, %x], %upper_slice_mask       |
+///     : vector<[8]xi16>, memref<?x?xi16>                  ┘
+///   %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
+///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
+///   %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
+///   vector.transfer_write %lower_slice,                         |
+///     %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
+///     : vector<[8]xi16>, memref<?x?xi16>                        ┘
 /// }
 /// ```
 struct LegalizeMultiTileTransferWriteAsStoreLoop



More information about the Mlir-commits mailing list