[Mlir-commits] [mlir] [mlir][ArmSME] Lower transfer_write + transpose to vertical store (PR #71181)

Cullen Rhodes llvmlistbot at llvm.org
Fri Nov 3 06:15:10 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/71181

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.

>From 5b48f3f2cb06ced56a86f8a1e59832c990ed96d8 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:57:59 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Add support for lowering masked tile_store
 ops

This patch extends ArmSMEToSCF to support lowering of masked tile_store
ops. Only masks created by 'vector.create_mask' are currently supported.

Example:

  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>,
vector<[4]x[4]xi32>

Produces:

  %num_rows = arith.constant 3 : index
  %num_cols = vector.create_mask %c2 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    arm_sme.store_tile_slice %tile, %slice_idx, %num_cols, %dest[%slice_idx, %c0]
      : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  65 ++++++----
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  25 +++-
 .../CPU/ArmSME/test-transfer-write-2d.mlir    | 121 ++++++++++++++++++
 3 files changed, 187 insertions(+), 24 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 80da6ffda1ed2ea..0511e270c992451 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -173,38 +173,59 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     auto tileType = tileStoreOp.getVectorType();
     auto tileElementType = tileType.getElementType();
 
-    // Create a loop that stores each ZA tile slice from memory.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+    Value maskCols;
+    Value upperBound;
+    auto maskOp = tileStoreOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return rewriter.notifyMatchFailure(
+            tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
+                         "currently supported");
+
+      auto numRows = createMaskOp.getOperands()[0];
+      auto numCols = createMaskOp.getOperands()[1];
+
+      upperBound = numRows;
+      maskCols =
+          rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+    } else {
+      // Store all tile slices if no mask.
+      auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+          loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+      auto vscale =
+          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+      // This describes both the number of ZA tile slices and the number of
+      // elements in a vector of SVL bits for a given element type (SVL_B,
+      // SVL_H,
+      // ..., SVL_Q).
+      auto numTileSlices =
+          rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+      upperBound = numTileSlices;
+      // Create an 'all true' predicate for the tile slice.
+      maskCols = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predicateType, true));
+    }
+
+    // Create a loop that stores each (active) active ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
-    // Create an 'all true' predicate for the tile slice.
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(predicateType, true));
-
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
     getMemrefIndices(tileStoreOp.getIndices(),
                      tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
-                     numTileSlices, memrefIndices, loc, rewriter);
+                     upperBound, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
-        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        allTruePredicate, tileStoreOp.getBase(), memrefIndices,
-        tileStoreOp.getLayout());
+        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
 
     return success();
   }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index d61f588941b408c..e839c2e9e06db02 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -46,9 +46,9 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
-// CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK:           %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
@@ -67,6 +67,27 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store_hor_with_mask(
+// CHECK-SAME:                                             %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.print
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
new file mode 100644
index 000000000000000..2555cf7ad73fb18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -0,0 +1,121 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// Vector store.
+func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c0 = arith.constant 0.0 : f32
+  %zero = vector.splat %c0 : vector<[4]x[4]xf32>
+  vector.transfer_write %zero, %A[%base1, %base2] {in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Masked vector store.
+func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c0 = arith.constant 0.0 : f32
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %zero = vector.splat %c0 : vector<[4]x[4]xf32>
+  vector.transfer_write %zero, %A[%base1, %base2], %mask {in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Vector load + print.
+func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  vector.print str "TILE BEGIN:"
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_f32 = arith.constant 1.0 : f32
+  %c10_f32 = arith.constant 10.0 : f32
+
+  %A = memref.alloc(%d0, %d1) : memref<?x?xf32>
+
+  %init = arith.constant 0.0 : f32
+  scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
+    scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
+      memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
+      %inner_val_next = arith.addf %inner_val, %c1_f32 : f32
+      scf.yield %inner_val_next : f32
+    }
+    %val_next = arith.addf %val, %c10_f32 : f32
+    scf.yield %val_next : f32
+  }
+
+  return %A : memref<?x?xf32>
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+
+  // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
+  // non-zero offsets while remaining inbounds.
+  %vscale = vector.vscale
+  %svl_s = arith.muli %c4, %vscale : index
+  %svl_s_plus_two = arith.addi %svl_s, %c2 : index
+
+  // 1. Initialize memory
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 22, 23
+  // CHECK-NEXT: ( 30, 31, 32, 33
+  %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 2. Write 2-D vector of zeroes to 1. at offset [2, 2].
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d(%A, %c2, %c2) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 3. Write 2-D vector of zeroes to 2. but with mask (nrows=2, ncols=3).
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 0, 3
+  // CHECK-NEXT: ( 0, 0, 0, 13
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  memref.dealloc %A : memref<?x?xf32>
+
+  return
+}

>From 0c13506d5d3ecfcddba74fcbc2b0d709125b55f0 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 16 Oct 2023 11:43:32 +0000
Subject: [PATCH 2/2] [mlir][ArmSME] Lower transfer_write + transpose to
 vertical store

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
---
 .../VectorToArmSME/VectorToArmSME.cpp         | 47 +++++++++++++++++--
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 42 +++++++++++++++++
 .../CPU/ArmSME/test-transfer-write-2d.mlir    | 38 +++++++++++++++
 3 files changed, 124 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 5491f7dd30629ad..a8956f0d38fba9d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
 
 /// Conversion pattern for vector.transfer_write.
 ///
-///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
-///                                                      memref<?x?xi8>
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+///            arm_sme.tile_store:
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
 ///
 /// is converted to:
 ///
 ///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
 ///                                                   vector<[16]x[16]xi8>
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
+///            (in-flight transpose):
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
     if (!arm_sme::isValidSMETileVectorType(vType))
       return failure();
 
+    assert(writeOp.getTransferRank() == 2 &&
+           "expected a permutation_map with result dims of the same rank as "
+           "the vector type");
+
     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
       return failure();
 
+    // Out-of-bounds dims are not supported.
+    if (writeOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "not inbounds transfer write");
+
+    arm_sme::TileSliceLayout layout;
+
+    AffineExpr d0, d1;
+    bindDims(writeOp.getContext(), d0, d1);
+    AffineMap map = writeOp.getPermutationMap();
+    if (map.isIdentity())
+      layout = arm_sme::TileSliceLayout::Horizontal;
+    else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                                   writeOp.getContext()))
+      layout = arm_sme::TileSliceLayout::Vertical;
+    else
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "unsupported permutation map");
+
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
         writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
-        writeOp.getMask());
+        writeOp.getMask(), layout);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ed33f8508dba0bf..a1ad25ed77aa8ef 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
 
 // -----
 
+/// in-flight transpose via vertical store.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+  return
+}
+
+// -----
+
+/// in-flight transpose via vertical store with mask.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
+// CHECK-SAME:                                                        %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME:                                                        %[[DEST:.*]]: memref<?x?xbf16>,
+// CHECK-SAME:                                                        %[[MASK:.*]]: vector<[8]x[8]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
   return
 }
 
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.broadcast
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 2555cf7ad73fb18..6c14b632dfd3d25 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -32,6 +32,25 @@ func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: i
   return
 }
 
+// Vector store + transpose.
+func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Masked vector store + transpose.
+func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 // Vector load + print.
 func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
   %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -115,6 +134,25 @@ func.func @entry() {
   call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
   call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
 
+  // 4. Reload 3. + store + transpose.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 3, 13, 0, 0
+  call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
+  // The mask applies after permutation
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
   memref.dealloc %A : memref<?x?xf32>
 
   return



More information about the Mlir-commits mailing list