[Mlir-commits] [mlir] [mlir][ArmSME] Add masking support to memory ops (PR #69148)

Cullen Rhodes llvmlistbot at llvm.org
Sun Oct 15 23:40:40 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/69148

This patch series adds masking support to the ArmSME memory ops, as well as lowerings from Vector `transfer_read` and `vector.transfer_write`. The `transfer_read` to ArmSME is more complete than `transfer_write`, for the latter the VectorToArmSME lowering still needs fleshing out to support in-flight transpose via vertical load and integration tests need to be added.

This support is part of a wider effort to lower linalg.matmul to SME. There's a lot of changes here so I don't expect this to be reviewed as a whole, hence why I've created this as a draft PR. I plan to create separate PRs for each commit.

>From f04e2c08ba55cc858c865a071f98dc08f8813d18 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 07:04:33 +0000
Subject: [PATCH 1/9] [mlir][ArmSME] Update tile slice layout syntax

This patch prefixes tile slice layout with `layout` in the
assemblyFormat:

  - <vertical>   -> layout<vertical>
  - <horizontal> -> layout<horizontal>

The reason for this change is the current format doesn't play nicely
with additional optional operands (required to support padding and
masking in an upcoming patch), as it becomes ambiguous.

This affects the the following ops:

  - arm_sme.tile_load
  - arm_sme.tile_store
  - arm_sme.load_tile_slice
  - arm_sme.store_tile_slice
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |  39 ++---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |   4 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |   6 +-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |   8 +-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir |  36 ++---
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 152 +++++++++---------
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |  36 ++---
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |   2 +-
 8 files changed, 139 insertions(+), 144 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 049c9759d70bf43..dab54b63d8d22be 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -76,6 +76,7 @@ def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
 def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
                                           "layout"> {
   let assemblyFormat = "`<` $value `>`";
+  let defaultValue = "TileSliceLayout::Horizontal";
 }
 
 //===----------------------------------------------------------------------===//
@@ -248,19 +249,18 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
 
     Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile = arm_sme.tile_load %base[%c0, %c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
     Example 3: Load a 128-bit element ZA tile with horizontal layout (default) from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0], <horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
 
@@ -274,7 +274,7 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   }];
 
   let assemblyFormat =
-    "$base `[` $indices `]` (`,` $layout^)? attr-dict "
+    "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
       "`:` type($base) `,` type($result)";
 }
 
@@ -296,19 +296,17 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
 
     Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.tile_store %tile, %base[%c0, %c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
     Example 3: Store a 128-bit element ZA tile with horizontal (default) layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0], <horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -320,7 +318,7 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
   }];
 
   let assemblyFormat =
-    "$valueToStore `,` $base `[` $indices `]` (`,` $layout^)? attr-dict "
+    "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
       "`:` type($base) `,` type($valueToStore)";
 }
 
@@ -348,19 +346,18 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 
     Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
     Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from">:$base,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
 
@@ -374,7 +371,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`,` $layout^)?
+    $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
       attr-dict `:` type($base) `,` type($result)
   }];
 }
@@ -401,19 +398,17 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
 
     Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
     Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0], <vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices,
-    DefaultValuedAttr<ArmSME_TileSliceLayoutAttr,
-                      "::mlir::arm_sme::TileSliceLayout::Horizontal">:$layout
+    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -425,7 +420,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`,` $layout^)?
+    $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
       attr-dict `:` type($base) `,` type($tile)
   }];
 }
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 881cc8575fb4824..0ec51b7430c0213 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical>
+///  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical>
 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
 ///  ```
 ///
@@ -147,7 +147,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
 ///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx],
-///      <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+///      layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index cbc5e468c729372..d06eb4f5b01c950 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -67,7 +67,7 @@ namespace {
 ///
 /// is converted to:
 ///
-///   arm_sme.tile_load ... <vertical>
+///   arm_sme.tile_load ... layout<vertical>
 struct TransferReadPermutationToArmSMELowering
     : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -368,8 +368,8 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
 ///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
 ///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
 ///     : memref<?x?xi32>, vector<[4]x[4]xi32>
-///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0], <vertical>
-///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load %alloca[%c0, %c0]
+///     layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///
 /// NOTE: Tranposing via memory is obviously expensive, the current intention
 /// is to avoid the transpose if possible, this is therefore intended as a
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 09f148bcd42f593..4b3020970d6ccc1 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -21,10 +21,10 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
 // -----
 
 // CHECK-LABEL: @arm_sme_tile_load_ver
-// CHECK: arm_sme.load_tile_slice {{.*}} <vertical>
+// CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical>
 func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
@@ -50,10 +50,10 @@ func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
 // -----
 
 // CHECK-LABEL: @arm_sme_tile_store_ver
-// CHECK: arm_sme.store_tile_slice {{.*}} <vertical>
+// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
 func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 4c16e5c488a74cd..07485b3ee8ddf86 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -116,7 +116,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -126,7 +126,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
