[Mlir-commits] [mlir] Arm sme vector transpose (PR #66760)

Cullen Rhodes llvmlistbot at llvm.org
Tue Sep 19 03:39:28 PDT 2023


https://github.com/c-rhodes created https://github.com/llvm/llvm-project/pull/66760

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on https://github.com/llvm/llvm-project/pull/66758.

>From 463c7c24d4c7ccb5237829b492d1f1a0ff3ea6e2 Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Sun, 17 Sep 2023 11:50:14 +0000
Subject: [PATCH 1/2] [mlir][ArmSME] Support vertical layout in load and store
 ops

In SME a ZA tile slice is a one-dimensional set of horizontally or
vertically contiguous elements within a ZA tile. Currently the load and
store ops only support horizontal tile slices. This patch adds a tile
slice layout attribute to the load and store ops to support both
horizontal and vertical tile slices.

When lowering from Vector dialect horizontal layout is the default.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |   5 +
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td | 116 ++--
 .../mlir/Dialect/ArmSME/IR/CMakeLists.txt     |   7 +
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  18 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |  14 +-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |  12 +
 mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt     |   1 +
 .../Transforms/LegalizeForLLVMExport.cpp      |  93 ++-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  42 +-
 .../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir |   6 +-
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 401 +++++++++++++
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 540 +++++++++++++++---
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |  16 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir | 110 ++++
 14 files changed, 1193 insertions(+), 188 deletions(-)
 create mode 100644 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index d1ed02abfd5c552..f947fc8fe1631b8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -21,6 +21,11 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
+
 #include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc"
 
 #define GET_OP_CLASSES
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..1a4984f3bd6ba27 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -14,6 +14,7 @@
 #ifndef ARMSME_OPS
 #define ARMSME_OPS
 
+include "mlir/IR/EnumAttr.td"
 include "mlir/IR/OpBase.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
@@ -36,6 +37,7 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
   let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let useDefaultAttributePrinterParser = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -83,6 +85,24 @@ def TileElementWidthMatchesTileID : TypesMatchWith<
                   "::llvm::cast<VectorType>($_self).getElementType())"
                   ".getWidth())">;
 
+//===----------------------------------------------------------------------===//
+// ArmSME attr definitions
+//===----------------------------------------------------------------------===//
+
+def TileSliceLayout : I32EnumAttr<"TileSliceLayout", "Layout of a tile slice", [
+  I32EnumAttrCase<"Horizontal", 0, "hor">,
+  I32EnumAttrCase<"Vertical", 1, "ver">,
+]> {
+  let cppNamespace = "::mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+/// An attribute that specifies the layout of a tile slice in a tile.
+def ArmSME_TileSliceLayoutAttr : EnumAttr<ArmSME_Dialect, TileSliceLayout,
+                                          "layout"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME op definitions
 //===----------------------------------------------------------------------===//
@@ -239,27 +259,33 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   let summary = "Tile load operation";
   let description = [{
     Loads a 2D SME "virtual tile" from memory defined by a base and indices,
-    with the shape defined by the 2D scalable vector type of the result tile.
-    The slice of memory must be contiguous. The memref must be either rank 1 or
-    rank 2 with dynamic dimensions, since the operation is scalable, and the
-    element type must be a scalar that matches the element type of the result.
+    with the shape defined by the 2D scalable vector type of the result tile. A
+    tile slice layout attribute specifies whether the slices of the tile being
+    loaded are horizontal or vertical. The slice of memory must be contiguous.
+    The memref must be either rank 1 or rank 2 with dynamic dimensions, since
+    the operation is scalable, and the element type must be a scalar that
+    matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Load an 8-bit element ZA tile from memory (ZA0.B).
+    Example 1: Load an 8-bit element ZA tile with horizontal layout from memory (ZA0.B).
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a FP 32-bit element ZA tile from memory.
+    Example 2: Load a FP 32-bit element ZA tile with vertical layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile = arm_sme.tile_load <ver>, %base[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a 128-bit element ZA tile from memory.
+    Example 3: Load a 128-bit element ZA tile with horizontal layout from memory.
     ```mlir
-    %tile = arm_sme.tile_load %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile = arm_sme.tile_load <hor>, %base[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base,
       Variadic<Index>:$indices);
   let results = (outs SMETile:$result);
@@ -274,7 +300,8 @@ def TileLoadOp : ArmSME_Op<"tile_load"> {
   }];
 
   let assemblyFormat =
-      "$base `[` $indices `]` attr-dict `:` type($base) `,` type($result)";
+      "$layout `,` $base `[` $indices `]` attr-dict "
+        "`:` type($base) `,` type($result)";
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store"> {
@@ -282,27 +309,32 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
   let description = [{
     Stores a 2D SME "virtual tile" to memory defined by a base and indices,
     with the shape defined by the 2D scalable vector type of the tile being
-    stored. The slice of memory must be contiguous. The memref must be either
-    rank 1 or rank 2 with dynamic dimensions, since the operation is scalable,
-    and the element type must be a scalar that matches the element type of the
-    result.
+    stored. A tile slice layout attribute specifies whether the slices of the
+    tile being stored are horizontal or vertical. The slice of memory must be
+    contiguous.  The memref must be either rank 1 or rank 2 with dynamic
+    dimensions, since the operation is scalable, and the element type must be a
+    scalar that matches the element type of the result.
+
+    The default tile slice layout when lowering from higher-level dialects is
+    horizontal.
 
-    Example 1: Store an 8-bit element ZA tile to memory (ZA0.B).
+    Example 1: Store an 8-bit element ZA tile with horizontal layout to memory (ZA0.B).
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store a FP 32-bit element ZA tile to memory.
+    Example 2: Store a FP 32-bit element ZA tile with vertical layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.tile_store %tile, <ver>, %base[%c0, %c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a 128-bit element ZA tile to memory.
+    Example 3: Store a 128-bit element ZA tile with horizontal layout to memory.
     ```mlir
-    arm_sme.tile_store %tile, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.tile_store %tile, <hor>, %base[%c0, %c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$valueToStore,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -314,8 +346,9 @@ def TileStoreOp : ArmSME_Op<"tile_store"> {
     }
   }];
 
-  let assemblyFormat = "$valueToStore `,` $base `[` $indices `]` attr-dict "
-                       "`:` type($base) `,` type($valueToStore)";
+  let assemblyFormat =
+    "$valueToStore `,` $layout `,` $base `[` $indices `]` attr-dict "
+      "`:` type($base) `,` type($valueToStore)";
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
@@ -326,29 +359,32 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     Loads a 1D tile slice from memory into a 2D SME "virtual tile". The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is loaded to. The updated tile is returned as the result.
+    slice is loaded to. A tile slice layout attribute specifies whether the
+    tile slice being loaded at the given index is horizontal or vertical. The
+    updated tile is returned as the result.
 
     The slice of memory read is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the result.
 
-    Example 1: Load a vector<[16]xi8> tile slice from memory into tile at given index.
+    Example 1: Load a vector<[16]xi8> tile slice from memory into tile horizontally at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+    %tile_update = arm_sme.load_tile_slice <hor>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
     ```
 
-    Example 2: Load a vector<[4]xf32> tile slice from memory into tile at given index.
+    Example 2: Load a vector<[4]xf32> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
     ```
 
-    Example 3: Load a vector<[1]xi128> tile slice from memory into tile at given index.
+    Example 3: Load a vector<[1]xi128> tile slice from memory into tile vertically at given index.
     ```mlir
-    %tile_update = arm_sme.load_tile_slice %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+    %tile_update = arm_sme.load_tile_slice <ver>, %base[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
     ```
   }];
   let arguments = (ins
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to load from">:$base,
       SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index);
   let results = (outs SMETile:$result);
@@ -363,7 +399,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
   }];
 
   let assemblyFormat = [{
-    $base `[` $indices `]` `,` $tile `,` $tile_slice_index
+    $layout `,` $base `[` $indices `]` `,` $tile `,` $tile_slice_index
       attr-dict `:` type($base) `,` type($result)
   }];
 }
