[Mlir-commits] [mlir] [mlir][ArmSME] Add masking support to memory ops (PR #69148)

Cullen Rhodes llvmlistbot at llvm.org
Mon Oct 16 08:08:50 PDT 2023


https://github.com/c-rhodes updated https://github.com/llvm/llvm-project/pull/69148

>From 73a31af5309a309a2066dcc36828b43c04b5ac7c Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:09:57 +0000
Subject: [PATCH 1/9] [mlir][ArmSME] Add optional padding and mask operands to
 tile_load

Padding and mask are optional, but if one is specified both must be
specified. This is consistent with vector.transfer_read.
---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 50 +++++++++++++++++--
 mlir/test/Dialect/ArmSME/invalid.mlir         | 44 ++++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 10 ++++
 3 files changed, 101 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index dab54b63d8d22be..6f6b54aad0058e5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -231,7 +231,24 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
-def TileLoadOp : ArmSME_Op<"tile_load"> {
+def TileLoadOp : ArmSME_Op<"tile_load", [
+  AttrSizedOperandSegments,
+  TypesMatchWith<
+    "padding type matches element type of result (if present)",
+    "result", "padding",
+    "::llvm::cast<VectorType>($_self).getElementType()",
+    "!getPadding() || std::equal_to<>()"
+  >,
+  TypesMatchWith<
+    "mask has i1 element type and same shape as result (if present)",
+    "result", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+    "!getMask() || std::equal_to<>()"
+  >
+]> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
@@ -242,6 +259,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     dimensions, since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
+    An optional SSA value `padding` of the same elemental type as the MemRef is
+    provided to specify a fallback value in the case of masking.
+
+    An optional SSA value `mask` may be specified to mask out elements read
+    from the MemRef. The `mask` type is an `i1` vector with a shape that
+    matches how elements are read from the MemRef. Elements whose corresponding
+    mask element is `0` are masked out and replaced with `padding`.
+
+    If either `padding` or `mask` are specified, both must be specified.
+
     Example 1: Load an 8-bit element ZA tile with horizontal layout (default) from memory (ZA0.B).
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
@@ -256,10 +283,16 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     ```mlir
     %tile = arm_sme.tile_load %base[%c0, %c0] layout<horizontal> : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
+
+    Example 4: Masked load of int 32-bit element ZA tile with horizontal layout (default) from memory.
+    ```mlir
+    %tile = arm_sme.tile_load %base[%c0, %c0], %pad, %mask : memref<?x?xf32>, vector<[4]x[4]xf32>
+    ```
   }];
   let arguments = (ins
     Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
     Variadic<Index>:$indices,
+    Optional<AnyType>:$padding, Optional<AnyVector>:$mask,
     ArmSME_TileSliceLayoutAttr:$layout
   );
   let results = (outs SMETile:$result);
@@ -273,9 +306,20 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices, "TileSliceLayout":$layout), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, layout);
+    }]>,
+    OpBuilder<(ins "VectorType":$resultType, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, resultType, base, indices, {}, {}, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($result)";
+    "$base `[` $indices `]` (`,` $padding `,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 431009b1b9ede2f..9229f0415c076c3 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
@@ -48,6 +52,10 @@ func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[1
   return %0 : vector<[4]x[16]xi8>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
@@ -64,6 +72,10 @@ func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -
   return %0 : i8
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_get_tile_id__bad_type() -> i1 {
@@ -72,6 +84,10 @@ func.func @arm_sme_get_tile_id__bad_type() -> i1 {
   return %0 : i1
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_vector_to_tile_slice_i8__bad_vector_type(%vector : vector<[8]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]x[16]xi8> {
@@ -90,6 +106,10 @@ func.func @arm_sme_move_vector_to_tile_slice_f32__bad_vector_type(%vector : vect
   return %0 : vector<[4]x[4]xf32>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_tile_slice_to_vector
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[2]xf64> {
@@ -97,3 +117,27 @@ func.func @arm_sme_move_tile_slice_to_vector__bad_result_type(%tile : vector<[4]
   %0 = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[4]x[4]xf32>
   return %0 : vector<[2]xf64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load__bad_padding_type(%src : memref<?x?xf64>, %pad : f32, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%pad' expects different type than prior uses: 'f64' vs 'f32'}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[4]x[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[2]x[2]xi1>' vs 'vector<[4]x[4]xi1>}}
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 427154158e797fd..f6459f085843655 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -438,6 +438,16 @@ func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
 
 // -----
 
+/// Padding and mask are optional
+func.func @arm_sme_tile_load_hor_pad_f64(%src : memref<?x?xf64>, %pad : f64, %mask : vector<[2]x[2]xi1>) {
+  // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_load_explicit_hor(%src : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>

>From a8db19de0693461ba51581364d7380ae5fe3e59d Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:28:53 +0000
Subject: [PATCH 2/9] [mlir][ArmSME] Add mask operand to load_tile_slice

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |  27 +++--
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  18 ++-
 .../Transforms/LegalizeForLLVMExport.cpp      |  37 +++---
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  15 ++-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir |  76 ++++++------
 mlir/test/Dialect/ArmSME/invalid.mlir         |  13 ++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 114 +++++++++---------
 7 files changed, 173 insertions(+), 127 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 6f6b54aad0058e5..8a05ed89799d564 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -367,7 +367,15 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
-    AllTypesMatch<["tile", "result"]>
+    AllTypesMatch<["tile", "result"]>,
+    TypesMatchWith<
+      "mask has i1 element type and same shape as result",
+      "result", "mask",
+      "VectorType("
+        "VectorType::Builder("
+          "::llvm::cast<mlir::VectorType>($_self)"
+        ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
+      ")">,
 ]> {
   let summary = "Tile slice load and update operation";
   let description = [{
@@ -383,23 +391,27 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
+    An SSA value `mask` specifies to mask out elements read from the MemRef.
+    The `mask` type is an `i1` vector with a shape that matches how elements
+    are read from the MemRef.
+
     Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally (default) at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
     ```
 
     Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
     ```
 
     Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice %base[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from">:$base,
+    Arg<AnyMemRef, "the reference to load from">:$base, AnyVector:$mask,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -415,8 +427,9 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
-      attr-dict `:` type($base) `,` type($result)
+    $base `[` $indices `]` `,` $mask `,` $tile `,` $tile_slice_index
+      (`layout` `` $layout^)? attr-dict `:` type($base) `,` type($mask) `,`
+                                            type($result)
   }];
 }
 
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 0ec51b7430c0213..9cfb13216d9bfe7 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -60,6 +60,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///
 ///  AFTER:
 ///  ```mlir
+///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
 ///  %tile_id = arm_sme.get_tile_id : i32
 ///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
@@ -69,7 +70,8 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
 ///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
-///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
+///      %ptrue_s, %tile, %tile_slice_idx
+///        : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
@@ -77,6 +79,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
   LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
                                 PatternRewriter &rewriter) const override {
+    if (tileLoadOp.getMask())
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has mask, needs masked pattern(s)");
+
     OpBuilder::InsertionGuard g(rewriter);
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
@@ -109,6 +115,12 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
+    // Create an 'all true' predicate for the tile slice.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
     // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
     // tile.
     SmallVector<Value> memrefIndices;
@@ -117,8 +129,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), tile, memrefIndices,
-        tileSliceIndex, tileLoadOp.getLayout());
+        loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
+        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 5e13707ea0aa2b9..220e0bdd7097978 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -179,12 +179,7 @@ struct LoadTileSliceToArmSMELowering
         loc, rewriter.getI32Type(), tileSlice);
 
     // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+    auto maskOp = loadTileSliceOp.getMask();
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
@@ -195,24 +190,24 @@ struct LoadTileSliceToArmSMELowering
       default:
         llvm_unreachable("unexpected element type!");
       case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
-            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
+                                                         tileI32, tileSliceI32);
         break;
       }
     } else {
@@ -220,23 +215,23 @@ struct LoadTileSliceToArmSMELowering
       default:
         llvm_unreachable("unexpected element type!");
       case 8:
-        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 16:
-        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 32:
-        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 64:
-        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       case 128:
-        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
+        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
                                                         tileI32, tileSliceI32);
         break;
       }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4b3020970d6ccc1..3fb320c0d219e60 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
@@ -10,8 +14,9 @@
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:      %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -28,6 +33,10 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
 // -----
 
 // CHECK-LABEL: func.func @arm_sme_tile_store_hor(
@@ -57,6 +66,10 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
   return
 }
 
+//===----------------------------------------------------------------------===//
+// vector.print
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 07485b3ee8ddf86..4fb4ca2f102ee74 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -8,9 +8,9 @@
 
 // CHECK-LABEL:   func.func @arm_sme_load_tile_slice_hor_i8(
 // CHECK-SAME:                                              %[[SRC:.*]]: memref<?x?xi8>,
+// CHECK-SAME:                                              %[[MASK:.*]]: vector<[16]xi1>,
 // CHECK-SAME:                                              %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index) {
-// CHECK:           %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
@@ -21,12 +21,12 @@
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
 // CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -34,9 +34,9 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -44,9 +44,9 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -54,9 +54,9 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -64,9 +64,9 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
 // CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -74,9 +74,9 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -84,9 +84,9 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
 // CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -94,9 +94,9 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
 // CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -104,9 +104,9 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
 // CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
@@ -114,9 +114,9 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
 // CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -124,9 +124,9 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -134,9 +134,9 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -144,9 +144,9 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -154,9 +154,9 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
 // CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -164,9 +164,9 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -174,9 +174,9 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
 // CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -184,9 +184,9 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vec
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
 // CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -194,9 +194,9 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vecto
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
 // CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 9229f0415c076c3..60350a888c88441 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -141,3 +141,16 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64,
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask : vector<[2]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that mask has i1 element type and same shape as result}}
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[2]xi1>, vector<[16]x[16]xi8>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index f6459f085843655..93b103fb83ac4d3 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -638,173 +638,173 @@ func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memre
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
 /// Layout is optional and horizontal is the default, verify it's still parsed.
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice %{{.*}}[{{.*}}], %{{.*}}, %{{.*}} : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<horizontal> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 

>From 05686112de89018c982c7e5f6879a2361c2fd562 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 08:47:40 +0000
Subject: [PATCH 3/9] [mlir][ArmSME] Propagate pad and mask in
 vector.transfer_read lowering

This extends the lowering of vector.transfer_read -> arm_sme.tile_load
lowering to propagate pad and mask.

The restriction on the transfer_read being a transposition is also
removed, identity maps are lowered to normal horizontal loads.
---
 .../VectorToArmSME/VectorToArmSME.cpp         |  57 ++++---
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 140 +++++++++---------
 2 files changed, 109 insertions(+), 88 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index d06eb4f5b01c950..02a5bc64fa52c04 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -60,15 +60,30 @@ getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
 
 namespace {
 
-/// Conversion pattern for vector.transfer_read op with transpose permutation
-/// map to vertical arm_sme.tile_load (in-flight transpose).
+/// Conversion pattern for vector.transfer_read.
+///
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+///            arm_sme.tile_load:
+///
+///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d0, d1)
+///
+/// is converted to:
+///
+///   arm_sme.tile_load ...
+///
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_load
+///            (in-flight transpose):
 ///
 ///   vector.transfer_read ...  permutation_map: (d0, d1) -> (d1, d0)
 ///
 /// is converted to:
 ///
 ///   arm_sme.tile_load ... layout<vertical>
-struct TransferReadPermutationToArmSMELowering
+struct TransferReadToArmSMELowering
     : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
 
@@ -79,15 +94,6 @@ struct TransferReadPermutationToArmSMELowering
       return rewriter.notifyMatchFailure(transferReadOp,
                                          "not a 2 result permutation map");
 
-    AffineMap map = transferReadOp.getPermutationMap();
-
-    // Permutation map doesn't perform permutation, can be lowered to
-    // vector.load by TransferReadToVectorLoadLowering and then
-    // arm_sme.tile_load by VectorLoadToArmSMELowering.
-    if (map.isIdentity())
-      return rewriter.notifyMatchFailure(
-          transferReadOp, "map is an identity, apply another pattern");
-
     auto vectorType = transferReadOp.getVectorType();
     if (!arm_sme::isValidSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(transferReadOp,
@@ -96,26 +102,33 @@ struct TransferReadPermutationToArmSMELowering
     if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
       return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
 
-    if (transferReadOp.getMask())
-      // TODO: support masking.
-      return rewriter.notifyMatchFailure(transferReadOp,
-                                         "masking not yet supported");
-
     // Out-of-bounds dims are not supported.
     if (transferReadOp.hasOutOfBoundsDim())
       return rewriter.notifyMatchFailure(transferReadOp,
                                          "not inbounds transfer read");
 
+    arm_sme::TileSliceLayout layout;
+
     AffineExpr d0, d1;
     bindDims(transferReadOp.getContext(), d0, d1);
-    if (map != AffineMap::get(map.getNumDims(), 0, {d1, d0},
-                              transferReadOp.getContext()))
+    AffineMap map = transferReadOp.getPermutationMap();
+    if (map.isIdentity())
+      layout = arm_sme::TileSliceLayout::Horizontal;
+    else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                                   transferReadOp.getContext()))
+      layout = arm_sme::TileSliceLayout::Vertical;
+    else
       return rewriter.notifyMatchFailure(transferReadOp,
-                                         "not true 2-D matrix transpose");
+                                         "unsupported permutation map");
 
+    // Padding isn't optional for transfer_read, but is only used in the case
+    // of out-of-bounds accesses (not supported here) and/or masking. Mask is
+    // optional, if it's not present don't pass padding.
+    auto mask = transferReadOp.getMask();
+    auto padding = mask ? transferReadOp.getPadding() : nullptr;
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
         transferReadOp, vectorType, transferReadOp.getSource(),
-        transferReadOp.getIndices(), arm_sme::TileSliceLayout::Vertical);
+        transferReadOp.getIndices(), padding, mask, layout);
 
     return success();
   }