@@ -136,7 +136,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
@@ -146,7 +146,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
@@ -156,7 +156,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
@@ -166,7 +166,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
@@ -176,7 +176,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -186,7 +186,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
@@ -196,7 +196,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
@@ -316,7 +316,7 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -326,7 +326,7 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
@@ -336,7 +336,7 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
@@ -346,7 +346,7 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
@@ -356,7 +356,7 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
@@ -366,7 +366,7 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
@@ -376,7 +376,7 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -386,7 +386,7 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
@@ -396,7 +396,7 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index f6d19359b8e3af8..427154158e797fd 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -358,81 +358,81 @@ func.func @arm_sme_tile_load_hor_f64(%src : memref<?x?xf64>) {
 // -----
 
 func.func @arm_sme_tile_load_ver_i8(%src : memref<?x?xi8>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_i16(%src : memref<?x?xi16>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_i32(%src : memref<?x?xi32>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_i64(%src : memref<?x?xi64>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_i128(%src : memref<?x?xi128>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_f16(%src : memref<?x?xf16>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_bf16(%src : memref<?x?xbf16>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_f32(%src : memref<?x?xf32>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
-  // CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
@@ -442,7 +442,7 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
 func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0], <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile = arm_sme.tile_load %src[%c0, %c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -534,81 +534,81 @@ func.func @arm_sme_tile_store_hor_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
 // -----
 
 func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
 func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
-  // CHECK: arm_sme.tile_store {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
@@ -618,7 +618,7 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
 func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0], <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, %dest[%c0, %c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -710,81 +710,81 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
 func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
@@ -794,7 +794,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vecto
 func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index, <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -886,81 +886,81 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
 func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
@@ -970,7 +970,7 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0], <horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index b2c8fd8e01ac7ec..455b47a83e28f43 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -5,7 +5,7 @@
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @transfer_read_2d_transpose_i8
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
 func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i8
@@ -17,7 +17,7 @@ func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_i16
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
 func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i16
@@ -29,7 +29,7 @@ func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_i32
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i32
@@ -41,7 +41,7 @@ func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_i64
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
 func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i64
@@ -53,7 +53,7 @@ func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_i128
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
 func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i128
@@ -65,7 +65,7 @@ func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_f16
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
 func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f16
@@ -77,7 +77,7 @@ func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_bf16
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
 func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : bf16
@@ -89,7 +89,7 @@ func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_f32
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
 func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32
@@ -101,7 +101,7 @@ func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
 // -----
 
 // CHECK-LABEL: @transfer_read_2d_transpose_f64
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
 func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
@@ -475,7 +475,7 @@ func.func @splat_vec2d_from_f16(%arg0: f16) {
 // CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
 // CHECK:           arm_sme.tile_store %[[TILE]], %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
-// CHECK:           arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]], <vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:           arm_sme.tile_load %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
 func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
@@ -487,7 +487,7 @@ func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
 // CHECK-LABEL: @transpose_i16
 // CHECK: arith.constant 8
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
 func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
@@ -499,7 +499,7 @@ func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
 // CHECK-LABEL: @transpose_i32
 // CHECK: arith.constant 4
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
   "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
@@ -511,7 +511,7 @@ func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
 // CHECK-LABEL: @transpose_i64
 // CHECK: arith.constant 2
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
 func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
   "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
@@ -524,7 +524,7 @@ func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
 func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
   "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
@@ -536,7 +536,7 @@ func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
 // CHECK-LABEL: @transpose_f16
 // CHECK: arith.constant 8
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
 func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
@@ -548,7 +548,7 @@ func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
 // CHECK-LABEL: @transpose_bf16
 // CHECK: arith.constant 8
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
 func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
@@ -560,7 +560,7 @@ func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
 // CHECK-LABEL: @transpose_f32
 // CHECK: arith.constant 4
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
 func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
@@ -572,7 +572,7 @@ func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
 // CHECK-LABEL: @transpose_f64
 // CHECK: arith.constant 2
 // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
-// CHECK: arm_sme.tile_load {{.*}}, <vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
 func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 8c7d8c954d38475..179e9fa83662ece 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -63,7 +63,7 @@ func.func @entry() {
   }
 
   // Load tile from "mem1" vertically.
-  %0 = arm_sme.tile_load %mem1[%c0, %c0], <vertical> : memref<?xi32>, vector<[4]x[4]xi32>
+  %0 = arm_sme.tile_load %mem1[%c0, %c0] layout<vertical> : memref<?xi32>, vector<[4]x[4]xi32>
 
   // 1. ORIGINAL HORIZONTAL LAYOUT
   // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least

>From b99b39b6ab62c65b50bbfdc1a45b9bd3306fc4f9 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:09:57 +0000
Subject: [PATCH 2/9] [mlir][ArmSME] Add optional padding and mask operands to
 tile_load

Padding and mask are optional, but if one is specified both must be
specified. This is consistent with vector.transfer_read.
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 50 +++++++++++++++++--
 mlir/test/Dialect/ArmSME/invalid.mlir         | 44 ++++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 10 ++++
 3 files changed, 101 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..6f6b54aad0058e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-def TileLoadOp : ArmSME_Op<"tile_load"> {
+def TileLoadOp : ArmSME_Op<"tile_load", [
+  AttrSizedOperandSegments,
+  TypesMatchWith<
+    "padding type matches element type of result (if present)",
+    "result", "padding",
+    "::llvm::cast<VectorType>($_self).getElementType()",
+    "!getPadding() || std::equal_to<>()"
+  >,
+  TypesMatchWith<
+    "mask has i1 element type and same shape as result (if present)",
+    "result", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+    "!getMask() || std::equal_to<>()"
+  >
+]> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     dimensions, since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
+    An optional SSA value `padding` of the same elemental type as the MemRef is
+    provided to specify a fallback value in the case of masking.
+
+    An optional SSA value `mask` may be specified to mask out elements read
+    from the MemRef. The `mask` type is an `i1` vector with a shape that
+    matches how elements are read from the MemRef. Elements whose corresponding
+    mask element is `0` are masked out and replaced with `padding`.
+
+    If either `padding` or `mask` are specified, both must be specified.
+
     Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
+
+    Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
+    ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
+    Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
@@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices, "TileSliceLayout":$layout), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, layout);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($result)";
+    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..9229f0415c076c3 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
   return %0 : vector<[4]x[16]xi8>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
   return %0 : i8
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
   return %0 : i1
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
   return %0 : vector<[4]x[4]xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
@@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
   %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
   return %0 : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..f6459f085843655 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
 
 // -----
 
+/// Padding and mask are optional
+func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
+  // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>

>From 57754f654e2fedcf6cee761f58532ad52ee4dd6d Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:28:53 +0000
Subject: [PATCH 3/9] [mlir][ArmSME] Add mask operand to load_tile_slice

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |  27 +++--
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  18 ++-
 .../Transforms/LegalizeForLLVMExport.cpp      |  37 +++---
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  15 ++-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir |  76 ++++++------
 mlir/test/Dialect/ArmSME/invalid.mlir         |  13 ++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 114 +++++++++---------
 7 files changed, 173 insertions(+), 127 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6f6b54aad0058e5..8a05ed89799d564 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -367,7 +367,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
-    AllTypesMatch<["tile", "result"]>
+    AllTypesMatch<["tile", "result"]>,
+    TypesMatchWith<
+      "mask has i1 element type and same shape as result",
+      "result", "mask",
+      "VectorType("
+        "VectorType::Builder("
+          "::llvm::cast<mlir::VectorType>($_self)"
+        ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
+      ")">,
 ]> {
   let summary = "Tile slice load and update operation";
   let description = [{
@@ -383,23 +391,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
+    An SSA value `mask` specifies to mask out elements read from the MemRef.
+    The `mask` type is an `i1` vector with a shape that matches how elements
+    are read from the MemRef.
+
     Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
     ```
 
     Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
     ```
 
     Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from">:$base,
+    Arg<AnyMemRef, "the reference to load from">:$base, AnyVector:$mask,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -415,8 +427,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
-      attr-dict `:` type($base) `,` type($result)
+    $base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
+      (`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
+                                            type($result)
   }];
 }
 
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 0ec51b7430c0213..9cfb13216d9bfe7 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///
 ///  AFTER:
 ///  ```mlir
+///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
 ///  %tile_id = arm_sme.get_tile_id : i32
 ///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
@@ -69,7 +70,8 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
 ///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
+///      %ptrue_s, %tile, %tile_slice_idx
+///        : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -77,6 +79,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
+    if (tileLoadOp.getMask())
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has mask, needs masked pattern(s)");
+
     OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
@@ -109,6 +115,12 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
+    // Create an 'all true' predicate for the tile slice.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
     // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
     // tile.
     SmallVector<Value> memrefIndices;
@@ -117,8 +129,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
-        tileSliceIndex, tileLoadOp.getLayout());
+        loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
+        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..220e0bdd7097978 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering
         loc, rewriter.getI32Type(), tileSlice);
 
     // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+    auto maskOp = loadTileSliceOp.getMask();
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
@@ -195,24 +190,24 @@ struct LoadTileSliceToArmSMELowering
       default:
         llvm_unreachable("unexpected element type!");
       case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       }
     } else {
@@ -220,23 +215,23 @@ struct LoadTileSliceToArmSMELowering
       default:
         llvm_unreachable("unexpected element type!");
       case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4b3020970d6ccc1..3fb320c0d219e60 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
@@ -10,8 +14,9 @@
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:      %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
 // -----
 
 // CHECK-LABEL: func.func @arm_sme_tile_store_hor(
@@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
   return
 }
 
+//===----------------------------------------------------------------------===//
+// vector.print
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 07485b3ee8ddf86..4fb4ca2f102ee74 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -8,9 +8,9 @@
 
 // CHECK-LABEL:   func.func @arm_sme_load_tile_slice_hor_i8(
 // CHECK-SAME:                                              %[[SRC:.*]]: memref<?x?xi8>,
+// CHECK-SAME:                                              %[[MASK:.*]]: vector<[16]xi1>,
 // CHECK-SAME:                                              %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index) {
-// CHECK:           %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
@@ -21,12 +21,12 @@
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
 // CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -34,9 +34,9 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -44,9 +44,9 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -54,9 +54,9 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -64,9 +64,9 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
 // CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -74,9 +74,9 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -84,9 +84,9 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -94,9 +94,9 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -104,9 +104,9 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
@@ -114,9 +114,9 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
 // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -124,9 +124,9 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -134,9 +134,9 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -144,9 +144,9 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -154,9 +154,9 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
 // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -164,9 +164,9 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -174,9 +174,9 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -184,9 +184,9 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -194,9 +194,9 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 9229f0415c076c3..60350a888c88441 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -141,3 +141,16 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64,
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that mask has i1 element type and same shape as result}}
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[2]xi1>, vector<[16]x[16]xi8>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index f6459f085843655..93b103fb83ac4d3 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -638,173 +638,173 @@ func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memre
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
 /// Layout is optional and horizontal is the default, verify it's still parsed.
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<horizontal> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 

>From ff34fe7ab47231414f0362ff96b92e574078d87f Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:47:40 +0000
Subject: [PATCH 4/9] [mlir][ArmSME] Propagate pad and mask in
 vector.transfer_read lowering

This extends the lowering of vector.transfer_read -> arm_sme.tile_load
lowering to propagate pad and mask.

The restriction on the transfer_read being a transposition is also
removed, identity maps are lowered to normal horizontal loads.
---
 .../VectorToArmSME/VectorToArmSME.cpp         |  57 ++++---
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 140 +++++++++---------
 2 files changed, 109 insertions(+), 88 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..02a5bc64fa52c04 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
 
 namespace {
 
-/// Conversion pattern for vector.transfer_read op with transpose permutation
-/// map to vertical arm_sme.tile_load (in-flight transpose).
+/// Conversion pattern for vector.transfer_read.
+///
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+///            arm_sme.tile_load:
+///
+///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d0, d1)
+///
+/// is converted to:
+///
+///   arm_sme.tile_load ...
+///
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
+///            (in-flight transpose):
 ///
 ///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
 ///
 /// is converted to:
 ///
 ///   arm_sme.tile_load ... layout<vertical>
-struct TransferReadPermutationToArmSMELowering
+struct TransferReadToArmSMELowering
     : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
 
@@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering
       return rewriter.notifyMatchFailure(transferReadOp,
                                          "not a 2 result permutation map");
 
-    AffineMap map = transferReadOp.getPermutationMap();
-
-    // Permutation map doesn't perform permutation, can be lowered to
-    // vector.load by TransferReadToVectorLoadLowering and then
-    // arm_sme.tile_load by VectorLoadToArmSMELowering.
-    if (map.isIdentity())
-      return rewriter.notifyMatchFailure(
-          transferReadOp, "map is an identity, apply another pattern");
-
     auto vectorType = transferReadOp.getVectorType();
     if (!arm_sme::isValidSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(transferReadOp,
@@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering
     if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
       return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
 
-    if (transferReadOp.getMask())
-      // TODO: support masking.
-      return rewriter.notifyMatchFailure(transferReadOp,
-                                         "masking not yet supported");
-
     // Out-of-bounds dims are not supported.
     if (transferReadOp.hasOutOfBoundsDim())
       return rewriter.notifyMatchFailure(transferReadOp,
                                          "not inbounds transfer read");
 
+    arm_sme::TileSliceLayout layout;
+
     AffineExpr d0, d1;
     bindDims(transferReadOp.getContext(), d0, d1);
-    if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
-                              transferReadOp.getContext()))
+    AffineMap map = transferReadOp.getPermutationMap();
+    if (map.isIdentity())
+      layout = arm_sme::TileSliceLayout::Horizontal;
+    else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                                   transferReadOp.getContext()))
+      layout = arm_sme::TileSliceLayout::Vertical;
+    else
       return rewriter.notifyMatchFailure(transferReadOp,
-                                         "not true 2-D matrix transpose");
+                                         "unsupported permutation map");
 
+    // Padding isn't optional for transfer_read, but is only used in the case
+    // of out-of-bounds accesses (not supported here) and/or masking. Mask is
+    // optional, if it's not present don't pass padding.
+    auto mask = transferReadOp.getMask();
+    auto padding = mask ? transferReadOp.getPadding() : nullptr;
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
         transferReadOp, vectorType, transferReadOp.getSource(),
-        transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
+        transferReadOp.getIndices(), padding, mask, layout);
 
     return success();
   }
@@ -432,7 +445,7 @@ struct TransposeOpToArmSMELowering
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
-               SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
+               SplatOpToArmSMELowering, TransferReadToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
                VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 455b47a83e28f43..80ca3d3b8281321 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,181 +1,189 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// vector.transfer_read (with in-flight transpose)
+// vector.transfer_read
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i8
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
-func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+// CHECK-LABEL: @transfer_read_2d_i8
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i8
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
-func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+// CHECK-LABEL: @transfer_read_2d_i16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i16
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i32
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+// CHECK-LABEL: @transfer_read_2d_i32
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i32
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
   "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i64
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
-func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+// CHECK-LABEL: @transfer_read_2d_i64
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
   "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i128
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
-func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+// CHECK-LABEL: @transfer_read_2d_i128
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i128
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
   "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_f16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
-func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+// CHECK-LABEL: @transfer_read_2d_f16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f16
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_bf16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
-func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+// CHECK-LABEL: @transfer_read_2d_bf16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : bf16
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_f32
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
-func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+// CHECK-LABEL: @transfer_read_2d_f32
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_f64
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
-func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+// CHECK-LABEL: @transfer_read_2d_f64
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__bad_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+// CHECK-LABEL: @transfer_read_2d_with_mask_i16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) {
   %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
-  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+  %pad = arith.constant 0 : i16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__non_memref_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+/// in-flight transpose
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  %pad = arith.constant 0 : i8
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
 // CHECK-NOT: arm_sme.tile_load
 // CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
 // CHECK-NOT: arm_sme.tile_load
 // CHECK: vector.transfer_read
-func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
 // -----
 
-/// transfer_read with identity map should be lowered to vector.load by
-/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
-/// VectorLoadToArmSMELowering.
-
-// CHECK-LABEL: @transfer_read_2d__non_permuting_map
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
 // CHECK-NOT: arm_sme.tile_load
 // CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
   return
 }
 