@@ -374,29 +410,31 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
     Stores a 1D tile slice from a 2D SME "virtual tile" into memory. The tile
     slice is defined by the dimension of the 2D scalable vector type pointed by
     the index. A tile slice index describes where in the input tile the tile
-    slice is stored from.
+    slice is stored from. A tile slice layout attribute specifies whether the
+    tile slice being stored from the given index is horizontal or vertical.
 
     The slice of memory written is defined by a base and indices and must be
     contiguous. The memref must be either rank 1 or rank 2, have dynamic
     dimensions since the operation is scalable, and the element type must be a
     scalar that matches the element type of the input tile.
 
-    Example 1: Store vector<[16]xi8> tile slice from tile at given index to memory.
+    Example 1: Store vector<[16]xi8> horizontal tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %base[%c0] : vector<[16]x[16]xi8>, memref<?x?xi8>
     ```
 
-    Example 2: Store vector<[4]xf32> tile slice from tile at given index to memory.
+    Example 2: Store vector<[4]xf32> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[4]x[4]xf32>, memref<?x?xf32>
     ```
 
-    Example 3: Store a vector<[1]xi128> tile slice from tile at given index to memory.
+    Example 3: Store a vector<[1]xi128> vertical tile slice from tile at given index to memory.
     ```mlir
-    arm_sme.store_tile_slice %tile, %tile_slice_index, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
+    arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %base[%c0] : vector<[1]x[1]xi128>, memref<?x?xi128>
     ```
   }];
   let arguments = (ins SMETile:$tile, Index:$tile_slice_index,
+      ArmSME_TileSliceLayoutAttr:$layout,
       Arg<AnyMemRef, "the reference to store to", [MemWrite]>:$base,
       Variadic<Index>:$indices);
   let extraClassDeclaration = [{
@@ -409,7 +447,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice"> {
   }];
 
   let assemblyFormat = [{
-    $tile `,` $tile_slice_index `,` $base `[` $indices `]`
+    $tile `,` $tile_slice_index `,` $layout `,` $base `[` $indices `]`
       attr-dict `:` type($base) `,` type($tile)
   }];
 }
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index d20ee65e62e7dc0..7afd0d014541687 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -4,3 +4,10 @@ add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme)
 set(LLVM_TARGET_DEFINITIONS ArmSME.td)
 mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