@@ -432,7 +445,7 @@ struct TransposeOpToArmSMELowering
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
-               SplatOpToArmSMELowering, TransferReadPermutationToArmSMELowering,
+               SplatOpToArmSMELowering, TransferReadToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
                VectorLoadToArmSMELowering, VectorStoreToArmSMELowering>(&ctx);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 455b47a83e28f43..80ca3d3b8281321 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,181 +1,189 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// vector.transfer_read (with in-flight transpose)
+// vector.transfer_read
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i8
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
-func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
+// CHECK-LABEL: @transfer_read_2d_i8
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_i8(%src : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i8
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
   "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
-func.func @transfer_read_2d_transpose_i16(%src : memref<?x?xi16>) {
+// CHECK-LABEL: @transfer_read_2d_i16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_i16(%src : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i16
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i32
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @transfer_read_2d_transpose_i32(%src : memref<?x?xi32>) {
+// CHECK-LABEL: @transfer_read_2d_i32
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @transfer_read_2d_i32(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i32
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi32>, vector<[4]x[4]xi32>
   "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i64
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
-func.func @transfer_read_2d_transpose_i64(%src : memref<?x?xi64>) {
+// CHECK-LABEL: @transfer_read_2d_i64
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_read_2d_i64(%src : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi64>, vector<[2]x[2]xi64>
   "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_i128
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
-func.func @transfer_read_2d_transpose_i128(%src : memref<?x?xi128>) {
+// CHECK-LABEL: @transfer_read_2d_i128
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @transfer_read_2d_i128(%src : memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0 : i128
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xi128>, vector<[1]x[1]xi128>
   "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_f16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
-func.func @transfer_read_2d_transpose_f16(%src : memref<?x?xf16>) {
+// CHECK-LABEL: @transfer_read_2d_f16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @transfer_read_2d_f16(%src : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f16
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf16>, vector<[8]x[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_bf16
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
-func.func @transfer_read_2d_transpose_bf16(%src : memref<?x?xbf16>) {
+// CHECK-LABEL: @transfer_read_2d_bf16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_read_2d_bf16(%src : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : bf16
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_f32
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
-func.func @transfer_read_2d_transpose_f32(%src : memref<?x?xf32>) {
+// CHECK-LABEL: @transfer_read_2d_f32
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_f32(%src : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f32
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d_transpose_f64
-// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
-func.func @transfer_read_2d_transpose_f64(%src : memref<?x?xf64>) {
+// CHECK-LABEL: @transfer_read_2d_f64
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_read_2d_f64(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__bad_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
+// CHECK-LABEL: @transfer_read_2d_with_mask_i16
+// CHECK: arm_sme.tile_load %{{.*}}[{{.*}}], {{.*}}, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @transfer_read_2d_with_mask_i16(%src : memref<?x?xi16>, %mask : vector<[8]x[8]xi1>) {
   %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
-  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
+  %pad = arith.constant 0 : i16
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__non_memref_type
-// CHECK-NOT: arm_sme.tile_load
-// CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
+/// in-flight transpose
+
+// CHECK-LABEL: @transfer_read_2d_transpose_i8
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transfer_read_2d_transpose_i8(%src : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
-  %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  %pad = arith.constant 0 : i8
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
+// CHECK-LABEL: @transfer_read_2d_transpose_with_mask_f32
+// CHECK: arm_sme.tile_load {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @transfer_read_2d_transpose_with_mask_f32(%src : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+  %c0 = arith.constant 0 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transfer_read_2d__bad_type
 // CHECK-NOT: arm_sme.tile_load
 // CHECK: vector.transfer_read
-func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
+func.func @transfer_read_2d__bad_type(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [false, false]} : memref<?x?xf64>, vector<[4]x[4]xf64>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf64>) -> ()
   return
 }
 
 // -----
 
-// CHECK-LABEL: @transfer_read_2d__unsupported_mask
+// CHECK-LABEL: @transfer_read_2d__non_memref_type
 // CHECK-NOT: arm_sme.tile_load
 // CHECK: vector.transfer_read
-func.func @transfer_read_2d__unsupported_mask(%src : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+func.func @transfer_read_2d__non_memref_type(%src : tensor<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad, %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : tensor<?x?xf64>, vector<[2]x[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
 // -----
 
-/// transfer_read with identity map should be lowered to vector.load by
-/// TransferReadToVectorLoadLowering and then arm_sme.tile_load by
-/// VectorLoadToArmSMELowering.
-
-// CHECK-LABEL: @transfer_read_2d__non_permuting_map
+// CHECK-LABEL: @transfer_read_2d__bad_transfer_rank
 // CHECK-NOT: arm_sme.tile_load
 // CHECK: vector.transfer_read
-func.func @transfer_read_2d__non_permuting_map(%src : memref<?x?xf64>) {
+func.func @transfer_read_2d__bad_transfer_rank(%src : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   %pad = arith.constant 0.0 : f64
-  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0, d1)>, in_bounds = [true, true]} : memref<?x?xf64>, vector<[2]x[2]xf64>
-  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  %0 = vector.transfer_read %src[%c0, %c0], %pad {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} : memref<?x?xf64>, vector<[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]xf64>) -> ()
   return
 }
 

>From 554055a54323f56a482d037b387dcc582abea0b1 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 10:19:34 +0000
Subject: [PATCH 4/9] [mlir][ArmSME] Add tile slice layout attr to vector <->
 tile ops

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 36 ++++++++++------
 .../Transforms/LegalizeForLLVMExport.cpp      | 43 +++++++++++++------
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 32 ++++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 16 +++++++
 4 files changed, 100 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a05ed89799d564..e35725934315bb2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -498,21 +498,24 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
     of a 2-D scalable vector tile at the given index. The type of the 1-D
     scalable vector to be moved must match the type of the tile slice. A tile
     slice is a 1-D vector of horizontally or vertically contiguous elements
-    within a ZA tile. Horizontal tile slices are currently assumed when
-    lowering to intrinsics. The updated tile is returned as the result.
+    within a ZA tile. The updated tile is returned as the result.
 
-    Example 1: Move a vector<[16]xi8> into tile at given index.
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
+
+    Example 1: Move a vector<[16]xi8> into tile horizontally (default) at given index.
     ```mlir
     %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[16]xi8> into vector<[16]x[16]xi8>
     ```
 
-    Example 2: Move a vector<[2]xf64> into tile at given index.
+    Example 2: Move a vector<[2]xf64> into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[2]xf64> into vector<[2]x[2]xf64>
+    %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[2]xf64> into vector<[2]x[2]xf64>
     ```
   }];
   let arguments = (ins
-      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index);
+      SVEVector:$vector, SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout);
   let results = (outs SMETile:$result);
 
   let extraClassDeclaration = [{
@@ -522,7 +525,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $vector `,` $tile `,` $tile_slice_index
+    $vector `,` $tile `,` $tile_slice_index (`layout` `` $layout^)?
       attr-dict `:` type($vector) `into` type($result)
   }];
 }
@@ -537,21 +540,26 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   let description = [{
     The tile slice to vector operation extracts a 1-D scalable slice from a 2-D
     scalable tile at the given index. A tile slice is a 1-D vector of
-    horizontally or vertically contiguous elements within a ZA tile. Horizontal
-    tile slices are currently assumed when lowering to intrinsics.
+    horizontally or vertically contiguous elements within a ZA tile.
+
+    An optional tile slice layout attribute specifies whether the tile slice is
+    horizontal (default) or vertical.
 
-    Example 1: Extract `vector<[16]xi8>` from tile at the given index.
+    Example 1: Extract `vector<[16]xi8>` from tile horizontally at the given index.
     ```mlir
     %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
     ```
 
-    Example 2: Extract `vector<[2]xf64>` from tile at the given index.
+    Example 2: Extract `vector<[2]xf64>` from tile vertically at the given index.
     ```mlir
-    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
+    %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
     ```
   }];
 
-  let arguments = (ins SMETile:$tile, Index:$tile_slice_index);
+  let arguments = (ins
+    SMETile:$tile, Index:$tile_slice_index,
+    ArmSME_TileSliceLayoutAttr:$layout
+  );
   let results = (outs SVEVector:$result);
 
   let extraClassDeclaration = [{
@@ -559,7 +567,7 @@ def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
   }];
 
   let assemblyFormat = [{
-      $tile `[` $tile_slice_index `]` attr-dict
+      $tile `[` $tile_slice_index `]` (`layout` `` $layout^)? attr-dict
       `:` type($result) `from` type($tile)
   }];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 220e0bdd7097978..86f245d82b16c77 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -345,8 +345,7 @@ struct StoreTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics.
 struct MoveVectorToTileSliceToArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveVectorToTileSliceOp> {
   using ConvertOpToLLVMPattern<
@@ -383,10 +382,19 @@ struct MoveVectorToTileSliceToArmSMELowering
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
 
-    // Create 'arm_sme.intr.write.horiz' to write vector to tile slice.
-    rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-        loc, tileI32, tileSliceI32, allActiveMask,
-        moveVectorToTileSliceOp.getVector());
+    // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
+    switch (moveVectorToTileSliceOp.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.create<arm_sme::aarch64_sme_write_horiz>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.create<arm_sme::aarch64_sme_write_vert>(
+          loc, tileI32, tileSliceI32, allActiveMask,
+          moveVectorToTileSliceOp.getVector());
+      break;
+    }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
     // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
@@ -397,8 +405,7 @@ struct MoveVectorToTileSliceToArmSMELowering
   }
 };
 
-/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. Only horizontal
-/// tile slices are currently supported.
+/// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics.
 struct MoveTileSliceToVectorArmSMELowering
     : public ConvertOpToLLVMPattern<arm_sme::MoveTileSliceToVectorOp> {
   using ConvertOpToLLVMPattern<
@@ -430,10 +437,19 @@ struct MoveTileSliceToVectorArmSMELowering
     auto sliceIndexI32 = rewriter.create<arith::IndexCastOp>(
         loc, rewriter.getI32Type(), sliceIndex);
 
-    // Create 'arm_sme.intr.read.horiz' to extract the tile slice.
-    rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
-        moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-        tileIdI32, sliceIndexI32);
+    // Create 'arm_sme.intr.read.(horiz|vert)' to extract the tile slice.
+    switch (moveTileSliceToVector.getLayout()) {
+    case arm_sme::TileSliceLayout::Horizontal:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    case arm_sme::TileSliceLayout::Vertical:
+      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
+          moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
+          tileIdI32, sliceIndexI32);
+      break;
+    }
 
     return success();
   }
@@ -675,7 +691,8 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
       arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
       arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_mopa,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
       arm_sme::aarch64_sme_za_enable, arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
   target.addIllegalOp<vector::OuterProductOp>();
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 4fb4ca2f102ee74..30ddb3c46860187 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -400,6 +400,29 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -485,3 +508,12 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
+  return %slice : vector<[1]xi128>
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 93b103fb83ac4d3..f0704a75ed2fc3a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1069,6 +1069,14 @@ func.func @arm_sme_move_vector_to_tile_slice_f64(%vector : vector<[2]xf64>, %til
   return
 }
 
+// -----
+
+func.func @arm_sme_move_vector_to_tile_slice_ver_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[16]xi8> into vector<[16]x[16]xi8>
+  return
+}
 
 //===----------------------------------------------------------------------===//
 // arm_sme.move_tile_slice_to_vector
@@ -1145,3 +1153,11 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
+
+// -----
+
+func.func @arm_sme_move_tile_slice_to_vector_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+  // CHECK: arm_sme.move_tile_slice_to_vector {{.*}} layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[2]xf64> from vector<[2]x[2]xf64>
+  return %slice : vector<[2]xf64>
+}

>From 4badd45c9d0152bb502b8b99b79dc4ef2d095f6d Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 09:15:52 +0000
Subject: [PATCH 5/9] [mlir][ArmSME] Add support for lowering masked tile_load
 ops

This patch extends ArmSMEToSCF to support lowering of masked tile_load
ops. Only masks created by 'vector.create_mask' are currently supported.

There are two lowerings, one for pad of constant zero and another for
non-zero pad. For the following example:

  %pad = arith.constant 0 : i32
  %num_rows = arith.constant 2 : index
  %num_cols = arith.constant 4 : index
  %mask = vector.create_mask %num_rows, %num_cols : <[4]x[4]xi1>
  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>,
                                                          vector<[4]x[4]xi32>

The former (constant non-zero pad) is lowered as follows:
---------------------------------------------------------

  %tile = arm_sme.zero : vector<[4]x[4]xi32>
  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    %tile_update = arm_sme.load_tile_slice
      %src[%slice_idx], %num_cols, %tile, %tile_slice_idx :
      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>

The tile is zeroed the satisfy the padding and only active rows are
loaded.

The latter (non-zero pad) is lowered as follows:
------------------------------------------------

  scf.for %slice_idx = %c0 to %num_tile_slices step %c1 {
    %row_is_active = arith.cmpi ult %slice_idx, %num_rows : index
    %slice = scf.if %row_is_active -> vector<[4]xf32> {
      %slice = vector.maskedload %src[%slice_idx, %c0], %num_cols, %pad_1d :
        memref<?x?xf32>, vector<[4]xi1>, vector<[4]xf32> into vector<[4]xf32>
      scf.yield %slice : vector<[4]xf32>
    } else {
      scf.yield %pad_1d : vector<[4]xf32>
    }
    arm_sme.move_vector_to_tile_slice %slice, %tile, %slice_idx
      : vector<[4]xi32> into vector<[4]x[4]xi32>

The scalar pad is broadcast to a 1-D vector and a regular
'vector.masked_load' (will be lowered to SVE, not SME) loads each slice
for active rows, with padding specified as a passthru. For non-active
rows the slice is the 1-D pad. The resulting slice is inserted into the
tile with 'arm_sme.move_vector_to_tile_slice'.
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 251 +++++++++++++++++-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  56 ++++
 .../CPU/ArmSME/test-transfer-read-2d.mlir     | 237 +++++++++++++++++
 3 files changed, 543 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 9cfb13216d9bfe7..75b7b8acdd190c6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -141,6 +141,254 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
   }
 };
 
+/// Lower `arm_sme.tile_load` with mask and pad of constant zero.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 0 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %tile = arm_sme.zero : vector<[4]x[4]xi32>
+///  %num_cols = vector.create_mask %c4 : vector<[4]xi1>
+///  scf.for %tile_slice_idx = %c0 to %num_rows step %c1 {
+///    %tile_update = arm_sme.load_tile_slice
+///      %src[%tile_slice_idx], %num_cols, %tile, %tile_slice_idx :
+///      memref<?x?xi32>, vector<[1]xi32>, vector<[4]x[4]xi32>
+///  }
+///  ```
+///
+/// NOTE: Only mask of 'vector.create_mask' op is currently supported.
+struct TileLoadOpWithMaskAndPadZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (!constPadOp || constPadOp.getValue() !=
+                           rewriter.getZeroAttr(tileType.getElementType()))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has non-zero pad, needs non-zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Initialize tile with zero to satisfy padding. Inactive cols will be
+    // zeroed anyway since the loads use zeroing predication. For inactive rows
+    // however, no load will occur so these need to be zeroed.
+    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+
+    // Create a loop to load the active tile slices from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = numRows;
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    // Create 'arm_sme.load_tile_slice' to load tile slice from memory into
+    // tile.
+    SmallVector<Value> memrefIndices;
+    auto tileSliceIndex = forOp.getInductionVar();
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     upperBound, memrefIndices, loc, rewriter);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
+        tileSliceIndex, tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
+/// Lower `arm_sme.tile_load` with mask and non-zero pad.
+///
+///  BEFORE:
+///  ```mlir
+///  %pad = arith.constant 1 : i32
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %mask = vector.create_mask %num_rows, %num_cols : vector<[4]x[4]xi1>
+///  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask :
+///    memref<?x?xi32>, vector<[4]x[4]xi32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %pad_1d = arith.constant dense<1> : vector<[4]xi32>
+///  %num_rows = arith.constant 2 : index
+///  %num_cols = arith.constant 4 : index
+///  %tile_id = arm_sme.get_tile_id : i32
+///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %vscale = vector.vscale
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %min_svl_s = arith.constant 4 : index
+///  %svl_s = arith.muli %min_svl_s, %vscale : index
+///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
+///    %row_is_active = arith.cmpi ult %tile_slice_idx, %num_rows : index
+///    %slice = scf.if %row_is_active -> vector<[4]xi32> {
+///      %slice = vector.maskedload %base[%tile_slice_idx, %c0], %num_cols, %pad
+///        : memref<?x?xi32>, vector<[4]xi1>,
+///          vector<[4]xi32> into vector<[4]xi32>
+///      scf.yield %slice : vector<[4]xi32>
+///    } else {
+///      scf.yield %pad_1d : vector<[4]xi32>
+///    }
+///    // Insert slice into tile
+///    arm_sme.move_vector_to_tile_slice %slice, %tile, %tile_slice_idx
+///      : vector<[4]xi32> into vector<[4]x[4]xi32>
+///  }
+///  ```
+struct TileLoadOpWithMaskAndPadNonZeroConversion
+    : public OpRewritePattern<arm_sme::TileLoadOp> {
+  using OpRewritePattern<arm_sme::TileLoadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arm_sme::TileLoadOp tileLoadOp,
+                                PatternRewriter &rewriter) const override {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = tileLoadOp.getLoc();
+    auto tileType = tileLoadOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
+
+    auto maskOp = tileLoadOp.getMask();
+    if (!maskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has no mask, needs unmasked pattern");
+
+    auto padOp = tileLoadOp.getPadding();
+    assert(padOp && "expected padding when masking!");
+
+    auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "unsupported mask op, only 'vector.create_mask' is "
+                      "currently supported");
+
+    auto constPadOp = padOp.getDefiningOp<arith::ConstantOp>();
+    if (constPadOp &&
+        constPadOp.getValue() == rewriter.getZeroAttr(tileElementType))
+      return rewriter.notifyMatchFailure(
+          tileLoadOp, "op has constant zero pad, needs zero pad pattern");
+
+    auto numRows = createMaskOp.getOperands()[0];
+    auto numCols = createMaskOp.getOperands()[1];
+
+    VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto numColsOp =
+        rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+
+    // Create 'arm_sme.get_tile' op.
+    auto tileId = rewriter.create<arm_sme::GetTileID>(
+        loc, rewriter.getIntegerType(tileElementWidth));
+
+    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
+    // use as input tile to 'arm_sme.load_tile_slice' ops.
+    auto tile =
+        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+
+    // Create a loop that loads each ZA tile slice from memory.
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+    auto vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+    auto forOp =
+        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+
+    rewriter.setInsertionPointToStart(forOp.getBody());
+
+    auto tileSliceIndex = forOp.getInductionVar();
+
+    auto rowIsActive = rewriter.create<arith::CmpIOp>(
+        loc, arith::CmpIPredicate::ult, tileSliceIndex, numRows);
+
+    SmallVector<Value> memrefIndices;
+    getMemrefIndices(tileLoadOp.getIndices(),
+                     tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
+                     numTileSlices, memrefIndices, loc, rewriter);
+
+    // Splat pad into 1-D vector matching type of tile slice.
+    auto pad1DOp = rewriter.create<vector::SplatOp>(loc, tileSliceType, padOp);
+
+    Operation *slice = rewriter.create<scf::IfOp>(
+        loc, rowIsActive,
+        [&](OpBuilder &b, Location loc) {
+          // If the row is active, emit a masked load where the predicate is
+          // 'numCols'. Pad is used for inactive elements, taken from
+          // passthru.
+          auto loadSlice = rewriter.create<vector::MaskedLoadOp>(
+              loc, tileSliceType, tileLoadOp.getBase(), memrefIndices,
+              numColsOp, /*passthru=*/pad1DOp);
+          rewriter.create<scf::YieldOp>(loc, loadSlice->getResult(0));
+        },
+        [&](OpBuilder &b, Location loc) {
+          // Inactive rows are filled with pad.
+          rewriter.create<scf::YieldOp>(loc, pad1DOp.getResult());
+        });
+
+    // TODO: If the load is vertical the transpose can't be done in-flight with
+    // a regular (SVE) maskedload. Propagate layout to
+    // 'arm_sme.move_vector_to_tile_slice' below once it supports layout. This
+    // is currently broken.
+
+    // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
+    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, slice->getResult(0), tile, tileSliceIndex,
+        tileLoadOp.getLayout());
+
+    rewriter.setInsertionPointAfter(forOp);
+
+    // Replace 'arm_sme.tile_load' with the tile.
+    rewriter.replaceOp(tileLoadOp, tile);
+
+    return success();
+  }
+};
+
 /// Lower `arm_sme.tile_store` to a loop over the tile slices and store each
 /// slice using `arm_sme.store_tile_slice`.
 ///
@@ -265,7 +513,8 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
                TileVectorPrintOpConversion>(patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 3fb320c0d219e60..4906812032ae903 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -33,6 +33,62 @@ func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(
+// CHECK-SAME:                                                          %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[TILEZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[TILEZERO]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0 : i32
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
+// CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
+// CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
+// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
+// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[ROW_IS_ACTIVE:.*]] = arith.cmpi ult, %[[TILE_SLICE_INDEX]], %[[NUM_ROWS]] : index
+// CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
+// CHECK:             %[[SLICE:.*]] = scf.if %[[ROW_IS_ACTIVE]] -> (vector<[4]xi32>) {
+// CHECK:               %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[NUM_COLS]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
+// CHECK:               scf.yield %[[LOAD_SLICE]] : vector<[4]xi32>
+// CHECK:             } else {
+// CHECK:               scf.yield %[[PAD_1D]] : vector<[4]xi32>
+// CHECK:             }
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.tile_store
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
new file mode 100644
index 000000000000000..fe40dd13ce2912f
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -0,0 +1,237 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+// TODO: replace with vector.print <str> once #68695 lands.
+func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) attributes { enable_arm_streaming_ignore } {
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %str_bytes = llvm.getelementptr %str[%c0, %c0]
+    : (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+// Vector load.
+func.func @transfer_read_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c4 = arith.constant 4 : index
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad {in_bounds=[true, true]} :
+    memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load + transpose.
+func.func @transfer_read_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %pad = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %A[%base1, %base2], %pad
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0 : vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero.
+func.func @transfer_read_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and pad of zero + transpose.
+func.func @transfer_read_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant 0.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad.
+func.func @transfer_read_2d_mask_non_zero_pad(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Vector load with mask and non-zero pad + transpose.
+func.func @transfer_read_2d_mask_non_zero_pad_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %pad = arith.constant -42.0 : f32
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %0 = vector.transfer_read %A[%base1, %base2], %pad, %mask
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]}
+      : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_f32 = arith.constant 1.0 : f32
+  %c10_f32 = arith.constant 10.0 : f32
+
+  %A = memref.alloc(%d0, %d1) : memref<?x?xf32>
+
+  %init = arith.constant 0.0 : f32
+  scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
+    scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
+      memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
+      %inner_val_next = arith.addf %inner_val, %c1_f32 : f32
+      scf.yield %inner_val_next : f32
+    }
+    %val_next = arith.addf %val, %c10_f32 : f32
+    scf.yield %val_next : f32
+  }
+
+  return %A : memref<?x?xf32>
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+
+  // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
+  // non-zero offsets while remaining inbounds.
+  %vscale = vector.vscale
+  %svl_s = arith.muli %c4, %vscale : index
+  %svl_s_plus_two = arith.addi %svl_s, %c2 : index
+
+  %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
+
+  // 1.a. Read 2D vector from 2D memref.
+  //
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 22, 23
+  // CHECK-NEXT: ( 30, 31, 32, 33
+  call @transfer_read_2d(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 1.b. Same as 1.a., but with non-zero offsets.
+  //
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 12, 13, 14, 15
+  // CHECK-NEXT: ( 22, 23, 24, 25
+  // CHECK-NEXT: ( 32, 33, 34, 35
+  // CHECK-NEXT: ( 42, 43, 44, 45
+  call @transfer_read_2d(%A, %c1, %c2) : (memref<?x?xf32>, index, index) -> ()
+
+  // 2. Same as 1.a., but with mask and a pad of constant zero.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 0
+  // CHECK-NEXT: ( 10, 11, 12, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  call @transfer_read_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 3. Same as 1.a., but with mask and non-zero pad.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, -42
+  // CHECK-NEXT: ( 10, 11, 12, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  call @transfer_read_2d_mask_non_zero_pad(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 4. Same as 1.a., but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, 20, 30
+  // CHECK-NEXT: ( 1, 11, 21, 31
+  // CHECK-NEXT: ( 2, 12, 22, 32
+  // CHECK-NEXT: ( 3, 13, 23, 33
+  call @transfer_read_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Same as 2., but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, 0, 0
+  // CHECK-NEXT: ( 1, 11, 0, 0
+  // CHECK-NEXT: ( 2, 12, 0, 0
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  call @transfer_read_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Same as 3, but transpose the result.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 10, -42, -42
+  // CHECK-NEXT: ( 1, 11, -42, -42
+  // CHECK-NEXT: ( 2, 12, -42, -42
+  // CHECK-NEXT: ( -42, -42, -42, -42
+  call @transfer_read_2d_mask_non_zero_pad_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  memref.dealloc %A : memref<?x?xf32>
+
+  return
+}
+
+llvm.mlir.global internal constant @tile_begin("TILE BEGIN:    \0A\00")

>From 0bba1701313b116693a7243826169ff6670c4c31 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:16:54 +0000
Subject: [PATCH 6/9] [mlir][ArmSME] Add optional mask operand to tile_store

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 37 +++++++++++++++++--
 .../VectorToArmSME/VectorToArmSME.cpp         |  4 +-
 mlir/test/Dialect/ArmSME/invalid.mlir         | 14 +++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       |  9 +++++
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 14 +++++++
 5 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index e35725934315bb2..85cbe22acad6fd3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -322,7 +322,18 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
       "attr-dict `:` type($base) `,` type($result)";
 }
 
-def TileStoreOp : ArmSME_Op<"tile_store"> {
+def TileStoreOp : ArmSME_Op<"tile_store", [
+  AttrSizedOperandSegments,
+  TypesMatchWith<
+    "mask has i1 element type and same shape as value to store (if present)",
+    "valueToStore", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").setElementType(IntegerType::get($_self.getContext(), 1)))",
+    "!getMask() || std::equal_to<>()"
+  >
+]> {
   let summary = "Tile store operation";
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
@@ -333,6 +344,11 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     rank 2 with dynamic dimensions, since the operation is scalable, and the
     element type must be a scalar that matches the element type of the result.
 
+    An optional SSA value `mask` may be specified to mask out elements written
+    to the MemRef. The `mask` type is an `i1` vector of the same shape as the
+    vector type that matches how elements are written into the MemRef. Elements
+    whose corresponding mask element is `0` are masked out.
+
     Example 1: Store an 8-bit element ZA tile with horizontal (default) layout to memory (ZA0.B).
     ```mlir
     arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -347,10 +363,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     ```mlir
     arm_sme.tile_store %tile, %base[%c0, %c0] layout<horizontal> : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
+
+    Example 4: Masked store a int 32-bit element ZA tile with vertical layout to memory.
+    ```mlir
+    arm_sme.tile_store %tile, %base[%c0, %c0], %mask layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    ```
   }];
   let arguments = (ins SMETile:$valueToStore,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
-    Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
+    Variadic<Index>:$indices, Optional<AnyVector>:$mask,
+    ArmSME_TileSliceLayoutAttr:$layout
   );
   let extraClassDeclaration = [{
     MemRefType getMemRefType() {
@@ -361,9 +383,16 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
+  let builders = [
+    OpBuilder<(ins "Value":$valueToStore, "Value":$base,
+                   "ValueRange":$indices), [{
+      build($_builder, $_state, valueToStore, base, indices, {});
+    }]>,
+  ];
+
   let assemblyFormat =
-    "$valueToStore `,` $base `[` $indices `]` (`layout` `` $layout^)? attr-dict "
-      "`:` type($base) `,` type($valueToStore)";
+    "$valueToStore `,` $base `[` $indices `]` (`,` $mask^)? (`layout` `` $layout^)?"
+      "attr-dict `:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 02a5bc64fa52c04..0cc5732c9212d18 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -157,8 +157,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
+        writeOp.getMask());
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 60350a888c88441..7a2550b8576d744 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -142,6 +142,20 @@ func.func @arm_sme_tile_load__bad_mask_type(%src : memref<?x?xf64>, %pad : f64,
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store__bad_mask_type(%tile : vector<[16]x[16]xi8>, %mask : vector<[1]x[1]xi1>, %dest : memref<?x?xi8>) {
+  %c0 = arith.constant 0 : index
+  // expected-note at -2 {{prior use here}}
+  // expected-error at +1 {{use of value '%mask' expects different type than prior uses: 'vector<[16]x[16]xi1>' vs 'vector<[1]x[1]xi1>}}
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // arm_sme.load_tile_slice
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index f0704a75ed2fc3a..c0c5c539f3f083d 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -624,6 +624,15 @@ func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memre
 
 // -----
 
+func.func @arm_sme_tile_store_with_mask_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask : vector<[4]x[4]xi1>) {
+  // CHECK: arm_sme.tile_store {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
 /// Layout is optional and horizontal is the default, verify it's still parsed.
 func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_store %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 80ca3d3b8281321..f9251edbe658b63 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -323,6 +323,20 @@ func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?
 
 // -----
 
+// CHECK-LABEL: func.func @transfer_write_2d_with_mask_f64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xf64>,
+// CHECK-SAME:                                             %[[MASK:.*]]: vector<[2]x[2]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>, %mask : vector<[2]x[2]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.

>From e5e3e1405051a88a8490c5e13f8c29645d2bdf72 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:39:18 +0000
Subject: [PATCH 7/9] [mlir][ArmSME] Add mask operand to store_tile_slice

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       |  28 +++--
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |   9 +-
 .../Transforms/LegalizeForLLVMExport.cpp      |  28 ++---
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |   3 +-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir |  76 ++++++------
 mlir/test/Dialect/ArmSME/invalid.mlir         |  14 +++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 114 +++++++++---------
 7 files changed, 151 insertions(+), 121 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 85cbe22acad6fd3..36fc4b9a3972836 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -462,7 +462,16 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 }
 
-def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
+def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
+  TypesMatchWith<
+    "mask has i1 element type and same shape as tile slice",
+    "tile", "mask",
+    "VectorType("
+      "VectorType::Builder("
+        "::llvm::cast<mlir::VectorType>($_self)"
+      ").dropDim(0).setElementType(IntegerType::get($_self.getContext(), 1))"
+    ")">
+]> {
   let summary = "Tile slice store operation";
   let description = [{
     Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
@@ -477,22 +486,27 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the input tile.
 
+    An SSA value `mask` specifies to mask out elements written to the MemRef.
+    The `mask` type is an `i1` vector with a shape that matches how elements
+    are written to the MemRef.
+
     Example 1: Store vector<[16]xi8> horizontal (default) tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] : vector<[16]x[16]xi8>, vector<[16]xi1>, memref<?x?xi8>
     ```
 
     Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[4]x[4]xf32>, vector<[4]xi1>, memref<?x?xf32>
     ```
 
     Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %base[%c0] layout<vertical> : vector<[1]x[1]xi128>, vector<[1]xi1>, memref<?x?xi128>
     ```
   }];
-  let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+  let arguments = (ins
+    SMETile:$tile, Index:$tile_slice_index, AnyVector:$mask,
     Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
     Variadic<Index>:$indices, ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -506,8 +520,8 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]` (`layout` `` $layout^)?
-      attr-dict `:` type($base) `,` type($tile)
+    $tile `,` $tile_slice_index `,` $mask `,` $base `[` $indices `]` (`layout` `` $layout^)?
+      attr-dict `:` type($base) `,` type($mask) `,` type($tile)
   }];
 }
 
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 75b7b8acdd190c6..e72064651c5cae3 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -437,6 +437,12 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
+    // Create an 'all true' predicate for the tile slice.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
     getMemrefIndices(tileStoreOp.getIndices(),
@@ -444,7 +450,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
+        allTruePredicate, tileStoreOp.getBase(), memrefIndices,
+        tileStoreOp.getLayout());
 
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 86f245d82b16c77..bbfe41a34e15057 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -278,13 +278,7 @@ struct StoreTileSliceToArmSMELowering
     auto tileSliceI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), tileSlice);
 
-    // Create all active predicate mask.
-    auto one = rewriter.create<arith::ConstantOp>(
-        loc, rewriter.getI1Type(),
-        rewriter.getIntegerAttr(rewriter.getI1Type(), 1));
-    auto predTy = VectorType::get(tileType.getShape()[0], rewriter.getI1Type(),
-                                  /*scalableDims=*/{true});
-    auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
+    auto maskOp = storeTileSliceOp.getMask();
 
     Value tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
@@ -295,23 +289,23 @@ struct StoreTileSliceToArmSMELowering
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       }
     } else {
@@ -320,23 +314,23 @@ struct StoreTileSliceToArmSMELowering
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
         break;
       }
     }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 4906812032ae903..55ea56f42c96ed9 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -104,8 +104,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
 // CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK:           %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index 30ddb3c46860187..8fdcf69958244f3 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -209,8 +209,8 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // CHECK-LABEL:   func.func @arm_sme_store_tile_slice_hor_i8(
 // CHECK-SAME:                                               %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                               %[[TILE_SLICE_INDEX:.*]]: index,
+// CHECK-SAME:                                               %[[MASK:.*]]: vector<[16]xi1>,
 // CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>) {
-// CHECK:           %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
@@ -221,12 +221,12 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
 // CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -234,9 +234,9 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -244,9 +244,9 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -254,9 +254,9 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -264,9 +264,9 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
 // CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -274,9 +274,9 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -284,9 +284,9 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
 // CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -294,9 +294,9 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
 // CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -304,9 +304,9 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
 // CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
@@ -314,9 +314,9 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
 // CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
@@ -324,9 +324,9 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
@@ -334,9 +334,9 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
@@ -344,9 +344,9 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
@@ -354,9 +354,9 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
 // CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
@@ -364,9 +364,9 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
@@ -374,9 +374,9 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
 // CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
@@ -384,9 +384,9 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
 // CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
@@ -394,9 +394,9 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
 // CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index 7a2550b8576d744..c29ae0581d39203 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -168,3 +168,17 @@ func.func @arm_sme_load_tile_slice__bad_mask_type(%src : memref<?x?xi8>, %mask :
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[2]xi1>, vector<[16]x[16]xi8>
   return
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+
+// -----
+
+func.func @arm_sme_store_tile_slice__bad_mask_type(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi8>) -> () {
+  %c0 = arith.constant 0 : index
+  // expected-error at +1 {{op failed to verify that mask has i1 element type and same shape as tile slice}}
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[8]xi1>, vector<[16]x[16]xi8>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index c0c5c539f3f083d..640ca3835e88a9a 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -823,173 +823,173 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
 /// Layout is optional and horizontal is the default, verify it's still parsed.
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
-  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}}, {{.*}}, %{{.*}}[{{.*}}] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] layout<horizontal> : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<horizontal> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
 

>From f5bbd5a8006545d8350b33a4d3e656c56192b927 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 15 Oct 2023 13:57:59 +0000
Subject: [PATCH 8/9] [mlir][ArmSME] Add support for lowering masked tile_store
 ops

This patch extends ArmSMEToSCF to support lowering of masked tile_store
ops. Only masks created by 'vector.create_mask' are currently supported.

Example:

  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>,
vector<[4]x[4]xi32>

Produces:

  %num_rows = arith.constant 3 : index
  %num_cols = vector.create_mask %c2 : vector<[4]xi1>
  scf.for %slice_idx = %c0 to %num_rows step %c1
    arm_sme.store_tile_slice %tile, %slice_idx, %num_cols, %dest[%slice_idx, %c0]
      : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
---
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    | 65 ++++++++++++-------
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           | 25 ++++++-
 2 files changed, 66 insertions(+), 24 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index e72064651c5cae3..86d1172ac4957b1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -420,38 +420,59 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     auto tileType = tileStoreOp.getVectorType();
     auto tileElementType = tileType.getElementType();
 
-    // Create a loop that stores each ZA tile slice from memory.
+    auto predicateType =
+        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
+
+    Value maskCols;
+    Value upperBound;
+    auto maskOp = tileStoreOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return rewriter.notifyMatchFailure(
+            tileStoreOp, "unsupported mask op, only 'vector.create_mask' is "
+                         "currently supported");
+
+      auto numRows = createMaskOp.getOperands()[0];
+      auto numCols = createMaskOp.getOperands()[1];
+
+      upperBound = numRows;
+      maskCols =
+          rewriter.create<vector::CreateMaskOp>(loc, predicateType, numCols);
+    } else {
+      // Store all tile slices if no mask.
+      auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
+          loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
+      auto vscale =
+          rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+      // This describes both the number of ZA tile slices and the number of
+      // elements in a vector of SVL bits for a given element type (SVL_B,
+      // SVL_H,
+      // ..., SVL_Q).
+      auto numTileSlices =
+          rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
+
+      upperBound = numTileSlices;
+      // Create an 'all true' predicate for the tile slice.
+      maskCols = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(predicateType, true));
+    }
+
+    // Create a loop that stores each (active) active ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
-        loc, arm_sme::getSMETileSliceMinNumElts(tileElementType));
-    auto vscale =
-        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
     auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    // This describes both the number of ZA tile slices and the number of
-    // elements in a vector of SVL bits for a given element type (SVL_B, SVL_H,
-    // ..., SVL_Q).
-    auto numTileSlices =
-        rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
-    auto forOp =
-        rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
 
     rewriter.setInsertionPointToStart(forOp.getBody());
 
-    // Create an 'all true' predicate for the tile slice.
-    auto predicateType =
-        VectorType::get(tileType.getDimSize(1), rewriter.getI1Type(), true);
-    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
-        loc, DenseElementsAttr::get(predicateType, true));
-
     SmallVector<Value> memrefIndices;
     auto tileSliceIndex = forOp.getInductionVar();
     getMemrefIndices(tileStoreOp.getIndices(),
                      tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
-                     numTileSlices, memrefIndices, loc, rewriter);
+                     upperBound, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
-        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        allTruePredicate, tileStoreOp.getBase(), memrefIndices,
-        tileStoreOp.getLayout());
+        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+        tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
 
     return success();
   }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 55ea56f42c96ed9..58c6998870edd98 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -102,9 +102,9 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
 // CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
-// CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
+// CHECK-DAG:     %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK:           %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[PTRUE_S]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
@@ -123,6 +123,27 @@ func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x
   return
 }
 
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store_hor_with_mask(
+// CHECK-SAME:                                             %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[NUM_ROWS:.*]] = arith.constant 3 : index
+// CHECK-DAG:     %[[NUM_COLS:.*]] = vector.create_mask %c2 : vector<[4]xi1>
+// CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_ROWS]] step %[[C1]] {
+// CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
+// CHECK-NEXT:      arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[NUM_COLS]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
+  arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.print
 //===----------------------------------------------------------------------===//

>From 14aac4339638dedc0ed18cc5ab35a346cda32e79 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Mon, 16 Oct 2023 11:43:32 +0000
Subject: [PATCH 9/9] [mlir][ArmSME] Lower transfer_write + transpose to
 vertical store

This patch extends the lowering of vector.transfer_write in
VectorToArmSME to support in-flight transpose via SME vertical store.
---
 .../VectorToArmSME/VectorToArmSME.cpp         |  47 ++++-
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |  42 +++++
 .../CPU/ArmSME/test-transfer-write-2d.mlir    | 174 ++++++++++++++++++
 3 files changed, 260 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir

diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0cc5732c9212d18..40e8378306bbf21 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -136,13 +136,31 @@ struct TransferReadToArmSMELowering
 
 /// Conversion pattern for vector.transfer_write.
 ///
-///   vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>,
-///                                                      memref<?x?xi8>
+/// ---
+///
+/// Example 1: op with identity permutation map to horizontal
+///            arm_sme.tile_store:
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
 ///
 /// is converted to:
 ///
 ///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
 ///                                                   vector<[16]x[16]xi8>
+/// ---
+///
+/// Example 2: op with transpose permutation map to vertical arm_sme.tile_store
+///            (in-flight transpose):
+///
+///   vector.transfer_write %vector, %source[%c0, %c0]
+///     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+///      in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
+///
+/// is converted to:
+///
+///   arm_sme.tile_store %vector, %source[%c0, %c0] layout<vertical>
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -153,12 +171,35 @@ struct TransferWriteToArmSMELowering
     if (!arm_sme::isValidSMETileVectorType(vType))
       return failure();
 
+    assert(writeOp.getTransferRank() == 2 &&
+           "expected a permutation_map with result dims of the same rank as "
+           "the vector type");
+
     if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
       return failure();
 
+    // Out-of-bounds dims are not supported.
+    if (writeOp.hasOutOfBoundsDim())
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "not inbounds transfer write");
+
+    arm_sme::TileSliceLayout layout;
+
+    AffineExpr d0, d1;
+    bindDims(writeOp.getContext(), d0, d1);
+    AffineMap map = writeOp.getPermutationMap();
+    if (map.isIdentity())
+      layout = arm_sme::TileSliceLayout::Horizontal;
+    else if (map == AffineMap::get(map.getNumDims(), 0, {d1, d0},
+                                   writeOp.getContext()))
+      layout = arm_sme::TileSliceLayout::Vertical;
+    else
+      return rewriter.notifyMatchFailure(writeOp,
+                                         "unsupported permutation map");
+
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
         writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