>From 7d658476dfe2a0de9c33355ac58a1b41b3141f75 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 10:19:34 +0000
Subject: [PATCH 5/9] [mlir][ArmSME] Add tile slice layout attr to vector <->
 tile ops

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 36 ++++++++++------
 .../Transforms/LegalizeForLLVMExport.cpp      | 43 +++++++++++++------
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 32 ++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 16 +++++++
 4 files changed, 100 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a05ed89799d564..e35725934315bb2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -498,21 +498,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
     of a 2-D scalable vector tile at the given index. The type of the 1-D
     scalable vector to be moved must match the type of the tile slice. A tile
     slice is a 1-D vector of horizontally or vertically contiguous elements
-    within a ZA tile. Horizontal tile slices are currently assumed when
-    lowering to intrinsics. The updated tile is returned as the result.
+    within a ZA tile. The updated tile is returned as the result.
 
-    Example 1: Move a vector<[16]xi8> into tile at given index.
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
+
+    Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
     ```mlir
     %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
     ```
 
-    Example 2: Move a vector<[2]xf64> into tile at given index.
+    Example 2: Move a vector<[2]xf64> into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
     ```
   }];
   let arguments = (ins
-      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout);
   let results = (outs SMETile:$result);
 
   let extraClassDeclaration = [{
@@ -522,7 +525,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $vector `,` $tile `,` $tile_slice_index
+    $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
       attr-dict `:` type($vector) `into` type($result)
   }];
 }
@@ -537,21 +540,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   let description = [{
     The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
     scalable tile at the given index. A tile slice is a 1-D vector of
-    horizontally or vertically contiguous elements within a ZA tile. Horizontal
-    tile slices are currently assumed when lowering to intrinsics.
+    horizontally or vertically contiguous elements within a ZA tile.
+
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
 
-    Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+    Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
     ```mlir
     %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
     ```
 
-    Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+    Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
     ```mlir