+
+mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls)
+mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs)
+mlir_tablegen(ArmSMEAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=arm_sme)
+mlir_tablegen(ArmSMEAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=arm_sme)
+add_public_tablegen_target(MLIRArmSMEEnumsIncGen)
+add_dependencies(mlir-headers MLIRArmSMEEnumsIncGen)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..86cabe67f2695f1 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -54,7 +54,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///
 ///  BEFORE:
 ///  ```mlir
-///  %tile = arm_sme.tile_load %src[%c0, %c0] :
+///  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] :
 ///    memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  ```
 ///
@@ -68,7 +68,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    %tile_update = arm_sme.load_tile_slice %src[%tile_slice_idx],
+///    %tile_update = arm_sme.load_tile_slice <hor>, %src[%tile_slice_idx],
 ///      %tile, %tile_slice_idx : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
@@ -116,9 +116,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(loc, tileType,
-                                              tileLoadOp.getBase(), tile,
-                                              memrefIndices, tileSliceIndex);
+    rewriter.create<arm_sme::LoadTileSliceOp>(
+        loc, tileType, tileLoadOp.getLayout(), tileLoadOp.getBase(), tile,
+        memrefIndices, tileSliceIndex);
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -134,7 +134,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///  arm_sme.tile_store %tile, %dest[%c0, %c0]
+///  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0]
 ///    : memref<?x?xi32>, vector<[4]x[4]xi32
 ///  ```
 ///
@@ -146,8 +146,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 ///  %min_svl_s = arith.constant 4 : index
 ///  %svl_s = arith.muli %min_svl_s, %vscale : index
 ///  scf.for %tile_slice_idx = %c0 to %svl_s step %c1 {
-///    arm_sme.store_tile_slice %tile, %tile_slice_idx, %dest[%tile_slice_idx]
-///      : memref<?x?xi32>, vector<[4]x[4]xi32>
+///    arm_sme.store_tile_slice %tile, %tile_slice_idx, <ver>,
+///      %dest[%tile_slice_idx] : memref<?x?xi32>, vector<[4]x[4]xi32>
 ///  }
 ///  ```
 struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
@@ -184,7 +184,7 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
                      numTileSlices, memrefIndices, loc, rewriter);
     rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
         tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
-        tileStoreOp.getBase(), memrefIndices);
+        tileStoreOp.getLayout(), tileStoreOp.getBase(), memrefIndices);
 
     return success();
   }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 0a1a087d9c8d6c7..feaec0e035ed9fd 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -65,8 +65,8 @@ namespace {
 ///
 /// is converted to:
 ///
-///   arm_sme.tile_store %vector, %source[%c0, %c0] : memref<?x?xi8>,
-///                                                   vector<[16]x[16]xi8>
+///   arm_sme.tile_store %vector, <hor>, %source[%c0, %c0]
+///     : memref<?x?xi8>, vector<[16]x[16]xi8>
 struct TransferWriteToArmSMELowering
     : public OpRewritePattern<vector::TransferWriteOp> {
   using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
@@ -81,8 +81,8 @@ struct TransferWriteToArmSMELowering
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        writeOp, writeOp.getVector(), writeOp.getSource(),
-        writeOp.getIndices());
+        writeOp, writeOp.getVector(), arm_sme::TileSliceLayout::Horizontal,
+        writeOp.getSource(), writeOp.getIndices());
     return success();
   }
 };
@@ -97,7 +97,8 @@ struct VectorLoadToArmSMELowering : public OpRewritePattern<vector::LoadOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
-        load, load.getVectorType(), load.getBase(), load.getIndices());
+        load, load.getVectorType(), arm_sme::TileSliceLayout::Horizontal,
+        load.getBase(), load.getIndices());
 
     return success();
   }
@@ -113,7 +114,8 @@ struct VectorStoreToArmSMELowering : public OpRewritePattern<vector::StoreOp> {
       return failure();
 
     rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
-        store, store.getValueToStore(), store.getBase(), store.getIndices());
+        store, store.getValueToStore(), arm_sme::TileSliceLayout::Horizontal,
+        store.getBase(), store.getIndices());
 
     return success();
   }
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 750627421215dfb..92fb146691a0beb 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,8 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 using namespace mlir;
 using namespace mlir::arm_sme;
