[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Nov 23 08:43:09 PST 2023


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/73253

This reworks the ArmSME dialect to use attributes for tile allocation. This has a number of advantages and corrects some issues with the previous approach:

 * Tile allocation can now be done ASAP (i.e. immediately after `-convert-vector-to-arm-sme`)
 * SSA form for control flow is now supported (e.g.`scf.for` loops that yeild tiles)
 * ArmSME ops can be converted to intrinsics very late (i.e. after lowering to control flow)
 * Tests are simplified by removing constants and casts
 * Avoids correctness issues with representing LLVM `immargs` as MLIR values
    - The tile ID on the SME intrinsics is an `immarg` (so is required to be a compile-time constant), `immargs` should be mapped to MLIR attributes (this is already the case for intrinsics in the LLVM dialect)
    - Using MLIR values for `immargs` can lead to invalid LLVM IR being generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```

These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```

The new tile allocation works by operations implementing the `ArmSMETileOpInterface`. This interface says that an operation needs to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return `std::nullopt` (which is the default behaviour).

Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```

Allocating operations become the roots for the tile allocation pass, which currently just (naively) assigns all transitive uses of a root operation the same tile ID. However, this is enough to handle current use cases.

Once tile IDs have been allocated subsequent rewrites can forward the tile IDs to any newly operations.

>From a9595e75f75a7ad132318c4edfcb31892bba6e29 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 23 Nov 2023 14:48:32 +0000
Subject: [PATCH] [mlir][ArmSME] Switch to an attribute-based tile allocation
 scheme

This reworks the ArmSME dialect to use attributes for tile allocation.
This has a number of advantages and corrects some issues with the
previous approach:

 * Tile allocation can now be done ASAP (i.e. immediately after
   `-convert-vector-to-arm-sme`)
 * SSA form for control flow is now supported (e.g.`scf.for` loops that
   yeild tiles)
 * ArmSME ops can be converted to intrinsics very late (i.e. after
   lowering to control flow)
 * Tests are simplified by removing constants and casts
 * Avoids correctness issues with representing LLVM `immargs` as MLIR
   values
    - The tile ID on the SME intrinsics is an `immarg` (so is required
      to be a compile-time constant), `immargs` should be mapped to
      MLIR attributes (this is already the case for intrinsics in the
      LLVM dialect)
    - Using MLIR values for `immargs` can lead to invalid LLVM IR being
      generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```

These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```

The new tile allocation works by operations implementing the
`ArmSMETileOpInterface`. This interface says that an operation needs
to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return `std::nullopt` (which
is the default behaviour).

Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```

Allocating operations become the roots for the tile allocation pass,
which currently just (naively) assigns all transitive uses of a root
operation the same tile ID. However, this is enough to handle current
use cases.

Once tile IDs have been allocated subsequent rewrites can forward the
tile IDs to any newly operations.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |   4 +-
 .../mlir/Dialect/ArmSME/IR/ArmSMEEnums.h      |  16 +
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   |  52 ++-
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 283 ++++++------
 .../mlir/Dialect/ArmSME/IR/CMakeLists.txt     |   6 +
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  16 +-
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 197 ++++-----
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  47 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |  23 +-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |  22 +-
 mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt     |   1 +
 .../ArmSME/Transforms/TileAllocation.cpp      | 105 +++--
 mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt  |   2 -
 mlir/lib/Dialect/ArmSME/Utils/Utils.cpp       |  44 +-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  10 +-
 .../test/Dialect/ArmSME/arith-ops-to-sme.mlir |   6 +-
 .../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir |  51 ---
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 252 ++++++-----
 mlir/test/Dialect/ArmSME/canonicalize.mlir    |  40 +-
 mlir/test/Dialect/ArmSME/cse.mlir             |  36 +-
 mlir/test/Dialect/ArmSME/invalid.mlir         |  62 +--
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 193 ++-------
 mlir/test/Dialect/ArmSME/tile-allocation.mlir | 376 ++++++++--------
 mlir/test/Dialect/ArmSME/tile-zero-masks.mlir | 102 +----
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    | 410 +++++++++---------
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |   6 +-
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |   4 +-
 .../Linalg/CPU/ArmSME/matmul-transpose-a.mlir |   4 +-
 .../Dialect/Linalg/CPU/ArmSME/matmul.mlir     |   4 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |   4 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |   4 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |   4 +-
 .../CPU/ArmSME/test-transfer-read-2d.mlir     |   4 +-
 .../CPU/ArmSME/test-transfer-write-2d.mlir    |   4 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |   4 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |   4 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |   4 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |   5 +-
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir  |  15 +-
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 331 +++++++-------
 mlir/tools/mlir-query/mlir-query.cpp          |   6 +-
 41 files changed, 1271 insertions(+), 1492 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
 delete mode 100644 mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..1da8e488a4c4647 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..430f3571001c8f4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEDialect.h - Arm SME Dialect Enums ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..2a0167afa8bae9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
   }];
 }
 
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = [],
+                    list<int> overloadedOperands = [],
                     list<Trait> traits = [], int numResults = 0,
                     list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
@@ -64,16 +67,26 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
           /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
+          /*int numResults=*/numResults,
+          /*bit requiresAccessGroup=*/0,
+          /*bit requiresAliasAnalysis=*/0,
+          /*bit requiresFastmath=*/0,
+          /*list<int> immArgPositions=*/immArgPositions,
+          /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 // Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero",
+                            /*immArgPositions=*/[0],
+                            /*immArgAttrNames=*/["tile_mask"]>,
+                            Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
 
 // MOP's
 class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[0],
+      /*immArgAttrNames=*/["tile_id"],
+      /*overloadedOperands=*/[4]>,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
                  Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
                  Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +105,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[2],