-    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
     ```
   }];
 
-  let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+  let arguments = (ins
+    SMETile:$tile, Index:$tile_slice_index,
+    ArmSME_TileSliceLayoutAttr:$layout
+  );
   let results = (outs SVEVector:$result);
 
   let extraClassDeclaration = [{
@@ -559,7 +567,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 
   let assemblyFormat = [{
-      $tile `[` $tile_slice_index `]` attr-dict
+      $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
       `:` type($result) `from` type($tile)
   }];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 220e0bdd7097978..86f245d82b16c77 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -345,8 +345,7 @@ struct StoreTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
 struct MoveVectorToTileSliceToArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
   using ConvertOpToLLVMPattern<
@@ -383,10 +382,19 @@ struct MoveVectorToTileSliceToArmSMELowering
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
 
-    // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
-    rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-        loc, tileI32, tileSliceI32, allActiveMask,
-        moveVectorToTileSliceOp.getVector());
+    // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
+    switch (moveVectorToTileSliceOp.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.create<arm_sme::aarch64_sme_write_vert>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
     // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -397,8 +405,7 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
 struct MoveTileSliceToVectorArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
   using ConvertOpToLLVMPattern<
@@ -430,10 +437,19 @@ struct MoveTileSliceToVectorArmSMELowering
     auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
         loc, rewriter.getI32Type(), sliceIndex);
 
-    // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
-    rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
-        moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-        tileIdI32, sliceIndexI32);
+    // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
+    switch (moveTileSliceToVector.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    }
 
     return success();
   }
@@ -675,7 +691,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
       arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 4fb4ca2f102ee74..30ddb3c46860187 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
+  return %slice : vector<[1]xi128>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 93b103fb83ac4d3..f0704a75ed2fc3a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1069,6 +1069,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
   return
 }
 
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -1145,3 +1153,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  return %slice : vector<[2]xf64>
+}

>From 8589e503f836de48356cc4bbed105d661767f148 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 09:15:52 +0000
Subject: [PATCH 6/9] [mlir][ArmSME] Add support for lowering masked tile_load
 ops

This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                          vector<[4]x[4]xi32>

The former (constant non-zero pad) is lowered as follows:
---------------------------------------------------------

  %tile = arm_sme.zero : vector<[4]x[4]xi32>
  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:
------------------------------------------------

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector<[4]xi32> into vector<[4]x[4]xi32>

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 251 +++++++++++++++++-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  56 ++++
 .../CPU/ArmSME/test-transfer-read-2d.mlir     | 237 +++++++++++++++++
 3 files changed, 543 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 9cfb13216d9bfe7..75b7b8acdd190c6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -141,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   }
 };
 
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 0 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///    %tile_update = arm_sme.load_tile_slice
+///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (!constPadOp || constPadOp.getValue() !=
+                           rewriter.getZeroAttr(tileType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Initialize tile with zero to satisfy padding. Inactive cols will be
+    // zeroed anyway since the loads use zeroing predication. For inactive rows
+    // however, no load will occur so these need to be zeroed.
+    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+    // Create a loop to load the active tile slices from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = numRows;
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+    // tile.
+    SmallVector<Value> memrefIndices;
+    auto tileSliceIndex = forOp.getInductionVar();
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     upperBound, memrefIndices, loc, rewriter);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+        tileSliceIndex, tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 1 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+///    %slice = scf.if %row_is_active -> vector<[4]xi32> {
+///      %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+///        : memref<?x?xi32>, vector<[4]xi1>,
+///          vector<[4]xi32> into vector<[4]xi32>
+///      scf.yield %slice : vector<[4]xi32>
+///    } else {
+///      scf.yield %pad_1d : vector<[4]xi32>
+///    }
+///    // Insert slice into tile
+///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (constPadOp &&
+        constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Create 'arm_sme.get_tile' op.
+    auto tileId = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+    // use as input tile to 'arm_sme.load_tile_slice' ops.
+    auto tile =
+        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    auto rowIsActive = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+    SmallVector<Value> memrefIndices;
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
+
+    // Splat pad into 1-D vector matching type of tile slice.
+    auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+    Operation *slice = rewriter.create<scf::IfOp>(
+        loc, rowIsActive,
+        [&](OpBuilder &b, Location loc) {
+          // If the row is active, emit a masked load where the predicate is
+          // 'numCols'. Pad is used for inactive elements, taken from
+          // passthru.
+          auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+              loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+              numColsOp, /*passthru=*/pad1DOp);
+          rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+        },
+        [&](OpBuilder &b, Location loc) {
+          // Inactive rows are filled with pad.
+          rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+        });
+
+    // TODO: If the load is vertical the transpose can't be done in-flight with
+    // a regular (SVE) maskedload. Propagate layout to
+    // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+    // is currently broken.
+
+    // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+        tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
 /// slice using `arm_sme.store_tile_slice`.
 ///
@@ -265,7 +513,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
                TileVectorPrintOpConversion>(patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 3fb320c0d219e60..4906812032ae903 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0 : i32
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
+// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK:               %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:               scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK:             } else {
+// CHECK:               scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK:             }
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..fe40dd13ce2912f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,237 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+// TODO: replace with vector.print <str> once #68695 lands.
+func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) attributes { enable_arm_streaming_ignore } {
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %str_bytes = llvm.getelementptr %str[%c0, %c0]
+    : (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c4 = arith.constant 4 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+    memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0 : vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_f32 = arith.constant 1.0 : f32
+  %c10_f32 = arith.constant 10.0 : f32
+
+  %A = memref.alloc(%d0, %d1) : memref<?x?xf32>
+
+  %init = arith.constant 0.0 : f32
+  scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
+    scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
+      memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
+      %inner_val_next = arith.addf %inner_val, %c1_f32 : f32
+      scf.yield %inner_val_next : f32
+    }
+    %val_next = arith.addf %val, %c10_f32 : f32
+    scf.yield %val_next : f32
+  }
+
+  return %A : memref<?x?xf32>
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+
+  // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
+  // non-zero offsets while remaining inbounds.
+  %vscale = vector.vscale
+  %svl_s = arith.muli %c4, %vscale : index
+  %svl_s_plus_two = arith.addi %svl_s, %c2 : index
+
+  %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
+
+  // 1.a. Read 2D vector from 2D memref.
+  //
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 22, 23
+  // CHECK-NEXT: ( 30, 31, 32, 33
+  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 1.b. Same as 1.a., but with non-zero offsets.
+  //
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 12, 13, 14, 15
+  // CHECK-NEXT: ( 22, 23, 24, 25
+  // CHECK-NEXT: ( 32, 33, 34, 35
+  // CHECK-NEXT: ( 42, 43, 44, 45
+  call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+
+  // 2. Same as 1.a., but with mask and a pad of constant zero.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 0
+  // CHECK-NEXT: ( 10, 11, 12, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  call @transfer_read_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 3. Same as 1.a., but with mask and non-zero pad.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, -42
+  // CHECK-NEXT: ( 10, 11, 12, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  call @transfer_read_2d_mask_non_zero_pad(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 4. Same as 1.a., but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, 20, 30
+  // CHECK-NEXT: ( 1, 11, 21, 31
+  // CHECK-NEXT: ( 2, 12, 22, 32
+  // CHECK-NEXT: ( 3, 13, 23, 33
+  call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Same as 2., but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, 0, 0
+  // CHECK-NEXT: ( 1, 11, 0, 0
+  // CHECK-NEXT: ( 2, 12, 0, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  call @transfer_read_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Same as 3, but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, -42, -42
+  // CHECK-NEXT: ( 1, 11, -42, -42
+  // CHECK-NEXT: ( 2, 12, -42, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  call @transfer_read_2d_mask_non_zero_pad_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  memref.dealloc %A : memref<?x?xf32>
+
+  return
+}
+
+llvm.mlir.global internal constant @tile_begin("TILE BEGIN:    \0A\00")

>From f2be769abfc3f432c118d46e6318e06993663fba Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:16:54 +0000
Subject: [PATCH 7/9] [mlir][ArmSME] Add optional mask operand to tile_store

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 37 +++++++++++++++++--
 .../VectorToArmSME/VectorToArmSME.cpp         |  4 +-
 mlir/test/Dialect/ArmSME/invalid.mlir         | 14 +++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       |  9 +++++
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 14 +++++++
 5 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index e35725934315bb2..85cbe22acad6fd3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -322,7 +322,18 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
       "attr-dict `:` type($base) `,` type($result)";
 }
 
-def TileStoreOp : ArmSME_Op<"tile_store"> {
+def TileStoreOp : ArmSME_Op<"tile_store", [
+  AttrSizedOperandSegments,
+  TypesMatchWith<
+    "mask has i1 element type and same shape as value to store (if present)",
+    "valueToStore", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+    "!getMask() || std::equal_to<>()"
+  >
+]> {
   let summary = "Tile store operation";
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -333,6 +344,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     rank 2 with dynamic dimensions, since the operation is scalable, and the
     element type must be a scalar that matches the element type of the result.
 
+    An optional SSA value `mask` may be specified to mask out elements written
+    to the MemRef. The `mask` type is an `i1` vector of the same shape as the
+    vector type that matches how elements are written into the MemRef. Elements
+    whose corresponding mask element is `0` are masked out.
+
     Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
     ```mlir
     arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -347,10 +363,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     ```mlir
     arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
+
+    Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
+    ```mlir
+    arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    ```
   }];
   let arguments = (ins SMETile:$valueToStore,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
+    Variadic<Index>:$indices, Optional<AnyVector>:$mask,
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -361,9 +383,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, valueToStore, base, indices, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($valueToStore)";
+    "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 02a5bc64fa52c04..0cc5732c9212d18 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -157,8 +157,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
+        writeOp.getMask());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 60350a888c88441..7a2550b8576d744 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -142,6 +142,20 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64,
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}}
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.load_tile_slice
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index f0704a75ed2fc3a..c0c5c539f3f083d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
 
 // -----
 
+func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 80ca3d3b8281321..f9251edbe658b63 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -323,6 +323,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
 
 // -----
 
+// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xf64>,
+// CHECK-SAME:                                             %[[MASK:.*]]: vector<[2]x[2]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.

>From 5596acc43932ac7960243899bdf9889435ce98d9 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:39:18 +0000
Subject: [PATCH 8/9] [mlir][ArmSME] Add mask operand to store_tile_slice

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |  28 +++--
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |   9 +-
 .../Transforms/LegalizeForLLVMExport.cpp      |  28 ++---
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |   3 +-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir |  76 ++++++------
 mlir/test/Dialect/ArmSME/invalid.mlir         |  14 +++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 114 +++++++++---------
 7 files changed, 151 insertions(+), 121 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 85cbe22acad6fd3..36fc4b9a3972836 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -462,7 +462,16 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 }
 
-def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
+def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
+  TypesMatchWith<
+    "mask has i1 element type and same shape as tile slice",
+    "tile", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
+    ")">
+]> {
   let summary = "Tile slice store operation";
   let description = [{
     Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
@@ -477,22 +486,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the input tile.
 
+    An SSA value `mask` specifies to mask out elements written to the MemRef.
+    The `mask` type is an `i1` vector with a shape that matches how elements
+    are written to the MemRef.
+
     Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8>
     ```
 
     Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32>
     ```
 
     Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128>
     ```
   }];