@@ -22,13 +24,23 @@ using namespace mlir::arm_sme;
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
 
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc"
 
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+
 void ArmSMEDialect::initialize() {
+  addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.cpp.inc"
+      >();
+
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc"
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 9b6332a478ade33..79f6a46c7c5889e 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
 
   DEPENDS
   MLIRArmSMEIncGen
+  MLIRArmSMEEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRIR
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6c8843fbb4546e6..eeb822aae09180b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -204,29 +204,51 @@ struct LoadTileSliceToArmSMELowering
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-    // Create 'arm_sme.intr.ld1*.horiz' intrinsic to load ZA tile slice.
+    arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
+
+    // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
     switch (tileElementWidth) {
     default:
       llvm_unreachable("unexpected element type!");
     case 8:
-      rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, allActiveMask, ptr,
-                                                       tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(
+            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, allActiveMask, ptr,
+                                                        tileI32, tileSliceI32);
       break;
     case 16:
-      rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, allActiveMask, ptr,
-                                                       tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(
+            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, allActiveMask, ptr,
+                                                        tileI32, tileSliceI32);
       break;
     case 32:
-      rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, allActiveMask, ptr,
-                                                       tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(
+            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, allActiveMask, ptr,
+                                                        tileI32, tileSliceI32);
       break;
     case 64:
-      rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, allActiveMask, ptr,
-                                                       tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(
+            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, allActiveMask, ptr,
+                                                        tileI32, tileSliceI32);
       break;
     case 128:
-      rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, allActiveMask, ptr,
-                                                       tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(
+            loc, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, allActiveMask, ptr,
+                                                        tileI32, tileSliceI32);
       break;
     }
 
@@ -280,28 +302,50 @@ struct StoreTileSliceToArmSMELowering
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
     Value tileI32 = castTileIDToI32(tile, loc, rewriter);
+    arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
+
     switch (tileElementWidth) {
     default:
       llvm_unreachable("unexpected element type!");
     case 8:
-      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 16:
-      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 32:
-      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 64:
-      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     case 128:
-      rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-          storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      if (layout == arm_sme::TileSliceLayout::Horizontal)
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
+      else
+        rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
+            storeTileSliceOp, allActiveMask, ptr, tileI32, tileSliceI32);
       break;
     }
 