+      /*immArgAttrNames=*/["tile_id"]>;
+
 // Loads
 class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Load address">:$load_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +131,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
 
 // Stores
 class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +156,28 @@ def LLVM_aarch64_sme_str
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+    : ArmSME_IntrOp<"write." # direction,
+                    /*immArgPositions=*/[0],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[3],
                     [AllShapesMatch<["predicate", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
 // Tile slice to vector
 class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+    : ArmSME_IntrOp<"read." # direction,
+                    /*immArgPositions=*/[2],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[],
                     [AllShapesMatch<["vector", "predicate", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
                     /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
-                     Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..abcc9b649c4a530 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,99 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+    [
+      I32EnumAttrCase<"ZAB", 0, "za.b">,
+      I32EnumAttrCase<"ZAH", 1, "za.h">,
+      I32EnumAttrCase<"ZAS", 2, "za.s">,
+      I32EnumAttrCase<"ZAD", 3, "za.d">,
+      I32EnumAttrCase<"ZAQ", 4, "za.q">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+  let description = [{
+    An interface for operations that use or allocate Arm SME tiles. These
+    operations need to be assigned a tile ID an i32 attribute, which specifies
+    which virtual tile within the ZA storage to use. The number of tiles
+    available depends on the type of the tile. This is summarized below:
+
+    | Tile Vector Types                                                       | Possible Tile IDs   |
+    |-------------------------------------------------------------------------|---------------------|
+    | `vector<[16]x[16]xi8>`                                                  | 0                   |
+    | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1             |
+    | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
+    | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
+    | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
+
+    Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+    the roots for tile allocation, with all operations that (transitively)
+    depend on a root being assigned the same tile ID.
+  }];
+  let methods = [
+    InterfaceMethod<
+      "Sets the tile ID for this operation.",
+      /*returnType=*/"void",
+      /*methodName=*/"setTileId",
+      /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        if (!tileId)
+          return;
+        ::mlir::Operation* op = this->getOperation();
+        op->setAttr("tile_id", tileId);
+      }]
+    >,
+    InterfaceMethod<
+      "Returns the (possibly null) tile ID assigned to this operation.",
+      /*returnType=*/"mlir::IntegerAttr",
+      /*methodName=*/"getTileId",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        ::mlir::Operation* op = this->getOperation();
+        return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+      }]
+    >,
+    InterfaceMethod<
+      "The type of tile this operation allocates (or none)",
+      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+      /*methodName=*/"getAllocatedTileType",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        // Do not allocate a new tile.
+        return std::nullopt;
+      }]
+    >
+  ];
+
+  let extraSharedClassDeclaration = [{
+    // A helper to create a new operation and propagate this operations tile ID.
+    template<typename T, typename... Args>
+    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+        tileOp.setTileId($_op.getTileId());
+      return op;
+    }
+
+    // A helper to replace this operation and forward any tile ID.
+    template<typename T, typename... Args>
+    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+      rewriter.replaceOp($_op, newOp);
+      return newOp;
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -44,7 +137,8 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
-                        "a vector type that fits into a SME tile">
+                        "a vector type that fits into a SME tile",
+                        "VectorType">
 {
   let description = [{
     Possible vector types:
@@ -66,40 +160,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
   }];
 }
 
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
-                       "an identifier of a virtual tile (of a size) within the ZA storage">
-{
-  let description = [{
-    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
-    the integer indicates the tile to use, and the bit size indicates the size
-    of tile. The number of tiles available and the element types of those depend
-    on the size. This is summarised below:
-
-    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
-    |--------------|---------------------|-------------------------------------------------------------------------|
-    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
-    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
-    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
-    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
-    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
-  }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
-  "`tile_id` has the same number of bits as elements in `vector`",
-  "vector", "tile_id",
-  "IntegerType::get("
-      "$_self.getContext(),"
-      "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
-          "? ::llvm::cast<IntegerType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth()"
-          ": ::llvm::cast<FloatType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth())">;
-
 class HasMatchingMaskTypeConstraint<string vector, string mask> :
   OptionalTypesMatchWith<
     mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +222,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+  let summary = "Returns a SME virtual tile";
   let description = [{
-    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
-    scalable vector type, which represents an SME "virtual tile". This would
-    normally be used when lowering operations that return "virtual tile" vector
-    types to model the output. This is required to preserve dataflow as SME
-    intrinsics have no return values.
+    Allocates a new SME "virtual tile" within a function. The contents of the
+    tile returned from this operation undefined.
 
-    Example:
+    Example 1:
 
-    Input:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 8-bit element "virtual tile"
+    %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
-    After lowering `vector.load`:
+    Example 2:
+
     ```mlir
-    %tile_id = arm_sme.get_tile_id : i32
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-    }
-    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate two 16-bit element "virtual tiles"
+    %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+    %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
-    In the example above, the `vector.load` can't be replaced with an SME
-    intrinsic that has no outputs since it is used by the `vector.store`.
-    However, by inserting a `cast_tile_to_vector` op after the load intrinsics
-    the `vector.load` can be replaced. This enables "local" rewrites on
-    individual vector ops, rather than "global" rewrites that would have to
-    look at the vector op uses and also lower them.
-
-    Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
-    the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
-  }];
-  let arguments = (ins TileID:$tile_id);
-  let results = (outs SMETile:$vector);
-  let assemblyFormat =
-    "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
-  let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from 2-d scalable vector type to tile id";
-  let description = [{
-    A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
-    type, which represents an SME "virtual tile", to a tile id. This is
-    required to preserve dataflow as the SME intrinsics have no return values.
-
-    Example:
-
-    Input:
+    Example 3:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 128-bit element "virtual tile"
+    %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
+  }];
 
-    After lowering `vector.store`:
-    ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
-      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
+
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
     }
-    ```
 
-    Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
-    the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getTileType());
+    }
   }];
-  let arguments = (ins SMETile:$vector);
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat =
-    "$vector attr-dict `:` type($vector) `to` type($tile_id)";
-  let hasCanonicalizeMethod = 1;
 }
 
-def GetTileID : ArmSME_Op<"get_tile_id"> {
-  let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+  let summary = "SME tile placeholder";
   let description = [{
-    A `get_tile_id` operation returns a scalar integer representing an SME
-    "virtual tile" id. The bitwidth of the scalar indicates the element
-    bitwidth of the "virtual tile".
-
-    The scope of a tile id is a function and cannot be passed or returned from
-    functions.
+    A placeholder to preserve dataflow while lowering to SME intrinsics (which
+    do not take or return tile values). This operation is intended to be DCE'd
+    once all ArmSME operations have been lowered.
 
-    Example:
-    ```mlir
-    // Allocate and return an 8-bit element "virtual tile" id
-    %za0_b = arm_sme.get_tile_id : i8
-    ```
-
-    Example:
-    ```
-    // Allocate and return two 16-bit element "virtual tile" ids
-    %za0_h = arm_sme.get_tile_id : i16
-    %za1_h = arm_sme.get_tile_id : i16
-    ```
-
-    Example:
-    ```
-    // Allocate and return an 128-bit element "virtual tile" id
-    %za0_q = arm_sme.get_tile_id : i128
-    ```
+    This operation is not intended to be used outside of the ArmSME -> LLVM
+    conversion.
   }];
-
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat = "attr-dict `:` type($tile_id)";
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
 //
 // Tile reset.
 //
 
-def ZeroOp : ArmSME_Op<"zero", [Pure]> {
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let summary = "Initialize the two-dimensional ZA array with 0s";
   let results = (outs SMETile:$res);
   let description = [{
@@ -303,11 +305,15 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getVectorType());
+    }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
 def TileLoadOp : ArmSME_Op<"tile_load", [
+  ArmSMETileOpInterface,
   AttrSizedOperandSegments,
   OptionalTypesMatchWith<
     "padding type matches element type of result",
@@ -375,6 +381,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getVectorType());
+    }
   }];
 
   let builders = [
@@ -394,6 +403,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store", [
+  ArmSMETileOpInterface,
   AttrSizedOperandSegments,
   HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
 ]> {
@@ -457,6 +467,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
+  ArmSMETileOpInterface,
   AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
 ]> {
   let summary = "Tile slice load and update operation";
@@ -515,6 +526,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 }
 
 def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
+  ArmSMETileOpInterface,
   TileSliceMaskConstraint<"tile", "mask">
 ]> {
   let summary = "Tile slice store operation";
@@ -570,6 +582,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
+    ArmSMETileOpInterface,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -617,7 +630,8 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 }
 
-def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
+def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
+    ArmSMETileOpInterface,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -670,7 +684,8 @@ class OuterProductResultTileTypeConstraint<string operand> :
     "}()">;
 
 def OuterProductOp :