-  let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+  let arguments = (ins
+    SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
     Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -506,8 +520,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
-      attr-dict `:` type($base) `,` type($tile)
+    $tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)?
+      attr-dict `:` type($base) `,` type($mask) `,` type($tile)
   }];
 }
 
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 75b7b8acdd190c6..e72064651c5cae3 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -437,6 +437,12 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
+    // Create an 'all true' predicate for the tile slice.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
     getMemrefIndices(tileStoreOp.getIndices(),
@@ -444,7 +450,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
+        allTruePredicate, tileStoreOp.getBase(), memrefIndices,
+        tileStoreOp.getLayout());
 
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 86f245d82b16c77..bbfe41a34e15057 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -278,13 +278,7 @@ struct StoreTileSliceToArmSMELowering
     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), tileSlice);
 
-    // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+    auto maskOp = storeTileSliceOp.getMask();
 
     Value tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
@@ -295,23 +289,23 @@ struct StoreTileSliceToArmSMELowering
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       }
     } else {
@@ -320,23 +314,23 @@ struct StoreTileSliceToArmSMELowering
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       }
     }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4906812032ae903..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -104,8 +104,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK:           %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 30ddb3c46860187..8fdcf69958244f3 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -209,8 +209,8 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // CHECK-LABEL:   func.func @arm_sme_store_tile_slice_hor_i8(
 // CHECK-SAME:                                               %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                               %[[TILE_SLICE_INDEX:.*]]: index,