-        writeOp.getMask());
+        writeOp.getMask(), layout);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index f9251edbe658b63..e1a8a9ff9bf10a8 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -337,6 +337,37 @@ func.func @transfer_write_2d_with_mask_f64(%vector : vector<[2]x[2]xf64>, %dest
 
 // -----
 
+/// in-flight transpose via vertical store.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_i64(
+// CHECK-SAME:                                             %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
+// CHECK-SAME:                                             %[[DEST:.*]]: memref<?x?xi64>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] layout<vertical> : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @transfer_write_2d_transpose_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
+  return
+}
+
+// -----
+
+/// in-flight transpose via vertical store with mask.
+
+// CHECK-LABEL: func.func @transfer_write_2d_transpose_with_mask_bf16(
+// CHECK-SAME:                                                        %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
+// CHECK-SAME:                                                        %[[DEST:.*]]: memref<?x?xbf16>,
+// CHECK-SAME:                                                        %[[MASK:.*]]: vector<[8]x[8]xi1>) {
+// CHECK:         %[[C0:.*]] = arith.constant 0 : index
+// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]], %[[MASK]] layout<vertical> : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>, %mask : vector<[8]x[8]xi1>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
+  return
+}
+
+// -----
+
 // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero'
 // lowering only occurs for vector types of correct rank, shape, element size
 // and number of scalable dims.