-  ArmSME_Op<"outerproduct", [Pure,
+  ArmSME_Op<"outerproduct", [
+    ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
     HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
@@ -715,7 +730,7 @@ def OuterProductOp :
     ```
   }];
 
-  let arguments = (ins
+let arguments = (ins
     SVEVector:$lhs, SVEVector:$rhs,
     Optional<SVEPredicate>:$lhsMask,
     Optional<SVEPredicate>:$rhsMask,
@@ -736,6 +751,12 @@ def OuterProductOp :
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 9319a8042c93f4a..9801d8b099e3f0a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,6 +15,12 @@ set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
 mlir_tablegen(ArmSMEOpsConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
 
+# Generate op interface declarations and definitions
+set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
+mlir_tablegen(ArmSMEOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ArmSMEOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRArmSMEOpInterfaces)
+
 # Generate declarations and definitions of ArmSME intrinsic Ops
 set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
 mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 0941592497beaae..954371cb9d5a0b1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -15,10 +15,11 @@
 #ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include <optional>
 
-namespace mlir {
-namespace arm_sme {
+namespace mlir::arm_sme {
 
 constexpr unsigned MinStreamingVectorLengthInBits = 128;
 
@@ -34,13 +35,8 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+std::optional<ArmSMETileType> getSMETileType(VectorType);
 
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..4f4b090dd10e3c0 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,6 +32,26 @@ using namespace mlir;
 
 namespace {
 
+IntegerAttr getTileIdOrError(ArmSMETileOpInterface op) {
+  auto tileId = op.getTileId();
+  if (!tileId)
+    op.emitOpError(
+        "expected tile ID to be allocated before conversion to LLVM");
+  return tileId;
+}
+
+struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTile> {
+  using ConvertOpToLLVMPattern<arm_sme::GetTile>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::GetTile getTile, OpAdaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+        getTile, getTile.getTileType());
+    return success();
+  }
+};
+
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
@@ -61,9 +81,9 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     unsigned tileElementWidth =
         zero.getVectorType().getElementType().getIntOrFloatBitWidth();
 
-    // Get Tile ID for the `zero` intrinsic.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
+    auto tileId = getTileIdOrError(zero);
+    if (!tileId)
+      return failure();
 
     // Get the base mask for tile based on the element size.
     // The base mask is just the mask to zero the first tile (of a size).
@@ -93,9 +113,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
         llvm_unreachable("bad element size");
       }
     }();
-    auto maskType = rewriter.getI32Type();
-    auto baseMask = rewriter.create<arith::ConstantOp>(
-        loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
 
     // The actual mask is just the base mask shifted by the tile ID.
     // This will be folded to a constant after tile allocation.
@@ -118,13 +135,13 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     //  * Mask    -> 10001000 = (00010001 << 3)
     //
     // This holds for all tile sizes.
-    auto tileMask = rewriter.create<arith::ShLIOp>(
-        loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
-    rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
+    int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
+    rewriter.create<arm_sme::aarch64_sme_zero>(
+        loc, rewriter.getI32IntegerAttr(zeroMask));
 
-    // Create `CastTileToVectorOp` to use as the output.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
-                                                           tileId);
+    // Create a placeholder op to preserve dataflow.
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+        zero, zero.getVectorType());
 
     return success();
   }
@@ -141,15 +158,9 @@ struct LoadTileSliceConversion
                   arm_sme::LoadTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = loadTileSliceOp.getLoc();
-    auto tileType = loadTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
-
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
-    // loaded to.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        loadTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(loadTileSliceOp);
+    if (!tileId)
+      return failure();
 
     Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
                                            adaptor.getBase(),
@@ -164,7 +175,9 @@ struct LoadTileSliceConversion
     // Create all active predicate mask.
     auto maskOp = loadTileSliceOp.getMask();
 
-    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+    auto tileType = loadTileSliceOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
 
     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
@@ -174,23 +187,23 @@ struct LoadTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 16:
         rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 32:
         rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 64:
         rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 128:
         rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       }
     } else {
@@ -199,31 +212,30 @@ struct LoadTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 16:
         rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 32:
         rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 64:
         rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 128:
         rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       }
     }
 
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
-    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
-                                                           tileType, tile);
+    // the input tile to preserve dataflow.
+    rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
 
     return success();
   }
@@ -244,11 +256,9 @@ struct StoreTileSliceConversion
     auto tileElementType = tileType.getElementType();
     unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
-    // being stored.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        storeTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(storeTileSliceOp);
+    if (!tileId)
+      return failure();
 
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
     Value ptr = this->getStridedElementPtr(
@@ -263,7 +273,6 @@ struct StoreTileSliceConversion
 
     auto maskOp = storeTileSliceOp.getMask();
 
-    Value tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
 
     if (layout == arm_sme::TileSliceLayout::Horizontal) {
@@ -272,23 +281,23 @@ struct StoreTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       }
     } else {
@@ -297,23 +306,23 @@ struct StoreTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       }
     }
@@ -334,14 +343,10 @@ struct MoveVectorToTileSliceConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveVectorToTileSliceOp.getLoc();
     auto tileType = moveVectorToTileSliceOp.getTileType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
-    // loaded to.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        moveVectorToTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
+    if (!tileId)
+      return failure();
 
     auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
 
@@ -357,26 +362,24 @@ struct MoveVectorToTileSliceConversion
                                   /*scalableDims=*/{true});
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
-    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-
     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
     switch (moveVectorToTileSliceOp.getLayout()) {
     case arm_sme::TileSliceLayout::Horizontal:
       rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-          loc, tileI32, tileSliceI32, allActiveMask,
+          loc, tileId, tileSliceI32, allActiveMask,
           moveVectorToTileSliceOp.getVector());
       break;
     case arm_sme::TileSliceLayout::Vertical:
       rewriter.create<arm_sme::aarch64_sme_write_vert>(
-          loc, tileI32, tileSliceI32, allActiveMask,
+          loc, tileId, tileSliceI32, allActiveMask,
           moveVectorToTileSliceOp.getVector());
       break;
     }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
-    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
-        moveVectorToTileSliceOp, tileType, tile);
+    // the input tile to preserve dataflow.
+    rewriter.replaceOp(moveVectorToTileSliceOp,
+                       moveVectorToTileSliceOp.getTile());
 
     return success();
   }
@@ -394,12 +397,11 @@ struct MoveTileSliceToVectorConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveTileSliceToVector.getLoc();
     auto sliceType = moveTileSliceToVector.getSliceType();
-    auto tile = moveTileSliceToVector.getTile();
     auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
 
-    // Cast tile to i32 tile ID.
-    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
-    Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+    auto tileId = getTileIdOrError(moveTileSliceToVector);
+    if (!tileId)
+      return failure();
 
     // Create an 'all true' predicate for the tile slice.
     auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
@@ -419,12 +421,12 @@ struct MoveTileSliceToVectorConversion
     case arm_sme::TileSliceLayout::Horizontal:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
           moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-          tileIdI32, sliceIndexI32);
+          tileId, sliceIndexI32);
       break;
     case arm_sme::TileSliceLayout::Vertical:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
           moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-          tileIdI32, sliceIndexI32);
+          tileId, sliceIndexI32);
       break;
     }
 
@@ -454,6 +456,10 @@ struct OuterProductOpConversion
   matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
                   arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto tileId = getTileIdOrError(outerProductOp);
+    if (!tileId)
+      return failure();
+
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
       // different features [1]:
@@ -498,13 +504,8 @@ struct OuterProductOpConversion
     Value acc = outerProductOp.getAcc();
     if (!acc)
       // Initalize accumulator with zero.
-      acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
-
-    unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
-    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(elementWidth), acc);
-
-    auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
+      acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+          rewriter, loc, resultVectorType);
 
     Value lhsMask = outerProductOp.getLhsMask();
     Value rhsMask = outerProductOp.getRhsMask();
@@ -519,13 +520,13 @@ struct OuterProductOpConversion
     }
 
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
                                                outerProductOp.getLhs(),
                                                outerProductOp.getRhs());
 
-    // Create `CastTileToVectorOp` to use as the output.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
-        outerProductOp, resultVectorType, tileId);
+    // The outerproduct intrinsics have no result, replace
+    // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
+    rewriter.replaceOp(outerProductOp, acc);
 
     return success();
   }
@@ -557,21 +558,20 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
-      arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
-      arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
-      arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
-      arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
-      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
-      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
-      arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
-      arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
-      arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
-      arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
-      arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
-      arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
-      arm_sme::aarch64_sme_mopa>();
+      arm_sme::MaterializeSSATile, arm_sme::aarch64_sme_zero,
+      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
 }
@@ -580,7 +580,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(
     ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
-               OuterProductOpConversion, ZeroOpConversion>(converter);
+               OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
+      converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index a4dfd4b7edc77c5..541d711fbd95f29 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -87,16 +87,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
     auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.load_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+    // Allocate a new SME tile.
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+        rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -128,8 +122,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
+    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+        rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
         memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -209,7 +203,8 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // Initialize tile with zero to satisfy padding. Inactive cols will be
     // zeroed anyway since the loads use zeroing predication. For inactive rows
     // however, no load will occur so these need to be zeroed.
-    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+        rewriter, loc, tileType);
 
     // Create a loop to load the active tile slices from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -226,9 +221,9 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
-        tileSliceIndex, tileLoadOp.getLayout());
+    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+        rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
+        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -276,7 +271,6 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
     auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
     auto maskOp = tileLoadOp.getMask();
     if (!maskOp)
@@ -304,14 +298,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), numCols);
 
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.load_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+    // Allocate a new SME tile.
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+        rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -356,8 +345,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
+    tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+        rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
         tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -450,8 +439,9 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     getMemrefIndices(tileStoreOp.getIndices(),
                      tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
-        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+
+    tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+        rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
         tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
 
     return success();
@@ -506,6 +496,7 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
       rewriter.setInsertionPointToStart(forOp.getBody());
       // Extract the current row from the tile.
       Value rowIndex = forOp.getInductionVar();
+      // FIXME: Forward tile IDs.
       auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
           loc, printOp.getSource(), rowIndex);
       // Print the row with a 1D vector.print.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index f1e92d1ac9708b7..109b04ce34a88b7 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -44,20 +44,6 @@ static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
   return forOp;
 }
 
-/// Returns a tile of the given vector type.
-static arm_sme::CastTileToVector
-getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
-                          VectorType type) {
-  unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
-
-  // Create 'arm_sme.get_tile' op.
-  auto tileId = rewriter.create<arm_sme::GetTileID>(
-      loc, rewriter.getIntegerType(tileElementWidth));
-
-  // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
-  return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
-}
-
 namespace {
 
 /// Conversion pattern for vector.transfer_read.
@@ -267,8 +253,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
 
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
     auto tileSliceIndex = forOp.getInductionVar();
@@ -330,8 +315,7 @@ struct BroadcastOpToArmSMELowering
     else
       return failure();
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
 
     // Create a loop over ZA tile slices.
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
@@ -387,8 +371,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
         loc, tileSliceType, splatOp.getInput());
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 9df15420b9c9b6f..29fa9085a0a963d 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -28,6 +28,8 @@ using namespace mlir::arm_sme;
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOps.cpp.inc"
 
@@ -54,23 +56,3 @@ void ArmSMEDialect::initialize() {
 #include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
       >();
 }
-
-// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id
-LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op,
-                                             PatternRewriter &rewriter) {
-  if (auto castTileToVectorOp = op.getVector().getDefiningOp<CastTileToVector>()) {
-    op.replaceAllUsesWith(castTileToVectorOp.getTileId());
-    return success();
-  }
-  return failure();
-}
-
-// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile
-LogicalResult CastTileToVector::canonicalize(CastTileToVector op,
-                                             PatternRewriter &rewriter) {
-  if (auto castVectorToTileOp = op.getTileId().getDefiningOp<CastVectorToTile>()) {
-    op.replaceAllUsesWith(castVectorToTileOp.getVector());
-    return success();
-  }
-  return failure();
-}
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 3e448ec4fb1e04d..66062335fa842a4 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRVectorDialect
+  MLIRArmSMEUtils
 )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e0462a6dc124131..0a65efeb266d15b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -41,10 +41,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "allocate-arm-sme-tiles"
 
@@ -107,73 +109,85 @@ enum class TileMask : unsigned {
 };
 
 /// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
-  static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
-  static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
-                                                   TileMask::kZA1H};
-  static const SmallVector<TileMask> ZA_S_MASKS = {
-      TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
-  static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+                                            TileMask::kZA2S, TileMask::kZA3S};
+  static constexpr std::array ZA_D_MASKS = {
       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
-  static const SmallVector<TileMask> ZA_Q_MASKS = {
+  static constexpr std::array ZA_Q_MASKS = {
       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
-  switch (cast<IntegerType>(type).getWidth()) {
-  default:
-    llvm_unreachable("unexpected type!");
-  case 8:
+  switch (type) {
+  case ArmSMETileType::ZAB:
     return ZA_B_MASKS;
-  case 16:
+  case ArmSMETileType::ZAH:
     return ZA_H_MASKS;
-  case 32:
+  case ArmSMETileType::ZAS:
     return ZA_S_MASKS;
-  case 64:
+  case ArmSMETileType::ZAD:
     return ZA_D_MASKS;
-  case 128:
+  case ArmSMETileType::ZAQ:
     return ZA_Q_MASKS;
   }
 }
 
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
-                             unsigned &tileID) {
-  auto masks = getMasks(tileIDOp.getType());
-  for (const auto &it : llvm::enumerate(masks)) {
-    const auto tileMask = it.value();
+/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
+static FailureOr<unsigned> getTile(ArmSMETileType tileType,
+                                   TileMask &tilesInUse) {
+  auto masks = getMasks(tileType);
+  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
     if ((tilesInUse & tileMask) == TileMask::kNone) {
       tilesInUse |= tileMask;
-      tileID = it.index();
-      return success();
+      return tileId;
     }
   }
-  return tileIDOp.emitError("ran out of SME virtual tiles!");
+  return failure();
 }
 
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetTileID tileIDOp,
+struct AssignTileIDsPattern
+    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
                                 PatternRewriter &rewriter) const override {
-    auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+    if (tileOp.getTileId())
+      return failure();
+
+    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+    if (!tileType)
+      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
     TileMask tilesInUse;
-    if (auto tilesInUseAttr =
-            funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+    if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
       tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
     else
       tilesInUse = TileMask::kNone;
 
-    unsigned tileID;
-    if (failed(getTile(tileIDOp, tilesInUse, tileID)))
-      return failure();
+    auto tileId = getTile(*tileType, tilesInUse);
+    if (failed(tileId))
+      return tileOp.emitError("ran out of SME virtual tiles!");
 
-    funcOp->setAttr(kTilesInUseAttr,
-                    rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+    func->setAttr(kTilesInUseAttr,
+                  rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+
+    // Find all the ops that (transitively) depend on this tile.
+    SetVector<Operation *> dependantOps;
+    getForwardSlice(tileOp.getOperation(), &dependantOps);
+
+    // Set all operations to use the same tile ID.
+    // This is a navie tile allocation scheme, but works for common cases.
+    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
+    tileOp.setTileId(tileIDAttr);
+    for (auto *op : dependantOps) {
+      if (auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op))
+        tileOp.setTileId(tileIDAttr);
+    }
 
-    auto tileType = tileIDOp.getType();
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-        tileIDOp, tileType, rewriter.getIntegerAttr(tileType, tileID));
     return success();
   }
 };
@@ -182,13 +196,14 @@ struct TileAllocationPass
     : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    ConversionTarget target(getContext());
-    patterns.add<GetTileIDConversion>(patterns.getContext());
-    target.addLegalOp<arith::ConstantOp>();
-    target.addIllegalOp<GetTileID>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    patterns.add<AssignTileIDsPattern>(patterns.getContext());
+    GreedyRewriteConfig config;
+    // This ensures tiles are allocated in program order.
+    config.useTopDownTraversal = true;
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(patterns), config))) {
       signalPassFailure();
+    }
   }
 };
 } // namespace
diff --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
index da8517aaf80a9fa..ecf774a215d24f8 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
@@ -5,7 +5,5 @@ add_mlir_dialect_library(MLIRArmSMEUtils
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
 
   LINK_LIBS PUBLIC
-  MLIRArmSMEDialect
-  MLIRDialect
   MLIRIR
   )
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index f17077ff8565d59..c3cdf5703bcbc28 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,24 +12,20 @@
 
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+namespace mlir::arm_sme {
 
-using namespace mlir;
-using namespace mlir::arm_sme;
-
-unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
+unsigned getSMETileSliceMinNumElts(Type type) {
   assert(isValidSMETileElementType(type) && "invalid tile type!");
   return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
 }
 
-bool mlir::arm_sme::isValidSMETileElementType(Type type) {
+bool isValidSMETileElementType(Type type) {
   return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
          type.isInteger(64) || type.isInteger(128) || type.isF16() ||
          type.isBF16() || type.isF32() || type.isF64() || type.isF128();
 }
 
-bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
+bool isValidSMETileVectorType(VectorType vType) {
   if ((vType.getRank() != 2) || !vType.allDimsScalable())
     return false;
 
@@ -37,22 +33,30 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
   if (!isValidSMETileElementType(elemType))
     return false;
 
-  unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
+  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
   if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
     return false;
 
   return true;
 }
 
-Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
-                                     RewriterBase &rewriter) {
-  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
-             tile.getDefiningOp())) &&
-         "expected ArmSME GetTileID or CastVectorToTile op!");
-  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
-  if (tileElementWidth < 32)
-    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
-  if (tileElementWidth > 32)
-    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
-  return tile;
+std::optional<ArmSMETileType> getSMETileType(VectorType type) {
+  if (!isValidSMETileVectorType(type))
+    return {};
+  switch (type.getElementTypeBitWidth()) {
+  case 8:
+    return ArmSMETileType::ZAB;
+  case 16:
+    return ArmSMETileType::ZAH;
+  case 32:
+    return ArmSMETileType::ZAS;
+  case 64:
+    return ArmSMETileType::ZAD;
+  case 128:
+    return ArmSMETileType::ZAQ;
+  default:
+    llvm_unreachable("unknown SME tile type");
+  }
 }
+
+} // namespace mlir::arm_sme
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 38a2332ffd5e0a3..fc28645a7acf7c0 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,8 +6,7 @@
 
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -16,7 +15,7 @@
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
 // CHECK-NEXT:      %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -60,8 +59,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
 // CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
 // CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
-// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -79,7 +77,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
 // CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index 12f7e7333ebc973..b8db105f9c601b4 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -95,8 +95,7 @@ func.func @arith_constant_dense_2d_zero_f64() {
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[C16:.*]] = arith.constant 16 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -115,8 +114,7 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[C2:.*]] = arith.constant 2 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i64 to vector<[2]x[2]xf64>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
deleted file mode 100644
index 65996e81c42d909..000000000000000
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ /dev/null
@@ -1,51 +0,0 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
-
-// This test verifies the temporary casts that are emitted when lowering to
-// intrinsics to preserve data flow are correct. Canonicalization will remove
-// these.
-
-// CHECK-LABEL: @arm_sme_zero
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: arm_sme.intr.zero
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  %tile = arm_sme.zero : vector<[16]x[16]xi8>
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_load
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-// CHECK: }
-// CHECK: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
-func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
-  %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return %tile : vector<[16]x[16]xi8>
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_store(
-// CHECK-SAME:                      %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return
-}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index fa62332bc3f5b17..bd88da37bdf966d 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
@@ -9,23 +9,21 @@
 // CHECK-LABEL:   func.func @arm_sme_load_tile_slice_hor_i8(
 // CHECK-SAME:                                              %[[SRC:.*]]: memref<?x?xi8>,
 // CHECK-SAME:                                              %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME:                                              %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index)
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -33,9 +31,10 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -43,9 +42,10 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -53,9 +53,10 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -63,9 +64,10 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -73,9 +75,10 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -83,9 +86,10 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -93,9 +97,10 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -103,9 +108,10 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -113,9 +119,10 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -123,9 +130,10 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -133,9 +141,10 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -143,9 +152,10 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -153,9 +163,10 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -163,9 +174,10 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -173,9 +185,10 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -183,9 +196,10 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -193,9 +207,10 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -207,25 +222,23 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // -----
 
 // CHECK-LABEL:   func.func @arm_sme_store_tile_slice_hor_i8(
-// CHECK-SAME:                                               %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                               %[[TILE_SLICE_INDEX:.*]]: index,
 // CHECK-SAME:                                               %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>)
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i8(%tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -233,9 +246,10 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -243,9 +257,10 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -253,9 +268,10 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -263,9 +279,10 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -273,9 +290,10 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -283,9 +301,10 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -293,9 +312,10 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -303,9 +323,10 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -313,9 +334,10 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i8(%tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -323,9 +345,10 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -333,9 +356,10 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -343,9 +367,10 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -353,9 +378,10 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -363,9 +389,10 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -373,9 +400,10 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -383,9 +411,10 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -393,9 +422,10 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -407,9 +437,10 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
   return
 }
@@ -417,9 +448,10 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return
 }
@@ -431,8 +463,9 @@ func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile_slice_index : index) -> vector<[16]xi8> {
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
@@ -440,8 +473,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile_slice_index : index) -> vector<[8]xi16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
@@ -449,8 +483,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile_slice_index : index) -> vector<[4]xi32> {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
@@ -458,8 +493,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile_slice_index : index) -> vector<[2]xi64> {
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
@@ -467,8 +503,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
@@ -476,8 +513,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile_slice_index : index) -> vector<[8]xf16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
@@ -485,8 +523,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile_slice_index : index) -> vector<[8]xbf16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
@@ -494,8 +533,9 @@ func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile_slice_index : index) -> vector<[4]xf32> {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
@@ -503,8 +543,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile_slice_index : index) -> vector<[2]xf64> {
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
@@ -512,8 +553,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
-// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index 06bbd3050fdecec..b7ba3f728c705a3 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,27 +1,25 @@
 // RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-// -----
-
-// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector
-// CHECK-SAME: %[[TILE_ID:.*]]: i8
-func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 {
-  // CHECK-NOT: arm_sme.cast_tile_to_vector
-  // CHECK-NOT: arm_sme.cast_vector_to_tile
-  // CHECK-NEXT: return %[[TILE_ID]] : i8
-  %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
-  %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
-  return %tile_id_1 : i8
-}
+// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
+// once it becomes unused, after lowering to control flow.
 
 // -----
 
-// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>
-func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> {
-  // CHECK-NOT: arm_sme.cast_vector_to_tile
-  // CHECK-NOT: arm_sme.cast_tile_to_vector
-  // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8>
-  %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8
-  %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
-  return %tile_1 : vector<[16]x[16]xi8>
+// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
+// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-NOT: vector<[4]x[4]xf32>
+func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+  %c10 = arith.constant 10 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+  cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
+^bb1(%1: index, %2: vector<[4]x[4]xf32>):  // 2 preds: ^bb0, ^bb2
+  %3 = arith.cmpi slt, %1, %c10 : index
+  cf.cond_br %3, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  %4 = arith.addi %1, %c1 : index
+  cf.br ^bb1(%4, %tile : index, vector<[4]x[4]xf32>)
+^bb3:  // pred: ^bb1
+  return
 }
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
index 734bd9f15c8dea1..74e7293eaeca5fc 100644
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ b/mlir/test/Dialect/ArmSME/cse.mlir
@@ -1,16 +1,30 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
 
-// This test is checking that CSE does not remove 'arm_sme.get_tile_id' ops as
+// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
 // duplicates.
-// CHECK-LABEL: @get_tile_id
-// CHECK: %[[TILE_ID_0:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE_ID_1:.*]] = arm_sme.get_tile_id : i32
-// CHECK: "prevent.dce"(%[[TILE_ID_0]]) : (i32) -> ()
-// CHECK: "prevent.dce"(%[[TILE_ID_1]]) : (i32) -> ()
-func.func @get_tile_id() {
-  %tile_id_1 = arm_sme.get_tile_id : i32
-  %tile_id_2 = arm_sme.get_tile_id : i32
-  "prevent.dce"(%tile_id_1) : (i32) -> ()
-  "prevent.dce"(%tile_id_2) : (i32) -> ()
+
+// CHECK-LABEL: @zero_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @zero_tile() {
+  %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
+  %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
+  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// CHECK-LABEL: @get_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @get_tile() {
+  %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index d35eb7f7bccc75e..85b95a8b6cf12b7 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,89 +1,49 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
 //===----------------------------------------------------------------------===//
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
-  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16>
-  return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_rank_1() -> vector<[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
+  %0 = arm_sme.get_tile : vector<[16]xi8>
   return %0 : vector<[16]xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
+func.func @arm_sme_get_tile__bad_vector_type_i4() -> vector<[16]x[16]xi4> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
+  %0 = arm_sme.get_tile : vector<[16]x[16]xi4>
   return %0 : vector<[16]x[16]xi4>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_0() -> vector<16x[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
+  %0 = arm_sme.get_tile : vector<16x[16]xi8>
   return %0 : vector<16x[16]xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_1() -> vector<[16]x16xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
+  %0 = arm_sme.get_tile : vector<[16]x16xi8>
   return %0 : vector<[16]x16xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
+func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
+  %0 = arm_sme.get_tile : vector<[4]x[16]xi8>
   return %0 : vector<[4]x[16]xi8>
 }
 
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
-  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32
-  return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
-  // expected-error at +1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
-  return %0 : i8
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id__bad_type() -> i1 {
-  // expected-error at +1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
-  %0 = arm_sme.get_tile_id : i1
-  return %0 : i1
-}
-
 //===----------------------------------------------------------------------===//
 // arm_sme.move_vector_to_tile_slice
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 1dbcc32e9259fc5..58ff7ef4d8340ec 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,197 +1,78 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
 //===----------------------------------------------------------------------===//
 
-func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
-  return %0 : vector<[16]x[16]xi8>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16>
-  return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-  return %0 : vector<[4]x[4]xi32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64>
-  return %0 : vector<[2]x[2]xi64>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128>
-  return %0 : vector<[1]x[1]xi128>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16>
-  return %0 : vector<[8]x[8]xf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16>
-  return %0 : vector<[8]x[8]xbf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32>
-  return %0 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64>
-  return %0 : vector<[2]x[2]xf64>
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8
-  return %0 : i8
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16
-  return %0 : i16
-}
 
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32
-  return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64
-  return %0 : i64
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128
-  return %0 : i128
+func.func @arm_sme_get_tile_i8() {
+  // CHECK: arm_sme.get_tile : vector<[16]x[16]xi8>
+  %0 = arm_sme.get_tile : vector<[16]x[16]xi8>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16
-  return %0 : i16
+func.func @arm_sme_get_tile_i16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xi16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xi16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16
-  return %0 : i16
+func.func @arm_sme_get_tile_i32() {
+  // CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
+  %0 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32
-  return %0 : i32
+func.func @arm_sme_get_tile_i64() {
+  // CHECK: arm_sme.get_tile : vector<[2]x[2]xi64>
+  %0 = arm_sme.get_tile : vector<[2]x[2]xi64>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64
-  return %0 : i64
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id_i8() -> i8 {
-  // CHECK: arm_sme.get_tile_id : i8
-  %0 = arm_sme.get_tile_id : i8
-  return %0 : i8
+func.func @arm_sme_get_tile_i128() {
+  // CHECK: arm_sme.get_tile : vector<[1]x[1]xi128>
+  %0 = arm_sme.get_tile : vector<[1]x[1]xi128>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i16() -> i16 {
-  // CHECK: arm_sme.get_tile_id : i16
-  %0 = arm_sme.get_tile_id : i16
-  return %0 : i16
+func.func @arm_sme_get_tile_f16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xf16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xf16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i32() -> i32 {
-  // CHECK: arm_sme.get_tile_id : i32
-  %0 = arm_sme.get_tile_id : i32
-  return %0 : i32
+func.func @arm_sme_get_tile_bf16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xbf16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xbf16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i64() -> i64 {
-  // CHECK: arm_sme.get_tile_id : i64
-  %0 = arm_sme.get_tile_id : i64
-  return %0 : i64
+func.func @arm_sme_get_tile_f32() {
+  // CHECK: arm_sme.get_tile : vector<[4]x[4]xf32>
+  %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i128() -> i128 {
-  // CHECK: arm_sme.get_tile_id : i128
-  %0 = arm_sme.get_tile_id : i128
-  return %0 : i128
+func.func @arm_sme_get_tile_f64() {
+  // CHECK: arm_sme.get_tile : vector<[2]x[2]xf64>
+  %0 = arm_sme.get_tile : vector<[2]x[2]xf64>
+  return
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
index a481516d4c15f93..1f895e4984ba844 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -6,17 +6,17 @@
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
 func.func @mixed_tiles() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.D ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // ZA15.Q is still free.
   return
 }
@@ -26,28 +26,26 @@ func.func @mixed_tiles() {
 // CHECK-LABEL: za_b
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_b() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_b = arm_sme.get_tile_id : i8
+  // CHECK-NEXT: tile_id = 0
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @za_b__out_of_tiles() {
-  %za0_b = arm_sme.get_tile_id : i8
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i8
+  %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @za_b_overlapping_za_q() {
-  %za0_b = arm_sme.get_tile_id : i8
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -56,8 +54,8 @@ func.func @za_b_overlapping_za_q() {
 // CHECK-LABEL: za0_h
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
 func.func @za0_h() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
@@ -66,21 +64,23 @@ func.func @za0_h() {
 // CHECK-LABEL: za_h
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
-  // CHECK-NEXT: arith.constant 1
-  %za1_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  // CHECK-NEXT: tile_id = 1
+  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_h__out_of_tiles
 func.func @za_h__out_of_tiles() {
-  %za0_h = arm_sme.get_tile_id : i16
-  %za1_h = arm_sme.get_tile_id : i16
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  // CHECK-NEXT: tile_id = 1
+  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i16
+  %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
@@ -90,14 +90,14 @@ func.func @za_h__out_of_tiles() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_s() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 3
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -107,38 +107,37 @@ func.func @za_h_overlapping_za_s() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_d() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA9.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 1
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA5.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 5
-  %za5_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 5
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_h_overlapping_za_q() {
-  %za0_h = arm_sme.get_tile_id : i16
-  %za0_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za8_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -147,8 +146,8 @@ func.func @za_h_overlapping_za_q() {
 // CHECK-LABEL: za0_s
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
 func.func @za0_s() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -157,27 +156,26 @@ func.func @za0_s() {
 // CHECK-LABEL: za_s
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 2
-  %za2_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 3
-  %za3_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 2
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 3
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @za_s__out_of_tiles() {
-  %za0_s = arm_sme.get_tile_id : i32
-  %za1_s = arm_sme.get_tile_id : i32
-  %za2_s = arm_sme.get_tile_id : i32
-  %za3_s = arm_sme.get_tile_id : i32
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i32
+  %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -187,42 +185,41 @@ func.func @za_s__out_of_tiles() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s_overlapping_za_d() {
   // ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 2
-  %za2_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 2
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_s_overlapping_za_q() {
-  %za0_s = arm_sme.get_tile_id : i32
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -231,8 +228,8 @@ func.func @za_s_overlapping_za_q() {
 // CHECK-LABEL: za0_d
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
 func.func @za0_d() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 0
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
@@ -241,63 +238,61 @@ func.func @za0_d() {
 // CHECK-LABEL: za_d
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_d() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 1
-  %za1_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 2
-  %za2_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 4
-  %za4_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 5
-  %za5_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 6
-  %za6_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 0
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 1
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 2
+  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 4
+  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 5
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 6
+  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_d__out_of_tiles() {
-  %za0_d = arm_sme.get_tile_id : i64
-  %za1_d = arm_sme.get_tile_id : i64
-  %za2_d = arm_sme.get_tile_id : i64
-  %za3_d = arm_sme.get_tile_id : i64
-  %za4_d = arm_sme.get_tile_id : i64
-  %za5_d = arm_sme.get_tile_id : i64
-  %za6_d = arm_sme.get_tile_id : i64
-  %za7_d = arm_sme.get_tile_id : i64
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i64
+  %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_d_overlapping_za_q() {
-  %za0_d = arm_sme.get_tile_id : i64
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -306,8 +301,8 @@ func.func @za_d_overlapping_za_q() {
 // CHECK-LABEL: za0_q
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
 func.func @za0_q() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 0
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -316,62 +311,61 @@ func.func @za0_q() {
 // CHECK-LABEL: za_q
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_q() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 1
-  %za1_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 2
-  %za2_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 3
-  %za3_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 4
-  %za4_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 5
-  %za5_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 6
-  %za6_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 7
-  %za7_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 8
-  %za8_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 9
-  %za9_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 10
-  %za10_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 11
-  %za11_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 12
-  %za12_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 13
-  %za13_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 14
-  %za14_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 15
-  %za15_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 0
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 1
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 2
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 4
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 6
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 8
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 10
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 12
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 14
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @za_q__out_of_tiles() {
-  %za0_q = arm_sme.get_tile_id : i128
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za8_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 2378f4234aef1ef..04412e4db1c5f3a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,7 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
-// RUN:           -allocate-arm-sme-tiles -canonicalize      \
-// RUN:           -allow-unregistered-dialect                \
-// RUN: | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
 
 // This test verifies the tile mask operand of the zero intrinsic zeroes
 // the correct tiles. Both integer and floating-point datatypes are checked.
@@ -10,13 +7,8 @@
 
 // CHECK-LABEL: zero_za_b
 func.func @zero_za_b() {
-  // CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8
-  // CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
   %zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
-  "prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -24,20 +16,10 @@ func.func @zero_za_b() {
 
 // CHECK-LABEL: zero_za_h
 func.func @zero_za_h() {
-  // CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16
-  // CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16>
+  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
   %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
-  "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
   %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
-  "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -45,32 +27,14 @@ func.func @zero_za_h() {
 
 // CHECK-LABEL: zero_za_s
 func.func @zero_za_s() {
-  // CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK:      arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
   %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
   %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
   %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
   %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
-  "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -78,55 +42,21 @@ func.func @zero_za_s() {
 
 // CHECK-LABEL: zero_za_d
 func.func @zero_za_d() {
-  // CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
   %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
   %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
   %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
   %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
   %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
   %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
   %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
   %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
-  "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 77ac071ef67de9a..d16d250b70eb3f6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
@@ -10,13 +10,9 @@
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
-// CHECK-DAG:  %[[C255:.*]] = arith.constant 255 : i32
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-DAG:  %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
-// CHECK-DAG:  "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
+// CHECK-DAG:  "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
 // CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
@@ -27,8 +23,7 @@
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -49,8 +44,6 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-LABEL: @vector_load_i8_with_offset(
 // CHECK-SAME:                              %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[C123:.*]] = arith.constant 123 : index
@@ -68,10 +61,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   %c123 = arith.constant 123 : index
@@ -84,8 +75,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
 // CHECK-SAME:                                     %[[ARG0:.*]]: memref<?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -98,10 +87,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
@@ -113,12 +100,10 @@ func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16
 
 // CHECK-LABEL: @vector_load_i16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -129,13 +114,10 @@ func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
 
 // CHECK-LABEL: @vector_load_i32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT:   arith.extui %[[TILE_ID]]
-// CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -146,12 +128,10 @@ func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
 
 // CHECK-LABEL: @vector_load_i64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -162,12 +142,10 @@ func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
 
 // CHECK-LABEL: @vector_load_f16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -178,12 +156,10 @@ func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
 
 // CHECK-LABEL: @vector_load_bf16(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xbf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -194,13 +170,10 @@ func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
 
 // CHECK-LABEL: @vector_load_f32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT:   arith.extui %[[TILE_ID]]
-// CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -211,12 +184,10 @@ func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
 
 // CHECK-LABEL: @vector_load_f64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -227,10 +198,8 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 
 // CHECK-LABEL: @vector_load_i128(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128>
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32
 // CHECK:       arm_sme.intr.ld1q.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -244,7 +213,6 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // -----
 
 // CHECK-LABEL: @vector_store_i8(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
@@ -256,19 +224,18 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
 // CHECK-NEXT: return
-func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>) {
+func.func @vector_store_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
@@ -276,15 +243,14 @@ func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>)
 // -----
 
 // CHECK-LABEL: @vector_store_i16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xi16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i16(%arg0 : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
@@ -292,16 +258,14 @@ func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>
 // -----
 
 // CHECK-LABEL: @vector_store_i32(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
 // CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
-func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i32(%arg0 : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
@@ -309,15 +273,14 @@ func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>
 // -----
 
 // CHECK-LABEL: @vector_store_i64(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xi64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
 // CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
-func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i64(%arg0 : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
@@ -325,15 +288,14 @@ func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>
 // -----
 
 // CHECK-LABEL: @vector_store_f16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f16(%arg0 : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
@@ -341,31 +303,28 @@ func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>
 // -----
 
 // CHECK-LABEL: @vector_store_bf16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xbf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xbf16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_bf16(%arg0 : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 // -----
 
 // CHECK-LABEL: @vector_store_f32(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xf32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
 // CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
-func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f32(%arg0 : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
@@ -373,15 +332,14 @@ func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>
 // -----
 
 // CHECK-LABEL: @vector_store_f64(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xf64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
 // CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
-func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f64(%arg0 : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
@@ -389,13 +347,12 @@ func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>
 // -----
 
 // CHECK-LABEL: @vector_store_i128(
-// CHECK-SAME:                     %[[TILE:.*]]: vector<[1]x[1]xi128>,
 // CHECK-SAME:                     %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32
 // CHECK:       arm_sme.intr.st1q.horiz
-func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi128>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i128(%arg0 : memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
@@ -407,12 +364,11 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
-func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>)
+func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>) {
   // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
-  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
-  // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
-  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  // CHECK: "arm_sme.intr.mopa"(%[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
 }
@@ -420,8 +376,9 @@ func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_bf16
-func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
 }
@@ -429,10 +386,9 @@ func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f32
-func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
-  // CHECK-NOT: arith.extui
-  // CHECK-NOT: arith.trunci
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -440,9 +396,9 @@ func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f64
-func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
-  // CHECK: arith.trunci {{.*}} : i64 to i32
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -451,8 +407,8 @@ func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]
 
 // CHECK-LABEL: @vector_outerproduct_no_accumulator
 func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
-  // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
   %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -460,12 +416,12 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %dim0 : index, %dim1 : index) {
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
-  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
-  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  // CHECK: "arm_sme.intr.mopa"(%[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
@@ -474,11 +430,12 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
@@ -487,11 +444,12 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
-func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
@@ -500,11 +458,12 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
-func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
@@ -520,7 +479,8 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 
 // -----
 
-func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
+  %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
@@ -529,7 +489,8 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
@@ -537,8 +498,9 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
   // CHECK: vector.outerproduct
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -550,14 +512,13 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i32(
-// CHECK-SAME:                       %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                       %[[SLICE:.*]]: vector<[4]xi32>,
 // CHECK-SAME:                       %[[INDEX:.*]]: index)
-func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
+func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-  // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
-  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
@@ -565,8 +526,9 @@ func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i8
-func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
@@ -574,8 +536,9 @@ func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[1
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i16
-func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
@@ -583,8 +546,9 @@ func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i64
-func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
@@ -592,8 +556,9 @@ func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i128
-func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
@@ -601,8 +566,9 @@ func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f16
-func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
@@ -610,8 +576,9 @@ func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_bf16
-func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
@@ -619,8 +586,9 @@ func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f32
-func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
@@ -628,8 +596,9 @@ func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f64
-func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
@@ -637,19 +606,18 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i32(
-// CHECK-SAME:                         %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                         %[[EL:.*]]: i32,
 // CHECK-SAME:                         %[[ROW:.*]]: index,
 // CHECK-SAME:                         %[[COL:.*]]: index)
-func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
-  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-  // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
+  // CHECK-DAG:  %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+  // CHECK-DAG:  %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-DAG:  %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
   // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
-  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
@@ -657,9 +625,10 @@ func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i8
-func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
@@ -667,9 +636,10 @@ func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i16
-func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
@@ -677,9 +647,10 @@ func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i64
-func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
@@ -687,9 +658,10 @@ func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i128
-func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
@@ -697,9 +669,10 @@ func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %r
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f16
-func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
@@ -707,9 +680,10 @@ func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_bf16
-func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
@@ -717,9 +691,10 @@ func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %r
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f32
-func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
@@ -727,9 +702,10 @@ func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f64
-func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
@@ -741,14 +717,13 @@ func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row:
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i32(
-// CHECK-SAME:                        %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                        %[[INDEX:.*]]: index)
-func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) -> vector<[4]xi32> {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_SLICE_INDEX]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
@@ -756,8 +731,9 @@ func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i8
-func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
@@ -765,8 +741,9 @@ func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i16
-func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
@@ -774,8 +751,9 @@ func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i64
-func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) -> vector<[2]xi64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
@@ -783,8 +761,9 @@ func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i128
-func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
@@ -792,8 +771,9 @@ func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f16
-func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) -> vector<[8]xf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
@@ -801,8 +781,9 @@ func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_bf16
-func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -> vector<[8]xbf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
@@ -810,8 +791,9 @@ func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f32
-func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
@@ -819,8 +801,9 @@ func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f64
-func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
@@ -828,16 +811,15 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_element(
-// CHECK-SAME:                          %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                          %[[ROW:.*]]: index,
 // CHECK-SAME:                          %[[COL:.*]]: index)
-func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_element(%row: index, %col: index) -> i32 {
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
   return %el : i32
 }
@@ -845,9 +827,10 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i8
-func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
   return %el : i8
 }
@@ -855,9 +838,10 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i16
-func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
   return %el : i16
 }
@@ -865,9 +849,10 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i64
-func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
   return %el : i64
 }
@@ -875,9 +860,10 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i128
-func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
   return %el : i128
 }
@@ -885,9 +871,10 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index,
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f16
-func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
   return %el : f16
 }
@@ -895,9 +882,10 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_bf16
-func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
   return %el : bf16
 }
@@ -905,9 +893,10 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f32
-func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
   return %el : f32
 }
@@ -915,9 +904,10 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f64
-func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ae3b260da83a2c1..2491b2e2468cdaf 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -452,8 +452,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 // CHECK: %[[C4:.*]] = arith.constant 4 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -500,8 +499,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
 // CHECK-LABEL:   func.func @splat_vec2d_from_i32(
 // CHECK-SAME:      %[[SRC:.*]]: i32) {
 // CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK:   arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK:   %[[VSCALE:.*]] = vector.vscale
 // CHECK:   %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
 // CHECK:   scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 18b95cf2fdf843c..ccd984bcb036e82 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -4,9 +4,9 @@
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles  \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index f189fd97d66cde7..81f503bec91efcc 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -3,10 +3,10 @@
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-vector-to-scf -allocate-arm-sme-tiles -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
-// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
index 413cb06b0ae1894..2c9edbf560b0110 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
@@ -2,11 +2,11 @@
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
 // RUN:   -canonicalize \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
-// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 0c186cc373a3b32..936163d1cd30d99 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 442a70cacd66508..8c73c24d695cfb2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 74b51dcc9b4df3a..965337c60b9ffdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles  \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 82f38b4dbfa9d1f..4ca61a089bdf52d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 3b218aefcd415ff..14dca2d4d7082aa 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index e2cbe735fa4ff06..2751c2d136485e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 6e33a421bf799a2..b45f24f6c8fddaf 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
 // RUN:  -e entry -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 961bb274d1e3352..5d2d0a73992f1e0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = za0_d_f64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 25ef1799e63adb1..af5fe236e5bd577 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,8 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-arm-sme-to-llvm \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-arm-sme-to-llvm -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index b3202b26f8e1e3d..7c9976bed912734 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -4,10 +4,9 @@
 llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
                                                 %nxv4i1 : vector<[4]xi1>,
                                                 %nxv16i8 : vector<[16]xi8>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
-      (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -16,10 +15,9 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
 llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
 ) -> vector<[3]xf32> {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
-  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
-      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[16]xi8>, vector<[4]xi1>, i32) -> vector<[3]xf32>
   llvm.return %res : vector<[3]xf32>
 }
 
@@ -28,9 +26,8 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
 llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
 ) -> vector<[3]xi32> {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
-  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
-      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   llvm.return %res : vector<[4]xi32>
 }
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 767d89a75eec326..edc1f749130440c 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -2,9 +2,8 @@
 
 // CHECK-LABEL: @arm_sme_zero
 llvm.func @arm_sme_zero() {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.zero(i32 0)
-  "arm_sme.intr.zero"(%c0) : (i32) -> ()
+  "arm_sme.intr.zero"() <{tile_mask = 0 : i32}> : () -> ()
   llvm.return
 }
 
@@ -18,19 +17,18 @@ llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>,
                          %nxv2i1  : vector<[2]xi1>,
                          %nxv4i1  : vector<[4]xi1>,
                          %nxv8i1  : vector<[8]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64
-  "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
-    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.mopa"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+    (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32
-  "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
-    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.mopa"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+    (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16
-  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16
-  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
   llvm.return
 }
 
@@ -41,31 +39,30 @@ llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
                          %nxv16i8 : vector<[16]xi8>,
                          %nxv8i1  : vector<[8]xi1>,
                          %nxv16i1 : vector<[16]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16
-  "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.smopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16
-  "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.umopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16
-  "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.sumopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16
-  "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.usmopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8
-  "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.smopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8
-  "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.umopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8
-  "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.sumopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
-  "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.usmopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -79,19 +76,18 @@ llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>,
                          %nxv2i1  : vector<[2]xi1>,
                          %nxv4i1  : vector<[4]xi1>,
                          %nxv8i1  : vector<[8]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.mops.nxv2f64
-  "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
-    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.mops"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+    (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.nxv4f32
-  "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
-    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.mops"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+    (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16
-  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16
-  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
   llvm.return
 }
 
@@ -102,31 +98,30 @@ llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
                          %nxv16i8 : vector<[16]xi8>,
                          %nxv8i1  : vector<[8]xi1>,
                          %nxv16i1 : vector<[16]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16
-  "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.smops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16
-  "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.umops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16
-  "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.sumops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16
-  "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.usmops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8
-  "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.smops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8
-  "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.umops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8
-  "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.sumops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
-  "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.usmops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -141,35 +136,35 @@ llvm.func @arm_sme_load(%nxv1i1  : vector<[1]xi1>,
                         %ptr    : !llvm.ptr) {
   %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.ld1q.horiz
-  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1d.horiz
-  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1w.horiz
-  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1h.horiz
-  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1b.horiz
-  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1q.vert
-  "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1d.vert
-  "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1w.vert
-  "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1h.vert
-  "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1b.vert
-  "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   llvm.return
 }
 
@@ -184,35 +179,35 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
                          %ptr    : !llvm.ptr) {
   %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.st1q.horiz
-  "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1d.horiz
-  "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1w.horiz
-  "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1h.horiz
-  "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1b.horiz
-  "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1q.vert
-  "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1d.vert
-  "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1w.vert
-  "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1h.vert
-  "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1b.vert
-  "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.str
   "arm_sme.intr.str"(%c0, %ptr, %c0) : (i32, !llvm.ptr, i32) -> ()
   llvm.return
@@ -236,34 +231,33 @@ llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
                                         %nxv8bf16 : vector<[8]xbf16>,
                                         %nxv4f32 : vector<[4]xf32>,
                                         %nxv2f64 : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
-      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
-      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+      (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
 
@@ -285,34 +279,33 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
                                        %nxv8bf16 : vector<[8]xbf16>,
                                        %nxv4f32 : vector<[4]xf32>,
                                        %nxv2f64 : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
-      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
-      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+      (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
 
@@ -334,34 +327,33 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
                                               %nxv8bf16  : vector<[8]xbf16>,
                                               %nxv4f32   : vector<[4]xf32>,
                                               %nxv2f64   : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
-  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
-    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
-  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
-  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
-  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
-  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
-    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
-  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
-  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
-  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
-  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   llvm.return
 }
 
@@ -382,33 +374,32 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
                                               %nxv8bf16 : vector<[8]xbf16>,
                                               %nxv4f32  : vector<[4]xf32>,
                                               %nxv2f64  : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
-  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
-    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
-  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
-  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
-  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
-  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
-    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
-  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
-  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
-  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
-  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   llvm.return
 }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b09..1e18d23c0044159 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -21,9 +21,9 @@
 using namespace mlir;
 
 // This is needed because these matchers are defined as overloaded functions.
-using HasOpAttrName = detail::AttrOpMatcher(StringRef);
-using HasOpName = detail::NameOpMatcher(StringRef);
-using IsConstantOp = detail::constant_op_matcher();
+using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef);
+using HasOpName = mlir::detail::NameOpMatcher(StringRef);
+using IsConstantOp = mlir::detail::constant_op_matcher();
 
 namespace test {
 #ifdef MLIR_INCLUDE_TESTS



More information about the Mlir-commits mailing list