+// CHECK-SAME:                                               %[[MASK:.*]]: vector<[16]xi1>,
 // CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>) {
-// CHECK:           %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
@@ -221,12 +221,12 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
 // CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -234,9 +234,9 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -244,9 +244,9 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -254,9 +254,9 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -264,9 +264,9 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
 // CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -274,9 +274,9 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -284,9 +284,9 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -294,9 +294,9 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -304,9 +304,9 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
@@ -314,9 +314,9 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
 // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -324,9 +324,9 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -334,9 +334,9 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -344,9 +344,9 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -354,9 +354,9 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
 // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -364,9 +364,9 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -374,9 +374,9 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -384,9 +384,9 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -394,9 +394,9 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 7a2550b8576d744..c29ae0581d39203 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -168,3 +168,17 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[2]xi1>, vector<[16]x[16]xi8>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+
+// -----
+
+func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi8>) -> () {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that mask has i1 element type and same shape as tile slice}}
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[8]xi1>, vector<[16]x[16]xi8>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index c0c5c539f3f083d..640ca3835e88a9a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -823,173 +823,173 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
 /// Layout is optional and horizontal is the default, verify it's still parsed.
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<horizontal> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 

>From 5a336eef187e2448cd0393f76b2a9815fe25dad4 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:57:59 +0000
Subject: [PATCH 9/9] [mlir][ArmSME] Add support for lowering masked tile_store
 ops

