[Mlir-commits] [mlir] 5ed5d72 - [mlir][ArmSME] Lower multi-tile stores to a single loop (#96187)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jun 25 04:47:00 PDT 2024


Author: Benjamin Maxwell
Date: 2024-06-25T12:46:56+01:00
New Revision: 5ed5d723db58f7556782427444642d0571cdf649

URL: https://github.com/llvm/llvm-project/commit/5ed5d723db58f7556782427444642d0571cdf649
DIFF: https://github.com/llvm/llvm-project/commit/5ed5d723db58f7556782427444642d0571cdf649.diff

LOG: [mlir][ArmSME] Lower multi-tile stores to a single loop (#96187)

This adds a new pattern that can legalize a multi-tile transfer_write as
a single store loop. This is done as part of type decomposition as at
this level we know each tile write is disjoint, but that information is
lost after decomposition (without analysis to reconstruct it).

Example (pseudo-MLIR):

```
vector.transfer_write %vector, %dest[%y, %x], %mask
  : vector<[16]x[8]xi16>, memref<?x?xi16>
```
Is rewritten to:
```
scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
  %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
    : vector<[8]xi1> from vector<[16]x[8]xi1>           |
  %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
    : vector<[8]xi16> from vector<[8]x[8]xi16>          |
  vector.transfer_write %upper_slice,                   |
    %dest[%slice_idx + %y, %x], %upper_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                  ┘
  %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
  %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
    : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
  %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
    : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
  vector.transfer_write %lower_slice,                         |
    %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
    : vector<[8]xi16>, memref<?x?xi16>                        ┘
}
```

Added: 
    

Modified: 
    mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
    mlir/test/Dialect/ArmSME/vector-legalization.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index b595c6dd8a684..96dad6518fec8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Transforms/OneToNTypeConversion.h"
@@ -373,6 +374,139 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+/// Legalize a multi-tile transfer_write as a single store loop. This is done as
+/// part of type decomposition as at this level we know each tile write is
+/// disjoint, but that information is lost after decomposition (without analysis
+/// to reconstruct it).
+///
+/// Example (pseudo-MLIR):
+///
+/// ```
+/// vector.transfer_write %vector, %dest[%y, %x], %mask
+///   : vector<[16]x[8]xi16>, memref<?x?xi16>
+/// ```
+/// Is rewritten to:
+/// ```
+/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
+///   %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>           |
+///   %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>          |
+///   vector.transfer_write %upper_slice,                   |
+///     %dest[%slice_idx + %y, %x], %upper_slice_mask       |
+///     : vector<[8]xi16>, memref<?x?xi16>                  ┘
+///   %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
+///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
+///     : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
+///   %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
+///     : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
+///   vector.transfer_write %lower_slice,                         |
+///     %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
+///     : vector<[8]xi16>, memref<?x?xi16>                        ┘
+/// }
+/// ```
+struct LegalizeMultiTileTransferWriteAsStoreLoop
+    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
+  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
+                  OneToNPatternRewriter &rewriter) const override {
+    if (writeOp.hasPureTensorSemantics())
+      return rewriter.notifyMatchFailure(
+          writeOp, "TODO: tensor semantics are unsupported");
+
+    auto permutationMap = writeOp.getPermutationMap();
+    if (!permutationMap.isPermutation())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNonPermutationMap);
+
+    bool transposed = !permutationMap.isIdentity();
+    if (transposed)
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "TODO: transpose unsupported");
+
+    auto vectorType = writeOp.getVectorType();
+    if (!isMultipleOfSMETileVectorType(vectorType))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureNotSMETileTypeMultiple);
+
+    // Note: We also disallow masks where any dimension is > 16 because that
+    // prevents the masking from being lowered to use arm_sve.psel.
+    auto mask = writeOp.getMask();
+    if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
+                                              vectorType.getDimSize(1) > 16)))
+      return rewriter.notifyMatchFailure(writeOp,
+                                         kMatchFailureUnsupportedMaskOp);
+
+    auto loc = writeOp.getLoc();
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto createVscaleMultiple = [&](int64_t multiplier) {
+      return rewriter.create<arith::MulIOp>(
+          loc, vscale,
+          rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+    };
+
+    // Get SME tile and slice types.
+    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+    auto minTileSlices = smeTileType.getDimSize(0);
+    VectorType sliceMaskType =
+        VectorType::get(minTileSlices, rewriter.getI1Type(), true);
+
+    // Create loop over all tile slices.
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = createVscaleMultiple(minTileSlices);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto storeLoop =
+        rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    rewriter.setInsertionPointToStart(storeLoop.getBody());
+
+    // For each sub-tile of the multi-tile `vectorType`.
+    auto inputSMETiles = adaptor.getVector();
+    auto tileSliceIndex = storeLoop.getInductionVar();
+    for (auto [index, smeTile] : llvm::enumerate(
+             decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
+      // The coordinates of the tile within `vectorType`.
+      auto tileRow = createVscaleMultiple(smeTile.row);
+      auto tileCol = createVscaleMultiple(smeTile.col);
+
+      // The current slice of `vectorType` we are processing.
+      auto sliceIndex =
+          rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
+
+      // Where in the destination memref the current slice will be stored.
+      auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
+                                                     writeOp.getIndices()[0]);
+      auto storeCol =
+          rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
+
+      // Extract the mask for the current slice.
+      Value sliceMask = nullptr;
+      if (mask) {
+        sliceMask = rewriter.create<vector::ExtractOp>(
+            loc, mask, OpFoldResult(sliceIndex));
+        if (sliceMaskType != sliceMask.getType())
+          sliceMask = rewriter.create<vector::ScalableExtractOp>(
+              loc, sliceMaskType, sliceMask, smeTile.col);
+      }
+
+      // Extract and store the current slice.
+      Value tile = inputSMETiles[index];
+      auto slice =
+          rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
+      rewriter.create<vector::TransferWriteOp>(
+          loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
+          AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
+          sliceMask,
+          rewriter.getBoolArrayAttr(
+              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
+    }
+
+    rewriter.eraseOp(writeOp);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ArmSME-specific fixup canonicalizations/folds
 //===----------------------------------------------------------------------===//
@@ -663,9 +797,12 @@ struct VectorLegalizationPass
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                  LiftIllegalVectorTransposeToMemory,
                  ConvertIllegalShapeCastOpsToTransposes>(context);
-    // Note: High benefit to ensure masked outer products are lowered first.
-    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
-        converter, context, 1024);
+    // Note: These two patterns are added with a high benefit to ensure:
+    //  - Masked outer products are handled before unmasked ones
+    //  - Multi-tile writes are lowered as a store loop (if possible)
+    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
+                 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
+                                                            /*benefit=*/1024);
     patterns.add<LegalizeArithConstantOpsByDecomposition,
                  LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,

diff  --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index f43ef1cce787c..71d80bc16ea12 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
 func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
 {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
   // CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
-  // CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
-  // CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT:   %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
+  // CHECK-NEXT:   vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
+  // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
   vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
@@ -201,6 +207,90 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
 
 // -----
 
+// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
+// CHECK-SAME:                                    %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
+// CHECK-SAME:                                    %[[DIM_0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[DIM_1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                    %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME:                                    %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
+func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
+{
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+  // CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+  // CHECK-DAG: %[[MASK:.*]] =  vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
+  // CHECK-NEXT:   %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: }
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
+// Tensor semantics are not supported for the store loop lowering.
+
+// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
+// CHECK-NOT: scf.for
+func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
+  return
+}
+
+// -----
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+// Transposes are not supported for the store loop lowering.
+
+// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
+// CHECK-NOT: scf.for
+func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
+  return
+}
+
+// -----
+
+// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.
+
+// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
+// CHECK-NOT: scf.for
+func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>)
+{
+  %c0 = arith.constant 0 : index
+  %mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1>
+  vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32>
+  return
+}
+
+// -----
+
 #transpose = affine_map<(d0, d1) -> (d1, d0)>
 
 // CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +299,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
 func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
 {
   // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
   // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
   // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
   // CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +312,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
   // CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   // CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
-  // CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
+  // CHECK-NEXT:   %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT:   %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
+  // CHECK-NEXT:   %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+  // CHECK-NEXT:   vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
+  // CHECK-NEXT: }
   // CHECK-NEXT: return
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
index ada744b322fe9..03a7d25cffa76 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir
@@ -1,7 +1,8 @@
 // RUN: mlir-opt %s \
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule  \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
-// RUN:   -test-lower-to-arm-sme -test-lower-to-llvm | \
+// RUN:   -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \


        


More information about the Mlir-commits mailing list