@@ -398,6 +429,17 @@ func.func @transfer_write_2d__fixed(%vector : vector<16x16xi8>, %dest : memref<?
   return
 }
 
+// -----
+
+// CHECK-LABEL: @transfer_write_2d__out_of_bounds
+// CHECK: vector.transfer_write
+// CHECK-NOT: arm_sme.tile_store
+func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vector, %dest[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // vector.broadcast
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
new file mode 100644
index 000000000000000..1cb685d7bc27cd6
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -0,0 +1,174 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:  -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:  -e %{entry_point} -entry-point-result=void \
+// DEFINE:  -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+// TODO: replace with vector.print <str> once #68695 lands.
+func.func @print_str(%str: !llvm.ptr<array<17 x i8>>) attributes { enable_arm_streaming_ignore } {
+  %c0 = llvm.mlir.constant(0 : index) : i64
+  %str_bytes = llvm.getelementptr %str[%c0, %c0]
+    : (!llvm.ptr<array<17 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%str_bytes) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+// Vector store.
+func.func @transfer_write_2d(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c0 = arith.constant 0.0 : f32
+  %zero = vector.splat %c0 : vector<[4]x[4]xf32>
+  vector.transfer_write %zero, %A[%base1, %base2] {in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Masked vector store.
+func.func @transfer_write_2d_mask(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c0 = arith.constant 0.0 : f32
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %mask = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1>
+  %zero = vector.splat %c0 : vector<[4]x[4]xf32>
+  vector.transfer_write %zero, %A[%base1, %base2], %mask {in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Vector store + transpose.
+func.func @transfer_write_2d_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2] {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Masked vector store + transpose.
+func.func @transfer_write_2d_mask_transposed(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %mask = vector.create_mask %c4, %c2 : vector<[4]x[4]xi1>
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  vector.transfer_write %0, %A[%base1, %base2], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>, in_bounds=[true, true]} :
+    vector<[4]x[4]xf32>, memref<?x?xf32>
+  return
+}
+
+// Vector load + print.
+func.func @load_and_print(%A : memref<?x?xf32>, %base1: index, %base2: index) {
+  %tile_begin_str = llvm.mlir.addressof @tile_begin : !llvm.ptr<array<17 x i8>>
+
+  %0 = vector.load %A[%base1, %base2] : memref<?x?xf32>, vector<[4]x[4]xf32>
+
+  func.call @print_str(%tile_begin_str) : (!llvm.ptr<array<17 x i8>>) -> ()
+  vector.print %0: vector<[4]x[4]xf32>
+
+  return
+}
+
+// Allocate heap memory of size 'd0' x 'd1' and initialize.
+//
+// Example:
+//
+// initialize_memory(%c4, %c5)
+//
+//    0,  1,  2,  3,  4
+//   10, 11, 12, 13, 14
+//   20, 21, 22, 23, 24
+//   30, 31, 32, 33, 34
+//
+// Returns dynamic memref. It's the callers responsiblity to free the returned
+// memref.
+func.func @initialize_memory(%d0 : index, %d1 : index) -> memref<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_f32 = arith.constant 1.0 : f32
+  %c10_f32 = arith.constant 10.0 : f32
+
+  %A = memref.alloc(%d0, %d1) : memref<?x?xf32>
+
+  %init = arith.constant 0.0 : f32
+  scf.for %i = %c0 to %d0 step %c1 iter_args(%val = %init) -> f32 {
+    scf.for %j = %c0 to %d1 step %c1 iter_args(%inner_val = %val) -> f32 {
+      memref.store %inner_val, %A[%i, %j] : memref<?x?xf32>
+      %inner_val_next = arith.addf %inner_val, %c1_f32 : f32
+      scf.yield %inner_val_next : f32
+    }
+    %val_next = arith.addf %val, %c10_f32 : f32
+    scf.yield %val_next : f32
+  }
+
+  return %A : memref<?x?xf32>
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+
+  // Allocate enough memory to load a 32-bit tile plus a tiny bit more to test
+  // non-zero offsets while remaining inbounds.
+  %vscale = vector.vscale
+  %svl_s = arith.muli %c4, %vscale : index
+  %svl_s_plus_two = arith.addi %svl_s, %c2 : index
+
+  // 1. Initialize memory
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 22, 23
+  // CHECK-NEXT: ( 30, 31, 32, 33
+  %A = call @initialize_memory(%svl_s_plus_two, %svl_s_plus_two) : (index, index) -> memref<?x?xf32>
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 2. Write 2-D vector of zeroes to 1. at offset [2, 2].
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 10, 11, 12, 13
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d(%A, %c2, %c2) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 3. Write 2-D vector of zeroes to 2. but with mask (nrows=2, ncols=3).
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 0, 3
+  // CHECK-NEXT: ( 0, 0, 0, 13
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d_mask(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 4. Reload 3. + store + transpose.
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 3, 13, 0, 0
+  call @transfer_write_2d_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  // 5. Reload 4. + store + transpose but with mask (nrows=4, ncols=2).
+  // The mask applies after permutation
+  // CHECK-LABEL: TILE BEGIN:
+  // CHECK-NEXT: ( 0, 0, 20, 30
+  // CHECK-NEXT: ( 0, 0, 21, 31
+  // CHECK-NEXT: ( 20, 21, 0, 0
+  // CHECK-NEXT: ( 30, 31, 0, 0
+  call @transfer_write_2d_mask_transposed(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+  call @load_and_print(%A, %c0, %c0) : (memref<?x?xf32>, index, index) -> ()
+
+  memref.dealloc %A : memref<?x?xf32>
+
+  return
+}
+
+llvm.mlir.global internal constant @tile_begin("TILE BEGIN:    \0A\00")



More information about the Mlir-commits mailing list