This patch extends ArmSMEToSCF to support lowering of masked tile_store
ops. Only masks created by 'vector.create_mask' are currently supported.

Example:

  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>,
vector<[4]x[4]xi32>

Produces:

  %num_rows = arith.constant 3 : index
  %num_cols = vector.create_mask %c2 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    arm_sme.store_tile_slice %tile, %slice_idx, %num_cols, %dest[%slice_idx, %c0]
      : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 65 ++++++++++++-------
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           | 25 ++++++-
 2 files changed, 66 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e72064651c5cae3..86d1172ac4957b1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -420,38 +420,59 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     auto tileType = tileStoreOp.getVectorType();
     auto tileElementType = tileType.getElementType();
 
-    // Create a loop that stores each ZA tile slice from memory.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+    Value maskCols;
+    Value upperBound;
+    auto maskOp = tileStoreOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return rewriter.notifyMatchFailure(
+            tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
+                         "currently supported");
+
+      auto numRows = createMaskOp.getOperands()[0];
+      auto numCols = createMaskOp.getOperands()[1];
+
+      upperBound = numRows;
+      maskCols =
+          rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+    } else {
+      // Store all tile slices if no mask.
+      auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+          loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+      auto vscale =
+          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+      // This describes both the number of ZA tile slices and the number of
+      // elements in a vector of SVL bits for a given element type (SVL_B,
+      // SVL_H,
+      // ..., SVL_Q).
+      auto numTileSlices =
+          rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+      upperBound = numTileSlices;
+      // Create an 'all true' predicate for the tile slice.
+      maskCols = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predicateType, true));
+    }
+
+    // Create a loop that stores each (active) active ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
-    // Create an 'all true' predicate for the tile slice.
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(predicateType, true));
-
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
     getMemrefIndices(tileStoreOp.getIndices(),
                      tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
-                     numTileSlices, memrefIndices, loc, rewriter);
+                     upperBound, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
-        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        allTruePredicate, tileStoreOp.getBase(), memrefIndices,
-        tileStoreOp.getLayout());
+        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
 
     return success();
   }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 55ea56f42c96ed9..58c6998870edd98 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -102,9 +102,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
-// CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK:           %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
@@ -123,6 +123,27 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store_hor_with_mask(
+// CHECK-SAME:                                             %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.print
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list