@@ -479,7 +523,12 @@ void mlir::configureArmSMELegalizeForExportTarget(
       arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
       arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
       arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
-      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_write_horiz,
       arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_za_enable,
       arm_sme::aarch64_sme_za_disable>();
   target.addLegalOp<GetTileID>();
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 9ab1d79794d7659..13c2e0479ab3957 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -convert-arm-sme-to-scf -cse -split-input-file | FileCheck %s
 
-// CHECK-LABEL: func.func @arm_sme_tile_load(
-// CHECK-SAME:                               %[[SRC:.*]]: memref<?x?xi32>) {
+// CHECK-LABEL: func.func @arm_sme_tile_load_hor(
+// CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
 // CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
@@ -11,18 +11,28 @@
 // CHECK-NEXT:    %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
+// CHECK-NEXT:      arm_sme.load_tile_slice <hor>, %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-// CHECK-LABEL: func.func @arm_sme_tile_store(
-// CHECK-SAME:                                %[[TILE:.*]]: vector<[4]x[4]xi32>,
-// CHECK-SAME:                                %[[DEST:.*]]: memref<?x?xi32>) {
+// CHECK-LABEL: @arm_sme_tile_load_ver
+// CHECK: arm_sme.load_tile_slice <ver>
+func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @arm_sme_tile_store_hor(
+// CHECK-SAME:                                    %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                                    %[[DEST:.*]]: memref<?x?xi32>) {
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -30,9 +40,19 @@ func.func @arm_sme_tile_load(%src : memref<?x?xi32>) {
 // CHECK:         %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
 // CHECK:         scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
 // CHECK:           %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
-func.func @arm_sme_tile_store(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+// CHECK:           arm_sme.store_tile_slice %[[TILE]], %[[TILE_SLICE_INDEX]], <hor>, %[[DEST]]{{\[}}%[[OFFSET]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_hor(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_tile_store_ver
+// CHECK: arm_sme.store_tile_slice {{.*}} <ver>
+func.func @arm_sme_tile_store_ver(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
index 2c26c62ad42481e..49bf188f913bc8c 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
@@ -15,7 +15,7 @@
 func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.zero : vector<[16]x[16]xi8>
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
@@ -32,7 +32,7 @@ func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
 // CHECK: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile = arm_sme.tile_load <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return %tile : vector<[16]x[16]xi8>
 }
 
@@ -46,6 +46,6 @@ func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
 // CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
 func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
new file mode 100644
index 000000000000000..c249372a013308d
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -0,0 +1,401 @@
+// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+
+// Test conversion of higher-level ArmSME ops to LLVM intrinsics.
+
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL:   func.func @arm_sme_load_tile_slice_hor_i8(
+// CHECK-SAME:                                              %[[SRC:.*]]: memref<?x?xi8>,
+// CHECK-SAME:                                              %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK:           %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
+// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
+// CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
+// CHECK: arm_sme.intr.ld1h.horiz
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
+// CHECK: arm_sme.intr.ld1w.horiz
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
+// CHECK: arm_sme.intr.ld1d.horiz
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
+// CHECK: arm_sme.intr.ld1q.horiz
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
+// CHECK: arm_sme.intr.ld1h.horiz
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
+// CHECK: arm_sme.intr.ld1h.horiz
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
+// CHECK: arm_sme.intr.ld1w.horiz
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
+// CHECK: arm_sme.intr.ld1d.horiz
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
+// CHECK: arm_sme.intr.ld1b.vert
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
+// CHECK: arm_sme.intr.ld1h.vert
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
+// CHECK: arm_sme.intr.ld1w.vert
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
+// CHECK: arm_sme.intr.ld1d.vert
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
+// CHECK: arm_sme.intr.ld1q.vert
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
+// CHECK: arm_sme.intr.ld1h.vert
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
+// CHECK: arm_sme.intr.ld1h.vert
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
+// CHECK: arm_sme.intr.ld1w.vert
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
+// CHECK: arm_sme.intr.ld1d.vert
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL:   func.func @arm_sme_store_tile_slice_hor_i8(
+// CHECK-SAME:                                               %[[TILE:.*]]: vector<[16]x[16]xi8>,
+// CHECK-SAME:                                               %[[TILE_SLICE_INDEX:.*]]: index,
+// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK:           %[[PTRUE_B:.*]] = arith.constant dense<true> : vector<[16]xi1>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
+// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
+// CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
+// CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
+// CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
+// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
+// CHECK:           "arm_sme.intr.st1b.horiz"(%[[PTRUE_B]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           return
+// CHECK:         }
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
+// CHECK: arm_sme.intr.st1w.horiz
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
+// CHECK: arm_sme.intr.st1d.horiz
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
+// CHECK: arm_sme.intr.st1q.horiz
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
+// CHECK: arm_sme.intr.st1h.horiz
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
+// CHECK: arm_sme.intr.st1w.horiz
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
+// CHECK: arm_sme.intr.st1d.horiz
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
+// CHECK: arm_sme.intr.st1b.vert
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
+// CHECK: arm_sme.intr.st1h.vert
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
+// CHECK: arm_sme.intr.st1w.vert
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
+// CHECK: arm_sme.intr.st1d.vert
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
+// CHECK: arm_sme.intr.st1q.vert
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
+// CHECK: arm_sme.intr.st1h.vert
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
+// CHECK: arm_sme.intr.st1h.vert
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
+// CHECK: arm_sme.intr.st1w.vert
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
+// CHECK: arm_sme.intr.st1d.vert
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index bae48be87b2dcdc..66518107c17bc4f 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_tile_to_vector
+//===----------------------------------------------------------------------===//
+
 func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
   // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
   %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
@@ -70,6 +74,10 @@ func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64
   return %0 : vector<[2]x[2]xf64>
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.cast_vector_to_tile
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
@@ -142,6 +150,10 @@ func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64
   return %0 : i64
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.get_tile_id
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_get_tile_id_i8() -> i8 {
@@ -182,6 +194,10 @@ func.func @arm_sme_get_tile_id_i128() -> i128 {
   return %0 : i128
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.zero
+//===----------------------------------------------------------------------===//
+
 // -----
 
 func.func @arm_sme_zero_i8() {
@@ -254,332 +270,676 @@ func.func @arm_sme_zero_f64() {
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_load
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_load_hor_i8(%src : memref<?x?xi8>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
 // -----
 
-func.func @arm_sme_tile_load_i8(%src : memref<?x?xi8>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_load_hor_i16(%src : memref<?x?xi16>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_i16(%src : memref<?x?xi16>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_tile_load_hor_i32(%src : memref<?x?xi32>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_i32(%src : memref<?x?xi32>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_load_hor_i64(%src : memref<?x?xi64>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_i64(%src : memref<?x?xi64>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_tile_load_hor_i128(%src : memref<?x?xi128>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_i128(%src : memref<?x?xi128>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_tile_load_hor_f16(%src : memref<?x?xf16>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_f16(%src : memref<?x?xf16>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_tile_load_hor_bf16(%src : memref<?x?xbf16>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_bf16(%src : memref<?x?xbf16>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_tile_load_hor_f32(%src : memref<?x?xf32>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_f32(%src : memref<?x?xf32>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_tile_load_hor_f64(%src : memref<?x?xf64>) {
+  // CHECK: arm_sme.tile_load <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %tile = arm_sme.tile_load <hor>, %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_load_f64(%src : memref<?x?xf64>) {
-  // CHECK: arm_sme.tile_load {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_tile_load_ver_i8(%src : memref<?x?xi8>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+func.func @arm_sme_tile_load_ver_i16(%src : memref<?x?xi16>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i32(%src : memref<?x?xi32>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i64(%src : memref<?x?xi64>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_i128(%src : memref<?x?xi128>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f16(%src : memref<?x?xf16>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_bf16(%src : memref<?x?xbf16>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f32(%src : memref<?x?xf32>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_load_ver_f64(%src : memref<?x?xf64>) {
+  // CHECK: arm_sme.tile_load <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile = arm_sme.tile_load <ver>, %src[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.tile_store
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_tile_store_hor_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+func.func @arm_sme_tile_store_hor_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+func.func @arm_sme_tile_store_hor_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+func.func @arm_sme_tile_store_hor_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
+func.func @arm_sme_tile_store_hor_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+func.func @arm_sme_tile_store_hor_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+func.func @arm_sme_tile_store_hor_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+func.func @arm_sme_tile_store_hor_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+func.func @arm_sme_tile_store_hor_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
   // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.tile_store %tile, <hor>, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_tile_store_ver_i8(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @arm_sme_tile_store_ver_i16(%tile : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+func.func @arm_sme_tile_store_ver_i32(%tile : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+func.func @arm_sme_tile_store_ver_i64(%tile : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+func.func @arm_sme_tile_store_ver_i128(%tile : vector<[1]x[1]xi128>, %dest : memref<?x?xi128>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+func.func @arm_sme_tile_store_ver_f16(%tile : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @arm_sme_tile_store_ver_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+func.func @arm_sme_tile_store_ver_f32(%tile : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_load_tile_slice_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @arm_sme_tile_store_ver_f64(%tile : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
+  // CHECK: arm_sme.tile_store {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.tile_store %tile, <ver>, %dest[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
+//===----------------------------------------------------------------------===//
+// arm_sme.load_tile_slice
+//===----------------------------------------------------------------------===//
+
 // -----
 
-func.func @arm_sme_load_tile_slice_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
-  // CHECK: arm_sme.load_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  %tile_update = arm_sme.load_tile_slice %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <hor>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <hor>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+  // CHECK: arm_sme.load_tile_slice <ver>, {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  %tile_update = arm_sme.load_tile_slice <ver>, %src[%c0], %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.store_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
 
 // -----
 
-func.func @arm_sme_store_tile_slice_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
   // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
   %c0 = arith.constant 0 : index
-  arm_sme.store_tile_slice %tile, %tile_slice_index, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <hor>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
 
 // -----
 
+func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %dest : memref<?x?xi8>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi8>, vector<[16]x[16]xi8>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %dest : memref<?x?xi16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi16>, vector<[8]x[8]xi16>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %dest : memref<?x?xi32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi32>, vector<[4]x[4]xi32>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %dest : memref<?x?xi64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi64>, vector<[2]x[2]xi64>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %dest : memref<?x?xi128>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xi128>, vector<[1]x[1]xi128>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %dest : memref<?x?xf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf16>, vector<[8]x[8]xf16>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %dest : memref<?x?xbf16>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %dest : memref<?x?xf32>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf32>, vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+  return
+}
+
+// -----
+
+func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %dest : memref<?x?xf64>) -> () {
+  // CHECK: arm_sme.store_tile_slice {{.*}} : memref<?x?xf64>, vector<[2]x[2]xf64>
+  %c0 = arith.constant 0 : index
+  arm_sme.store_tile_slice %tile, %tile_slice_index, <ver>, %dest[%c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
+  return
+}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.move_vector_to_tile_slice
+//===----------------------------------------------------------------------===//
+
+// -----
+
 func.func @arm_sme_move_vector_to_tile_slice_i8(%vector : vector<[16]xi8>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> () {
   // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} : vector<[16]xi8> into vector<[16]x[16]xi8>
   %c0 = arith.constant 0 : index
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index cb35de11ab5b3ed..3012ed156578059 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -4,7 +4,7 @@
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi8>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
 func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref<?x?xi8>
@@ -17,7 +17,7 @@ func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref<?
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xi16>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi16>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi16>, vector<[8]x[8]xi16>
 func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi16>, memref<?x?xi16>
@@ -30,7 +30,7 @@ func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref<?
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi32>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi32>, vector<[4]x[4]xi32>
 func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xi32>, memref<?x?xi32>
@@ -43,7 +43,7 @@ func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref<?
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[2]x[2]xi64>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi64>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi64>, vector<[2]x[2]xi64>
 func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref<?x?xi64>
@@ -56,7 +56,7 @@ func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref<?
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xf16>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf16>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf16>, vector<[8]x[8]xf16>
 func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
@@ -69,7 +69,7 @@ func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref<?
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[8]x[8]xbf16>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xbf16>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
 func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref<?x?xbf16>
@@ -82,7 +82,7 @@ func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf32>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<[4]x[4]xf32>
 func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
@@ -95,7 +95,7 @@ func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref<?
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[2]x[2]xf64>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xf64>) {
 // CHECK:         %[[C0:.*]] = arith.constant 0 : index
-// CHECK:         arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
+// CHECK:         arm_sme.tile_store %[[VECTOR]], <hor>, %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xf64>, vector<[2]x[2]xf64>
 func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
   vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref<?x?xf64>
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
new file mode 100644
index 000000000000000..ea9cc61d931266e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -0,0 +1,110 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @printTileEnd() {
+  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_i32 = arith.constant 1 : i32
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
+
+  // Fill each "row" of "mem1" with row number.
+  //
+  // For example, assuming an SVL of 128-bits:
+  //
+  //   0, 0, 0, 0
+  //   1, 1, 1, 1
+  //   2, 2, 2, 2
+  //   3, 3, 3, 3
+  //
+  %init_0 = arith.constant 0 : i32
+  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    %val_next = arith.addi %val, %c1_i32 : i32
+    scf.yield %val_next : i32
+  }
+
+  // Load tile from "mem1" vertically.
+  %0 = arm_sme.tile_load <ver>, %mem1[%c0, %c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Store tile back to "mem2" to print.
+  // TODO: Support vector.print for 2-D scalable vectors so don't have to spill
+  // to memory and reload to print.
+  vector.store %0, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 3, 3, 3, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")

>From 3ad71b9afff663101868c350f3397a36832bb47f Mon Sep 17 00:00:00 2001
From: Cullen Rhodes <cullen.rhodes at arm.com>
Date: Tue, 19 Sep 2023 10:03:47 +0000
Subject: [PATCH 2/2] [mlir][ArmSME] Add support for vector.transpose

This patch adds support for lowering vector.transpose to ArmSME. It's
implemented by storing the input tile of the tranpose to memory and
reloading vertically, building on top of the tile slice layout support.

Tranposing via memory is obviously expensive, the current intention is
to avoid the transpose if possible, this is therefore intended as a
fallback and to provide base support for Vector ops. If it turns out
transposes can't be avoided then this should be replaced with a more
optimal implementation, perhaps with tile <-> vector (MOVA) ops.

Depends on #66758.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td |   3 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |  76 ++++++++++-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   1 +
 mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt     |   1 +
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     | 122 ++++++++++++++++++
 .../Vector/CPU/ArmSME/test-transpose.mlir     | 113 ++++++++++++++++
 6 files changed, 314 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 1a4984f3bd6ba27..884773aa559bcf9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -36,7 +36,8 @@ def ArmSME_Dialect : Dialect {
     https://developer.arm.com/documentation/ddi0616
     https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions
   }];
-  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect"];
+  let dependentDialects = ["scf::SCFDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect"];
   let useDefaultAttributePrinterParser = 1;
 }
 
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index feaec0e035ed9fd..30c516ffbe1e900 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
 
@@ -241,11 +242,84 @@ struct BroadcastOpToArmSMELowering
   }
 };
 
+/// Conversion pattern for vector.transpose.
+///
+/// Stores the input tile to memory and reloads vertically.
+///
+/// Example:
+///
+///   %transposed_src = vector.transpose %src, [1, 0]
+///     : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+///
+/// is converted to:
+///
+///   %alloca = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
+///   %arm_sme.tile_store %src, <hor>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///   %transposed_src = arm_sme.tile_load <ver>, %alloca[%c0, %c0]
+///     : memref<?x?xi32>, vector<[4]x[4]xi32>
+///
+/// NOTE: Tranposing via memory is obviously expensive, the current intention
+/// is to avoid the transpose if possible, this is therefore intended as a
+/// fallback and to provide base support for Vector ops. If it turns out
+/// transposes can't be avoided then this should be replaced with a more optimal
+/// implementation, perhaps with tile <-> vector (MOVA) ops.
+struct TransposeOpToArmSMELowering
+    : public OpRewritePattern<vector::TransposeOp> {
+  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const final {
+    auto tileType = transposeOp.getResultVectorType();
+    if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
+      return failure();
+
+    SmallVector<int64_t> transp;
+    for (auto attr : transposeOp.getTransp())
+      transp.push_back(cast<IntegerAttr>(attr).getInt());
+
+    if (transp[0] != 1 && transp[1] != 0)
+      return failure();
+
+    OpBuilder::InsertionGuard g(rewriter);
+    auto loc = transposeOp.getLoc();
+
+    // Allocate buffer to store input tile to.
+    Value vscale =
+        rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
+    Value minTileSlices = rewriter.create<arith::ConstantOp>(
+        loc, rewriter.getIndexAttr(tileType.getDimSize(0)));
+    Value c0 =
+        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
+    Value numTileSlices =
+        rewriter.create<arith::MulIOp>(loc, vscale, minTileSlices);
+    auto bufferType =
+        MemRefType::get({ShapedType::kDynamic, ShapedType::kDynamic},
+                        tileType.getElementType());
+    auto buffer = rewriter.create<memref::AllocaOp>(
+        loc, bufferType, ValueRange{numTileSlices, numTileSlices});
+
+    Value input = transposeOp.getVector();
+
+    // Store input tile.
+    auto tileStoreOp = rewriter.create<arm_sme::TileStoreOp>(
+        loc, input, arm_sme::TileSliceLayout::Horizontal, buffer,
+        ValueRange{c0, c0});
+
+    // Reload input tile vertically.
+    rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
+        transposeOp, tileType, arm_sme::TileSliceLayout::Vertical,
+        tileStoreOp.getBase(), tileStoreOp.getIndices());
+
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
   patterns.add<TransferWriteToArmSMELowering, VectorLoadToArmSMELowering,
                VectorStoreToArmSMELowering, ConstantOpToArmSMELowering,
-               BroadcastOpToArmSMELowering>(&ctx);
+               BroadcastOpToArmSMELowering, TransposeOpToArmSMELowering>(&ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 92fb146691a0beb..5f48604f56ce85b 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
 
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 79f6a46c7c5889e..8c1f1a508caad47 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRLLVMDialect
+  MLIRMemRefDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRVectorDialect
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 3012ed156578059..53f8188df69695f 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s
 
+// =============================================================================
+// vector.transfer_write
+// =============================================================================
+
 // CHECK-LABEL: func.func @transfer_write_2d_i8(
 // CHECK-SAME:                                   %[[VECTOR:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                   %[[DEST:.*]]: memref<?x?xi8>) {
@@ -215,3 +219,121 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
   "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
   return
 }
+
+// =============================================================================
+// vector.transpose
+// =============================================================================
+
+// -----
+
+// CHECK-LABEL:   func.func @transpose_i8(
+// CHECK-SAME:                            %[[TILE:.*]]: vector<[16]x[16]xi8>)
+// CHECK:           %[[C16:.*]] = arith.constant 16 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[VSCALE:.*]] = vector.vscale
+// CHECK:           %[[MIN_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
+// CHECK:           %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[MIN_TILE_SLICES]], %[[MIN_TILE_SLICES]]) : memref<?x?xi8>
+// CHECK:           arm_sme.tile_store %[[TILE]], <hor>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+// CHECK:           arm_sme.tile_load <ver>, %[[NUM_TILE_SLICES]]{{\[}}%[[C0]], %[[C0]]] : memref<?x?xi8>, vector<[16]x[16]xi8>
+func.func @transpose_i8(%arg0: vector<[16]x[16]xi8>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[16]x[16]xi8> to vector<[16]x[16]xi8>
+  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i16(%arg0: vector<[8]x[8]xi16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xi16> to vector<[8]x[8]xi16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xi16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i32
+// CHECK: arith.constant 4
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i32(%arg0: vector<[4]x[4]xi32>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i64
+// CHECK: arith.constant 2
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i64(%arg0: vector<[2]x[2]xi64>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xi64> to vector<[2]x[2]xi64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xi64>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_i128
+// CHECK: %[[VSCALE:.*]] = vector.vscale
+// CHECK: %[[NUM_TILE_SLICES:.*]] = memref.alloca(%[[VSCALE]], %[[VSCALE]]) : memref<?x?xi128>
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_i128(%arg0: vector<[1]x[1]xi128>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[1]x[1]xi128> to vector<[1]x[1]xi128>
+  "prevent.dce"(%0) : (vector<[1]x[1]xi128>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_f16(%arg0: vector<[8]x[8]xf16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xf16> to vector<[8]x[8]xf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_bf16
+// CHECK: arith.constant 8
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_bf16(%arg0: vector<[8]x[8]xbf16>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[8]x[8]xbf16> to vector<[8]x[8]xbf16>
+  "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f32
+// CHECK: arith.constant 4
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_f32(%arg0: vector<[4]x[4]xf32>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[4]x[4]xf32> to vector<[4]x[4]xf32>
+  "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @transpose_f64
+// CHECK: arith.constant 2
+// CHECK: arm_sme.tile_store
+// CHECK: arm_sme.tile_load
+func.func @transpose_f64(%arg0: vector<[2]x[2]xf64>) {
+  %0 = vector.transpose %arg0, [1, 0] : vector<[2]x[2]xf64> to vector<[2]x[2]xf64>
+  "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
new file mode 100644
index 000000000000000..4350abbd13eca75
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -0,0 +1,113 @@
+// DEFINE: %{entry_point} = entry
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+llvm.func @printCString(!llvm.ptr<i8>)
+
+func.func @printTileBegin() {
+  %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<11 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @printTileEnd() {
+  %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
+  %1 = llvm.mlir.constant(0 : index) : i64
+  %2 = llvm.getelementptr %0[%1, %1]
+    : (!llvm.ptr<array<9 x i8>>, i64, i64) -> !llvm.ptr<i8>
+  llvm.call @printCString(%2) : (!llvm.ptr<i8>) -> ()
+  return
+}
+
+func.func @entry() {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c1_i32 = arith.constant 1 : i32
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem1 = memref.alloca(%za_s_size) : memref<?xi32>
+  %mem2 = memref.alloca(%za_s_size) : memref<?xi32>
+
+  // Fill each "row" of "mem1" with row number.
+  //
+  // For example, assuming an SVL of 128-bits:
+  //
+  //   0, 0, 0, 0
+  //   1, 1, 1, 1
+  //   2, 2, 2, 2
+  //   3, 3, 3, 3
+  //
+  %init_0 = arith.constant 0 : i32
+  scf.for %i = %c0 to %za_s_size step %svl_s iter_args(%val = %init_0) -> (i32) {
+    %splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
+    vector.store %splat_val, %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    %val_next = arith.addi %val, %c1_i32 : i32
+    scf.yield %val_next : i32
+  }
+
+  // Load tile from "mem1".
+  %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Transpose tile.
+  %transposed_tile = vector.transpose %tile, [1, 0] : vector<[4]x[4]xi32> to vector<[4]x[4]xi32>
+
+  // Store tile back to "mem2" to print.
+  // TODO: Replace this with vector.print when
+  // https://github.com/llvm/llvm-project/pull/66691 lands.
+  vector.store %transposed_tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+
+  // Dump "mem1". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 0, 0, 0
+  // CHECK-NEXT: ( 1, 1, 1, 1
+  // CHECK-NEXT: ( 2, 2, 2, 2
+  // CHECK-NEXT: ( 3, 3, 3, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem1[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  // Dump "mem2". The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xi32.
+  //
+  // CHECK:      TILE BEGIN
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK-NEXT: ( 0, 1, 2, 3
+  // CHECK:      TILE END
+  func.call @printTileBegin() : () -> ()
+  scf.for %i = %c0 to %za_s_size step %svl_s {
+    %tileslice = vector.load %mem2[%i] : memref<?xi32>, vector<[4]xi32>
+    vector.print %tileslice : vector<[4]xi32>
+  }
+  func.call @printTileEnd() : () -> ()
+
+  return
+}
+
+llvm.mlir.global internal constant @str_tile_begin("TILE BEGIN\0A")
+llvm.mlir.global internal constant @str_tile_end("TILE END\0A")



More information about the Mlir-commits mailing list