[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Nov 24 04:04:24 PST 2023


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/73253

>From f6e3b88ee412489f95546e6748d9b036cb3e0fe6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 23 Nov 2023 14:48:32 +0000
Subject: [PATCH 1/4] [mlir][ArmSME] Switch to an attribute-based tile
 allocation scheme

This reworks the ArmSME dialect to use attributes for tile allocation.
This has a number of advantages and corrects some issues with the
previous approach:

 * Tile allocation can now be done ASAP (i.e. immediately after
   `-convert-vector-to-arm-sme`)
 * SSA form for control flow is now supported (e.g.`scf.for` loops that
   yeild tiles)
 * ArmSME ops can be converted to intrinsics very late (i.e. after
   lowering to control flow)
 * Tests are simplified by removing constants and casts
 * Avoids correctness issues with representing LLVM `immargs` as MLIR
   values
    - The tile ID on the SME intrinsics is an `immarg` (so is required
      to be a compile-time constant), `immargs` should be mapped to
      MLIR attributes (this is already the case for intrinsics in the
      LLVM dialect)
    - Using MLIR values for `immargs` can lead to invalid LLVM IR being
      generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```

These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```

The new tile allocation works by operations implementing the
`ArmSMETileOpInterface`. This interface says that an operation needs
to be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return `std::nullopt` (which
is the default behaviour).

Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```

Allocating operations become the roots for the tile allocation pass,
which currently just (naively) assigns all transitive uses of a root
operation the same tile ID. However, this is enough to handle current
use cases.

Once tile IDs have been allocated subsequent rewrites can forward the
tile IDs to any newly operations.
---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |   4 +-
 .../mlir/Dialect/ArmSME/IR/ArmSMEEnums.h      |  16 +
 .../Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td   |  52 ++-
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 283 ++++++------
 .../mlir/Dialect/ArmSME/IR/CMakeLists.txt     |   6 +
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  16 +-
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 197 ++++-----
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  47 +-
 .../VectorToArmSME/VectorToArmSME.cpp         |  23 +-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |  22 +-
 mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt     |   1 +
 .../ArmSME/Transforms/TileAllocation.cpp      | 105 +++--
 mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt  |   2 -
 mlir/lib/Dialect/ArmSME/Utils/Utils.cpp       |  44 +-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |  10 +-
 .../test/Dialect/ArmSME/arith-ops-to-sme.mlir |   6 +-
 .../Dialect/ArmSME/arm-sme-to-llvm-casts.mlir |  51 ---
 mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir | 252 ++++++-----
 mlir/test/Dialect/ArmSME/canonicalize.mlir    |  40 +-
 mlir/test/Dialect/ArmSME/cse.mlir             |  36 +-
 mlir/test/Dialect/ArmSME/invalid.mlir         |  62 +--
 mlir/test/Dialect/ArmSME/roundtrip.mlir       | 193 ++-------
 mlir/test/Dialect/ArmSME/tile-allocation.mlir | 376 ++++++++--------
 mlir/test/Dialect/ArmSME/tile-zero-masks.mlir | 102 +----
 .../Dialect/ArmSME/vector-ops-to-llvm.mlir    | 410 +++++++++---------
 .../Dialect/ArmSME/vector-ops-to-sme.mlir     |   6 +-
 .../Dialect/Linalg/CPU/ArmSME/fill-2d.mlir    |   4 +-
 .../Linalg/CPU/ArmSME/matmul-transpose-a.mlir |   4 +-
 .../Dialect/Linalg/CPU/ArmSME/matmul.mlir     |   4 +-
 .../Vector/CPU/ArmSME/test-load-vertical.mlir |   4 +-
 .../CPU/ArmSME/test-outerproduct-f32.mlir     |   4 +-
 .../CPU/ArmSME/test-outerproduct-f64.mlir     |   4 +-
 .../CPU/ArmSME/test-transfer-read-2d.mlir     |   4 +-
 .../CPU/ArmSME/test-transfer-write-2d.mlir    |   4 +-
 .../Vector/CPU/ArmSME/test-transpose.mlir     |   4 +-
 .../Dialect/Vector/CPU/ArmSME/tile_fill.mlir  |   4 +-
 .../Vector/CPU/ArmSME/vector-load-store.mlir  |   4 +-
 .../Dialect/Vector/CPU/ArmSME/vector-ops.mlir |   5 +-
 mlir/test/Target/LLVMIR/arm-sme-invalid.mlir  |  15 +-
 mlir/test/Target/LLVMIR/arm-sme.mlir          | 331 +++++++-------
 mlir/tools/mlir-query/mlir-query.cpp          |   6 +-
 41 files changed, 1271 insertions(+), 1492 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
 delete mode 100644 mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..1da8e488a4c4647 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,7 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..430f3571001c8f4
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEDialect.h - Arm SME Dialect Enums ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..2a0167afa8bae9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
   }];
 }
 
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = [],
+                    list<int> overloadedOperands = [],
                     list<Trait> traits = [], int numResults = 0,
                     list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
@@ -64,16 +67,26 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
           /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
+          /*int numResults=*/numResults,
+          /*bit requiresAccessGroup=*/0,
+          /*bit requiresAliasAnalysis=*/0,
+          /*bit requiresFastmath=*/0,
+          /*list<int> immArgPositions=*/immArgPositions,
+          /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 // Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero",
+                            /*immArgPositions=*/[0],
+                            /*immArgAttrNames=*/["tile_mask"]>,
+                            Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
 
 // MOP's
 class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[0],
+      /*immArgAttrNames=*/["tile_id"],
+      /*overloadedOperands=*/[4]>,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
                  Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
                  Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +105,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic,
+      /*immArgPositions=*/[2],
+      /*immArgAttrNames=*/["tile_id"]>;
+
 // Loads
 class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Load address">:$load_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +131,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
 
 // Stores
 class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +156,28 @@ def LLVM_aarch64_sme_str
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+    : ArmSME_IntrOp<"write." # direction,
+                    /*immArgPositions=*/[0],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[3],
                     [AllShapesMatch<["predicate", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
 // Tile slice to vector
 class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+    : ArmSME_IntrOp<"read." # direction,
+                    /*immArgPositions=*/[2],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[],
                     [AllShapesMatch<["vector", "predicate", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
                     /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
-                     Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..abcc9b649c4a530 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,99 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+    [
+      I32EnumAttrCase<"ZAB", 0, "za.b">,
+      I32EnumAttrCase<"ZAH", 1, "za.h">,
+      I32EnumAttrCase<"ZAS", 2, "za.s">,
+      I32EnumAttrCase<"ZAD", 3, "za.d">,
+      I32EnumAttrCase<"ZAQ", 4, "za.q">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+  let description = [{
+    An interface for operations that use or allocate Arm SME tiles. These
+    operations need to be assigned a tile ID an i32 attribute, which specifies
+    which virtual tile within the ZA storage to use. The number of tiles
+    available depends on the type of the tile. This is summarized below:
+
+    | Tile Vector Types                                                       | Possible Tile IDs   |
+    |-------------------------------------------------------------------------|---------------------|
+    | `vector<[16]x[16]xi8>`                                                  | 0                   |
+    | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1             |
+    | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
+    | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
+    | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
+
+    Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+    the roots for tile allocation, with all operations that (transitively)
+    depend on a root being assigned the same tile ID.
+  }];
+  let methods = [
+    InterfaceMethod<
+      "Sets the tile ID for this operation.",
+      /*returnType=*/"void",
+      /*methodName=*/"setTileId",
+      /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        if (!tileId)
+          return;
+        ::mlir::Operation* op = this->getOperation();
+        op->setAttr("tile_id", tileId);
+      }]
+    >,
+    InterfaceMethod<
+      "Returns the (possibly null) tile ID assigned to this operation.",
+      /*returnType=*/"mlir::IntegerAttr",
+      /*methodName=*/"getTileId",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        ::mlir::Operation* op = this->getOperation();
+        return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+      }]
+    >,
+    InterfaceMethod<
+      "The type of tile this operation allocates (or none)",
+      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+      /*methodName=*/"getAllocatedTileType",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        // Do not allocate a new tile.
+        return std::nullopt;
+      }]
+    >
+  ];
+
+  let extraSharedClassDeclaration = [{
+    // A helper to create a new operation and propagate this operations tile ID.
+    template<typename T, typename... Args>
+    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+        tileOp.setTileId($_op.getTileId());
+      return op;
+    }
+
+    // A helper to replace this operation and forward any tile ID.
+    template<typename T, typename... Args>
+    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+      rewriter.replaceOp($_op, newOp);
+      return newOp;
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -44,7 +137,8 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
-                        "a vector type that fits into a SME tile">
+                        "a vector type that fits into a SME tile",
+                        "VectorType">
 {
   let description = [{
     Possible vector types:
@@ -66,40 +160,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
   }];
 }
 
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
-                       "an identifier of a virtual tile (of a size) within the ZA storage">
-{
-  let description = [{
-    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
-    the integer indicates the tile to use, and the bit size indicates the size
-    of tile. The number of tiles available and the element types of those depend
-    on the size. This is summarised below:
-
-    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
-    |--------------|---------------------|-------------------------------------------------------------------------|
-    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
-    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
-    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
-    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
-    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
-  }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
-  "`tile_id` has the same number of bits as elements in `vector`",
-  "vector", "tile_id",
-  "IntegerType::get("
-      "$_self.getContext(),"
-      "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
-          "? ::llvm::cast<IntegerType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth()"
-          ": ::llvm::cast<FloatType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth())">;
-
 class HasMatchingMaskTypeConstraint<string vector, string mask> :
   OptionalTypesMatchWith<
     mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +222,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+  let summary = "Returns a SME virtual tile";
   let description = [{
-    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
-    scalable vector type, which represents an SME "virtual tile". This would
-    normally be used when lowering operations that return "virtual tile" vector
-    types to model the output. This is required to preserve dataflow as SME
-    intrinsics have no return values.
+    Allocates a new SME "virtual tile" within a function. The contents of the
+    tile returned from this operation undefined.
 
-    Example:
+    Example 1:
 
-    Input:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 8-bit element "virtual tile"
+    %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
-    After lowering `vector.load`:
+    Example 2:
+
     ```mlir
-    %tile_id = arm_sme.get_tile_id : i32
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-    }
-    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate two 16-bit element "virtual tiles"
+    %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+    %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
-    In the example above, the `vector.load` can't be replaced with an SME
-    intrinsic that has no outputs since it is used by the `vector.store`.
-    However, by inserting a `cast_tile_to_vector` op after the load intrinsics
-    the `vector.load` can be replaced. This enables "local" rewrites on
-    individual vector ops, rather than "global" rewrites that would have to
-    look at the vector op uses and also lower them.
-
-    Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
-    the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
-  }];
-  let arguments = (ins TileID:$tile_id);
-  let results = (outs SMETile:$vector);
-  let assemblyFormat =
-    "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
-  let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from 2-d scalable vector type to tile id";
-  let description = [{
-    A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
-    type, which represents an SME "virtual tile", to a tile id. This is
-    required to preserve dataflow as the SME intrinsics have no return values.
-
-    Example:
-
-    Input:
+    Example 3:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 128-bit element "virtual tile"
+    %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
+  }];
 
-    After lowering `vector.store`:
-    ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
-      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
+
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
     }
-    ```
 
-    Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
-    the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getTileType());
+    }
   }];
-  let arguments = (ins SMETile:$vector);
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat =
-    "$vector attr-dict `:` type($vector) `to` type($tile_id)";
-  let hasCanonicalizeMethod = 1;
 }
 
-def GetTileID : ArmSME_Op<"get_tile_id"> {
-  let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+  let summary = "SME tile placeholder";
   let description = [{
-    A `get_tile_id` operation returns a scalar integer representing an SME
-    "virtual tile" id. The bitwidth of the scalar indicates the element
-    bitwidth of the "virtual tile".
-
-    The scope of a tile id is a function and cannot be passed or returned from
-    functions.
+    A placeholder to preserve dataflow while lowering to SME intrinsics (which
+    do not take or return tile values). This operation is intended to be DCE'd
+    once all ArmSME operations have been lowered.
 
-    Example:
-    ```mlir
-    // Allocate and return an 8-bit element "virtual tile" id
-    %za0_b = arm_sme.get_tile_id : i8
-    ```
-
-    Example:
-    ```
-    // Allocate and return two 16-bit element "virtual tile" ids
-    %za0_h = arm_sme.get_tile_id : i16
-    %za1_h = arm_sme.get_tile_id : i16
-    ```
-
-    Example:
-    ```
-    // Allocate and return an 128-bit element "virtual tile" id
-    %za0_q = arm_sme.get_tile_id : i128
-    ```
+    This operation is not intended to be used outside of the ArmSME -> LLVM
+    conversion.
   }];
-
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat = "attr-dict `:` type($tile_id)";
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
 //
 // Tile reset.
 //
 
-def ZeroOp : ArmSME_Op<"zero", [Pure]> {
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let summary = "Initialize the two-dimensional ZA array with 0s";
   let results = (outs SMETile:$res);
   let description = [{
@@ -303,11 +305,15 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getVectorType());
+    }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
 def TileLoadOp : ArmSME_Op<"tile_load", [
+  ArmSMETileOpInterface,
   AttrSizedOperandSegments,
   OptionalTypesMatchWith<
     "padding type matches element type of result",
@@ -375,6 +381,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getVectorType());
+    }
   }];
 
   let builders = [
@@ -394,6 +403,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store", [
+  ArmSMETileOpInterface,
   AttrSizedOperandSegments,
   HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
 ]> {
@@ -457,6 +467,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
+  ArmSMETileOpInterface,
   AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
 ]> {
   let summary = "Tile slice load and update operation";
@@ -515,6 +526,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 }
 
 def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
+  ArmSMETileOpInterface,
   TileSliceMaskConstraint<"tile", "mask">
 ]> {
   let summary = "Tile slice store operation";
@@ -570,6 +582,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
+    ArmSMETileOpInterface,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -617,7 +630,8 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 }
 
-def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
+def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
+    ArmSMETileOpInterface,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -670,7 +684,8 @@ class OuterProductResultTileTypeConstraint<string operand> :
     "}()">;
 
 def OuterProductOp :
-  ArmSME_Op<"outerproduct", [Pure,
+  ArmSME_Op<"outerproduct", [
+    ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
     HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
@@ -715,7 +730,7 @@ def OuterProductOp :
     ```
   }];
 
-  let arguments = (ins
+let arguments = (ins
     SVEVector:$lhs, SVEVector:$rhs,
     Optional<SVEPredicate>:$lhsMask,
     Optional<SVEPredicate>:$rhsMask,
@@ -736,6 +751,12 @@ def OuterProductOp :
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
   }];
 }
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 9319a8042c93f4a..9801d8b099e3f0a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,6 +15,12 @@ set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
 mlir_tablegen(ArmSMEOpsConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
 
+# Generate op interface declarations and definitions
+set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
+mlir_tablegen(ArmSMEOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ArmSMEOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRArmSMEOpInterfaces)
+
 # Generate declarations and definitions of ArmSME intrinsic Ops
 set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
 mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 0941592497beaae..954371cb9d5a0b1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -15,10 +15,11 @@
 #ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include <optional>
 
-namespace mlir {
-namespace arm_sme {
+namespace mlir::arm_sme {
 
 constexpr unsigned MinStreamingVectorLengthInBits = 128;
 
@@ -34,13 +35,8 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+std::optional<ArmSMETileType> getSMETileType(VectorType);
 
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..4f4b090dd10e3c0 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,6 +32,26 @@ using namespace mlir;
 
 namespace {
 
+IntegerAttr getTileIdOrError(ArmSMETileOpInterface op) {
+  auto tileId = op.getTileId();
+  if (!tileId)
+    op.emitOpError(
+        "expected tile ID to be allocated before conversion to LLVM");
+  return tileId;
+}
+
+struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTile> {
+  using ConvertOpToLLVMPattern<arm_sme::GetTile>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::GetTile getTile, OpAdaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+        getTile, getTile.getTileType());
+    return success();
+  }
+};
+
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
@@ -61,9 +81,9 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     unsigned tileElementWidth =
         zero.getVectorType().getElementType().getIntOrFloatBitWidth();
 
-    // Get Tile ID for the `zero` intrinsic.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
+    auto tileId = getTileIdOrError(zero);
+    if (!tileId)
+      return failure();
 
     // Get the base mask for tile based on the element size.
     // The base mask is just the mask to zero the first tile (of a size).
@@ -93,9 +113,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
         llvm_unreachable("bad element size");
       }
     }();
-    auto maskType = rewriter.getI32Type();
-    auto baseMask = rewriter.create<arith::ConstantOp>(
-        loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
 
     // The actual mask is just the base mask shifted by the tile ID.
     // This will be folded to a constant after tile allocation.
@@ -118,13 +135,13 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     //  * Mask    -> 10001000 = (00010001 << 3)
     //
     // This holds for all tile sizes.
-    auto tileMask = rewriter.create<arith::ShLIOp>(
-        loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
-    rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
+    int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
+    rewriter.create<arm_sme::aarch64_sme_zero>(
+        loc, rewriter.getI32IntegerAttr(zeroMask));
 
-    // Create `CastTileToVectorOp` to use as the output.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
-                                                           tileId);
+    // Create a placeholder op to preserve dataflow.
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+        zero, zero.getVectorType());
 
     return success();
   }
@@ -141,15 +158,9 @@ struct LoadTileSliceConversion
                   arm_sme::LoadTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = loadTileSliceOp.getLoc();
-    auto tileType = loadTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
-
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
-    // loaded to.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        loadTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(loadTileSliceOp);
+    if (!tileId)
+      return failure();
 
     Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
                                            adaptor.getBase(),
@@ -164,7 +175,9 @@ struct LoadTileSliceConversion
     // Create all active predicate mask.
     auto maskOp = loadTileSliceOp.getMask();
 
-    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+    auto tileType = loadTileSliceOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
 
     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
@@ -174,23 +187,23 @@ struct LoadTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 16:
         rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 32:
         rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 64:
         rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 128:
         rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       }
     } else {
@@ -199,31 +212,30 @@ struct LoadTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 16:
         rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 32:
         rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 64:
         rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 128:
         rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       }
     }
 
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
-    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
-                                                           tileType, tile);
+    // the input tile to preserve dataflow.
+    rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
 
     return success();
   }
@@ -244,11 +256,9 @@ struct StoreTileSliceConversion
     auto tileElementType = tileType.getElementType();
     unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
-    // being stored.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        storeTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(storeTileSliceOp);
+    if (!tileId)
+      return failure();
 
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
     Value ptr = this->getStridedElementPtr(
@@ -263,7 +273,6 @@ struct StoreTileSliceConversion
 
     auto maskOp = storeTileSliceOp.getMask();
 
-    Value tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
 
     if (layout == arm_sme::TileSliceLayout::Horizontal) {
@@ -272,23 +281,23 @@ struct StoreTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       }
     } else {
@@ -297,23 +306,23 @@ struct StoreTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       }
     }
@@ -334,14 +343,10 @@ struct MoveVectorToTileSliceConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveVectorToTileSliceOp.getLoc();
     auto tileType = moveVectorToTileSliceOp.getTileType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
-    // loaded to.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        moveVectorToTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
+    if (!tileId)
+      return failure();
 
     auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
 
@@ -357,26 +362,24 @@ struct MoveVectorToTileSliceConversion
                                   /*scalableDims=*/{true});
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
-    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-
     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
     switch (moveVectorToTileSliceOp.getLayout()) {
     case arm_sme::TileSliceLayout::Horizontal:
       rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-          loc, tileI32, tileSliceI32, allActiveMask,
+          loc, tileId, tileSliceI32, allActiveMask,
           moveVectorToTileSliceOp.getVector());
       break;
     case arm_sme::TileSliceLayout::Vertical:
       rewriter.create<arm_sme::aarch64_sme_write_vert>(
-          loc, tileI32, tileSliceI32, allActiveMask,
+          loc, tileId, tileSliceI32, allActiveMask,
           moveVectorToTileSliceOp.getVector());
       break;
     }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
-    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
-        moveVectorToTileSliceOp, tileType, tile);
+    // the input tile to preserve dataflow.
+    rewriter.replaceOp(moveVectorToTileSliceOp,
+                       moveVectorToTileSliceOp.getTile());
 
     return success();
   }
@@ -394,12 +397,11 @@ struct MoveTileSliceToVectorConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveTileSliceToVector.getLoc();
     auto sliceType = moveTileSliceToVector.getSliceType();
-    auto tile = moveTileSliceToVector.getTile();
     auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
 
-    // Cast tile to i32 tile ID.
-    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
-    Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+    auto tileId = getTileIdOrError(moveTileSliceToVector);
+    if (!tileId)
+      return failure();
 
     // Create an 'all true' predicate for the tile slice.
     auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
@@ -419,12 +421,12 @@ struct MoveTileSliceToVectorConversion
     case arm_sme::TileSliceLayout::Horizontal:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
           moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-          tileIdI32, sliceIndexI32);
+          tileId, sliceIndexI32);
       break;
     case arm_sme::TileSliceLayout::Vertical:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
           moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-          tileIdI32, sliceIndexI32);
+          tileId, sliceIndexI32);
       break;
     }
 
@@ -454,6 +456,10 @@ struct OuterProductOpConversion
   matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
                   arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto tileId = getTileIdOrError(outerProductOp);
+    if (!tileId)
+      return failure();
+
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
       // different features [1]:
@@ -498,13 +504,8 @@ struct OuterProductOpConversion
     Value acc = outerProductOp.getAcc();
     if (!acc)
       // Initalize accumulator with zero.
-      acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
-
-    unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
-    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(elementWidth), acc);
-
-    auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
+      acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+          rewriter, loc, resultVectorType);
 
     Value lhsMask = outerProductOp.getLhsMask();
     Value rhsMask = outerProductOp.getRhsMask();
@@ -519,13 +520,13 @@ struct OuterProductOpConversion
     }
 
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
                                                outerProductOp.getLhs(),
                                                outerProductOp.getRhs());
 
-    // Create `CastTileToVectorOp` to use as the output.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
-        outerProductOp, resultVectorType, tileId);
+    // The outerproduct intrinsics have no result, replace
+    // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
+    rewriter.replaceOp(outerProductOp, acc);
 
     return success();
   }
@@ -557,21 +558,20 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
-      arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
-      arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
-      arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
-      arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
-      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
-      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
-      arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
-      arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
-      arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
-      arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
-      arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
-      arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
-      arm_sme::aarch64_sme_mopa>();
+      arm_sme::MaterializeSSATile, arm_sme::aarch64_sme_zero,
+      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
 }
@@ -580,7 +580,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(
     ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
-               OuterProductOpConversion, ZeroOpConversion>(converter);
+               OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
+      converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index a4dfd4b7edc77c5..541d711fbd95f29 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -87,16 +87,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
     auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.load_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+    // Allocate a new SME tile.
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+        rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -128,8 +122,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
+    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+        rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
         memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -209,7 +203,8 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // Initialize tile with zero to satisfy padding. Inactive cols will be
     // zeroed anyway since the loads use zeroing predication. For inactive rows
     // however, no load will occur so these need to be zeroed.
-    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+        rewriter, loc, tileType);
 
     // Create a loop to load the active tile slices from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -226,9 +221,9 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
-        tileSliceIndex, tileLoadOp.getLayout());
+    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+        rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
+        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -276,7 +271,6 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
     auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
     auto maskOp = tileLoadOp.getMask();
     if (!maskOp)
@@ -304,14 +298,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), numCols);
 
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.load_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+    // Allocate a new SME tile.
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+        rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -356,8 +345,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
+    tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+        rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
         tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -450,8 +439,9 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     getMemrefIndices(tileStoreOp.getIndices(),
                      tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
-        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+
+    tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+        rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
         tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
 
     return success();
@@ -506,6 +496,7 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
       rewriter.setInsertionPointToStart(forOp.getBody());
       // Extract the current row from the tile.
       Value rowIndex = forOp.getInductionVar();
+      // FIXME: Forward tile IDs.
       auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
           loc, printOp.getSource(), rowIndex);
       // Print the row with a 1D vector.print.
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index f1e92d1ac9708b7..109b04ce34a88b7 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -44,20 +44,6 @@ static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
   return forOp;
 }
 
-/// Returns a tile of the given vector type.
-static arm_sme::CastTileToVector
-getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
-                          VectorType type) {
-  unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
-
-  // Create 'arm_sme.get_tile' op.
-  auto tileId = rewriter.create<arm_sme::GetTileID>(
-      loc, rewriter.getIntegerType(tileElementWidth));
-
-  // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
-  return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
-}
-
 namespace {
 
 /// Conversion pattern for vector.transfer_read.
@@ -267,8 +253,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
 
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
     auto tileSliceIndex = forOp.getInductionVar();
@@ -330,8 +315,7 @@ struct BroadcastOpToArmSMELowering
     else
       return failure();
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
 
     // Create a loop over ZA tile slices.
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
@@ -387,8 +371,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
         loc, tileSliceType, splatOp.getInput());
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 9df15420b9c9b6f..29fa9085a0a963d 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -28,6 +28,8 @@ using namespace mlir::arm_sme;
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOps.cpp.inc"
 
@@ -54,23 +56,3 @@ void ArmSMEDialect::initialize() {
 #include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
       >();
 }
-
-// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id
-LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op,
-                                             PatternRewriter &rewriter) {
-  if (auto castTileToVectorOp = op.getVector().getDefiningOp<CastTileToVector>()) {
-    op.replaceAllUsesWith(castTileToVectorOp.getTileId());
-    return success();
-  }
-  return failure();
-}
-
-// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile
-LogicalResult CastTileToVector::canonicalize(CastTileToVector op,
-                                             PatternRewriter &rewriter) {
-  if (auto castVectorToTileOp = op.getTileId().getDefiningOp<CastVectorToTile>()) {
-    op.replaceAllUsesWith(castVectorToTileOp.getVector());
-    return success();
-  }
-  return failure();
-}
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 3e448ec4fb1e04d..66062335fa842a4 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRVectorDialect
+  MLIRArmSMEUtils
 )
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e0462a6dc124131..0a65efeb266d15b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -41,10 +41,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "allocate-arm-sme-tiles"
 
@@ -107,73 +109,85 @@ enum class TileMask : unsigned {
 };
 
 /// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
-  static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
-  static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
-                                                   TileMask::kZA1H};
-  static const SmallVector<TileMask> ZA_S_MASKS = {
-      TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
-  static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+                                            TileMask::kZA2S, TileMask::kZA3S};
+  static constexpr std::array ZA_D_MASKS = {
       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
-  static const SmallVector<TileMask> ZA_Q_MASKS = {
+  static constexpr std::array ZA_Q_MASKS = {
       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
-  switch (cast<IntegerType>(type).getWidth()) {
-  default:
-    llvm_unreachable("unexpected type!");
-  case 8:
+  switch (type) {
+  case ArmSMETileType::ZAB:
     return ZA_B_MASKS;
-  case 16:
+  case ArmSMETileType::ZAH:
     return ZA_H_MASKS;
-  case 32:
+  case ArmSMETileType::ZAS:
     return ZA_S_MASKS;
-  case 64:
+  case ArmSMETileType::ZAD:
     return ZA_D_MASKS;
-  case 128:
+  case ArmSMETileType::ZAQ:
     return ZA_Q_MASKS;
   }
 }
 
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
-                             unsigned &tileID) {
-  auto masks = getMasks(tileIDOp.getType());
-  for (const auto &it : llvm::enumerate(masks)) {
-    const auto tileMask = it.value();
+/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
+static FailureOr<unsigned> getTile(ArmSMETileType tileType,
+                                   TileMask &tilesInUse) {
+  auto masks = getMasks(tileType);
+  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
     if ((tilesInUse & tileMask) == TileMask::kNone) {
       tilesInUse |= tileMask;
-      tileID = it.index();
-      return success();
+      return tileId;
     }
   }
-  return tileIDOp.emitError("ran out of SME virtual tiles!");
+  return failure();
 }
 
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetTileID tileIDOp,
+struct AssignTileIDsPattern
+    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
                                 PatternRewriter &rewriter) const override {
-    auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+    if (tileOp.getTileId())
+      return failure();
+
+    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+    if (!tileType)
+      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
     TileMask tilesInUse;
-    if (auto tilesInUseAttr =
-            funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+    if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
       tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
     else
       tilesInUse = TileMask::kNone;
 
-    unsigned tileID;
-    if (failed(getTile(tileIDOp, tilesInUse, tileID)))
-      return failure();
+    auto tileId = getTile(*tileType, tilesInUse);
+    if (failed(tileId))
+      return tileOp.emitError("ran out of SME virtual tiles!");
 
-    funcOp->setAttr(kTilesInUseAttr,
-                    rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+    func->setAttr(kTilesInUseAttr,
+                  rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+
+    // Find all the ops that (transitively) depend on this tile.
+    SetVector<Operation *> dependantOps;
+    getForwardSlice(tileOp.getOperation(), &dependantOps);
+
+    // Set all operations to use the same tile ID.
+    // This is a navie tile allocation scheme, but works for common cases.
+    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
+    tileOp.setTileId(tileIDAttr);
+    for (auto *op : dependantOps) {
+      if (auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op))
+        tileOp.setTileId(tileIDAttr);
+    }
 
-    auto tileType = tileIDOp.getType();
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-        tileIDOp, tileType, rewriter.getIntegerAttr(tileType, tileID));
     return success();
   }
 };
@@ -182,13 +196,14 @@ struct TileAllocationPass
     : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    ConversionTarget target(getContext());
-    patterns.add<GetTileIDConversion>(patterns.getContext());
-    target.addLegalOp<arith::ConstantOp>();
-    target.addIllegalOp<GetTileID>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    patterns.add<AssignTileIDsPattern>(patterns.getContext());
+    GreedyRewriteConfig config;
+    // This ensures tiles are allocated in program order.
+    config.useTopDownTraversal = true;
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(patterns), config))) {
       signalPassFailure();
+    }
   }
 };
 } // namespace
diff --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
index da8517aaf80a9fa..ecf774a215d24f8 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
@@ -5,7 +5,5 @@ add_mlir_dialect_library(MLIRArmSMEUtils
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
 
   LINK_LIBS PUBLIC
-  MLIRArmSMEDialect
-  MLIRDialect
   MLIRIR
   )
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index f17077ff8565d59..c3cdf5703bcbc28 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,24 +12,20 @@
 
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 
-#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+namespace mlir::arm_sme {
 
-using namespace mlir;
-using namespace mlir::arm_sme;
-
-unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
+unsigned getSMETileSliceMinNumElts(Type type) {
   assert(isValidSMETileElementType(type) && "invalid tile type!");
   return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
 }
 
-bool mlir::arm_sme::isValidSMETileElementType(Type type) {
+bool isValidSMETileElementType(Type type) {
   return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
          type.isInteger(64) || type.isInteger(128) || type.isF16() ||
          type.isBF16() || type.isF32() || type.isF64() || type.isF128();
 }
 
-bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
+bool isValidSMETileVectorType(VectorType vType) {
   if ((vType.getRank() != 2) || !vType.allDimsScalable())
     return false;
 
@@ -37,22 +33,30 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
   if (!isValidSMETileElementType(elemType))
     return false;
 
-  unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
+  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
   if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
     return false;
 
   return true;
 }
 
-Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
-                                     RewriterBase &rewriter) {
-  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
-             tile.getDefiningOp())) &&
-         "expected ArmSME GetTileID or CastVectorToTile op!");
-  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
-  if (tileElementWidth < 32)
-    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
-  if (tileElementWidth > 32)
-    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
-  return tile;
+std::optional<ArmSMETileType> getSMETileType(VectorType type) {
+  if (!isValidSMETileVectorType(type))
+    return {};
+  switch (type.getElementTypeBitWidth()) {
+  case 8:
+    return ArmSMETileType::ZAB;
+  case 16:
+    return ArmSMETileType::ZAH;
+  case 32:
+    return ArmSMETileType::ZAS;
+  case 64:
+    return ArmSMETileType::ZAD;
+  case 128:
+    return ArmSMETileType::ZAQ;
+  default:
+    llvm_unreachable("unknown SME tile type");
+  }
 }
+
+} // namespace mlir::arm_sme
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 38a2332ffd5e0a3..fc28645a7acf7c0 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,8 +6,7 @@
 
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -16,7 +15,7 @@
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
 // CHECK-NEXT:      %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -60,8 +59,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
 // CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
 // CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
-// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -79,7 +77,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
 // CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index
diff --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index 12f7e7333ebc973..b8db105f9c601b4 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -95,8 +95,7 @@ func.func @arith_constant_dense_2d_zero_f64() {
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[C16:.*]] = arith.constant 16 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -115,8 +114,7 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[C2:.*]] = arith.constant 2 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i64 to vector<[2]x[2]xf64>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
deleted file mode 100644
index 65996e81c42d909..000000000000000
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ /dev/null
@@ -1,51 +0,0 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
-
-// This test verifies the temporary casts that are emitted when lowering to
-// intrinsics to preserve data flow are correct. Canonicalization will remove
-// these.
-
-// CHECK-LABEL: @arm_sme_zero
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: arm_sme.intr.zero
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  %tile = arm_sme.zero : vector<[16]x[16]xi8>
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_load
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-// CHECK: }
-// CHECK: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
-func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
-  %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return %tile : vector<[16]x[16]xi8>
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_store(
-// CHECK-SAME:                      %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return
-}
diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index fa62332bc3f5b17..bd88da37bdf966d 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
@@ -9,23 +9,21 @@
 // CHECK-LABEL:   func.func @arm_sme_load_tile_slice_hor_i8(
 // CHECK-SAME:                                              %[[SRC:.*]]: memref<?x?xi8>,
 // CHECK-SAME:                                              %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME:                                              %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index)
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -33,9 +31,10 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -43,9 +42,10 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -53,9 +53,10 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -63,9 +64,10 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -73,9 +75,10 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -83,9 +86,10 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -93,9 +97,10 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -103,9 +108,10 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -113,9 +119,10 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -123,9 +130,10 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -133,9 +141,10 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -143,9 +152,10 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -153,9 +163,10 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -163,9 +174,10 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -173,9 +185,10 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -183,9 +196,10 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -193,9 +207,10 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -207,25 +222,23 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // -----
 
 // CHECK-LABEL:   func.func @arm_sme_store_tile_slice_hor_i8(
-// CHECK-SAME:                                               %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                               %[[TILE_SLICE_INDEX:.*]]: index,
 // CHECK-SAME:                                               %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>)
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i8(%tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -233,9 +246,10 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -243,9 +257,10 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -253,9 +268,10 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -263,9 +279,10 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -273,9 +290,10 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -283,9 +301,10 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -293,9 +312,10 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -303,9 +323,10 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -313,9 +334,10 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i8(%tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -323,9 +345,10 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -333,9 +356,10 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -343,9 +367,10 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -353,9 +378,10 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -363,9 +389,10 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -373,9 +400,10 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -383,9 +411,10 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -393,9 +422,10 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -407,9 +437,10 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
   return
 }
@@ -417,9 +448,10 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return
 }
@@ -431,8 +463,9 @@ func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile_slice_index : index) -> vector<[16]xi8> {
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
@@ -440,8 +473,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile_slice_index : index) -> vector<[8]xi16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
@@ -449,8 +483,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile_slice_index : index) -> vector<[4]xi32> {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
@@ -458,8 +493,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile_slice_index : index) -> vector<[2]xi64> {
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
@@ -467,8 +503,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
@@ -476,8 +513,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile_slice_index : index) -> vector<[8]xf16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
@@ -485,8 +523,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile_slice_index : index) -> vector<[8]xbf16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
@@ -494,8 +533,9 @@ func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile_slice_index : index) -> vector<[4]xf32> {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
@@ -503,8 +543,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile_slice_index : index) -> vector<[2]xf64> {
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
@@ -512,8 +553,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
-// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index 06bbd3050fdecec..b7ba3f728c705a3 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,27 +1,25 @@
 // RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-// -----
-
-// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector
-// CHECK-SAME: %[[TILE_ID:.*]]: i8
-func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 {
-  // CHECK-NOT: arm_sme.cast_tile_to_vector
-  // CHECK-NOT: arm_sme.cast_vector_to_tile
-  // CHECK-NEXT: return %[[TILE_ID]] : i8
-  %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
-  %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
-  return %tile_id_1 : i8
-}
+// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
+// once it becomes unused, after lowering to control flow.
 
 // -----
 
-// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>
-func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> {
-  // CHECK-NOT: arm_sme.cast_vector_to_tile
-  // CHECK-NOT: arm_sme.cast_tile_to_vector
-  // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8>
-  %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8
-  %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
-  return %tile_1 : vector<[16]x[16]xi8>
+// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
+// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-NOT: vector<[4]x[4]xf32>
+func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+  %c10 = arith.constant 10 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+  cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
+^bb1(%1: index, %2: vector<[4]x[4]xf32>):  // 2 preds: ^bb0, ^bb2
+  %3 = arith.cmpi slt, %1, %c10 : index
+  cf.cond_br %3, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  %4 = arith.addi %1, %c1 : index
+  cf.br ^bb1(%4, %tile : index, vector<[4]x[4]xf32>)
+^bb3:  // pred: ^bb1
+  return
 }
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
index 734bd9f15c8dea1..74e7293eaeca5fc 100644
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ b/mlir/test/Dialect/ArmSME/cse.mlir
@@ -1,16 +1,30 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
 
-// This test is checking that CSE does not remove 'arm_sme.get_tile_id' ops as
+// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
 // duplicates.
-// CHECK-LABEL: @get_tile_id
-// CHECK: %[[TILE_ID_0:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE_ID_1:.*]] = arm_sme.get_tile_id : i32
-// CHECK: "prevent.dce"(%[[TILE_ID_0]]) : (i32) -> ()
-// CHECK: "prevent.dce"(%[[TILE_ID_1]]) : (i32) -> ()
-func.func @get_tile_id() {
-  %tile_id_1 = arm_sme.get_tile_id : i32
-  %tile_id_2 = arm_sme.get_tile_id : i32
-  "prevent.dce"(%tile_id_1) : (i32) -> ()
-  "prevent.dce"(%tile_id_2) : (i32) -> ()
+
+// CHECK-LABEL: @zero_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @zero_tile() {
+  %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
+  %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
+  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// CHECK-LABEL: @get_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @get_tile() {
+  %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index d35eb7f7bccc75e..85b95a8b6cf12b7 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,89 +1,49 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
 //===----------------------------------------------------------------------===//
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
-  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16>
-  return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_rank_1() -> vector<[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
+  %0 = arm_sme.get_tile : vector<[16]xi8>
   return %0 : vector<[16]xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
+func.func @arm_sme_get_tile__bad_vector_type_i4() -> vector<[16]x[16]xi4> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
+  %0 = arm_sme.get_tile : vector<[16]x[16]xi4>
   return %0 : vector<[16]x[16]xi4>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_0() -> vector<16x[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
+  %0 = arm_sme.get_tile : vector<16x[16]xi8>
   return %0 : vector<16x[16]xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_1() -> vector<[16]x16xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
+  %0 = arm_sme.get_tile : vector<[16]x16xi8>
   return %0 : vector<[16]x16xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
+func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
+  %0 = arm_sme.get_tile : vector<[4]x[16]xi8>
   return %0 : vector<[4]x[16]xi8>
 }
 
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
-  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32
-  return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
-  // expected-error at +1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
-  return %0 : i8
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id__bad_type() -> i1 {
-  // expected-error at +1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
-  %0 = arm_sme.get_tile_id : i1
-  return %0 : i1
-}
-
 //===----------------------------------------------------------------------===//
 // arm_sme.move_vector_to_tile_slice
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 1dbcc32e9259fc5..58ff7ef4d8340ec 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,197 +1,78 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
 //===----------------------------------------------------------------------===//
 
-func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
-  return %0 : vector<[16]x[16]xi8>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16>
-  return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-  return %0 : vector<[4]x[4]xi32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64>
-  return %0 : vector<[2]x[2]xi64>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128>
-  return %0 : vector<[1]x[1]xi128>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16>
-  return %0 : vector<[8]x[8]xf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16>
-  return %0 : vector<[8]x[8]xbf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32>
-  return %0 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64>
-  return %0 : vector<[2]x[2]xf64>
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8
-  return %0 : i8
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16
-  return %0 : i16
-}
 
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32
-  return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64
-  return %0 : i64
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128
-  return %0 : i128
+func.func @arm_sme_get_tile_i8() {
+  // CHECK: arm_sme.get_tile : vector<[16]x[16]xi8>
+  %0 = arm_sme.get_tile : vector<[16]x[16]xi8>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16
-  return %0 : i16
+func.func @arm_sme_get_tile_i16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xi16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xi16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16
-  return %0 : i16
+func.func @arm_sme_get_tile_i32() {
+  // CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
+  %0 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32
-  return %0 : i32
+func.func @arm_sme_get_tile_i64() {
+  // CHECK: arm_sme.get_tile : vector<[2]x[2]xi64>
+  %0 = arm_sme.get_tile : vector<[2]x[2]xi64>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64
-  return %0 : i64
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id_i8() -> i8 {
-  // CHECK: arm_sme.get_tile_id : i8
-  %0 = arm_sme.get_tile_id : i8
-  return %0 : i8
+func.func @arm_sme_get_tile_i128() {
+  // CHECK: arm_sme.get_tile : vector<[1]x[1]xi128>
+  %0 = arm_sme.get_tile : vector<[1]x[1]xi128>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i16() -> i16 {
-  // CHECK: arm_sme.get_tile_id : i16
-  %0 = arm_sme.get_tile_id : i16
-  return %0 : i16
+func.func @arm_sme_get_tile_f16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xf16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xf16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i32() -> i32 {
-  // CHECK: arm_sme.get_tile_id : i32
-  %0 = arm_sme.get_tile_id : i32
-  return %0 : i32
+func.func @arm_sme_get_tile_bf16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xbf16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xbf16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i64() -> i64 {
-  // CHECK: arm_sme.get_tile_id : i64
-  %0 = arm_sme.get_tile_id : i64
-  return %0 : i64
+func.func @arm_sme_get_tile_f32() {
+  // CHECK: arm_sme.get_tile : vector<[4]x[4]xf32>
+  %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i128() -> i128 {
-  // CHECK: arm_sme.get_tile_id : i128
-  %0 = arm_sme.get_tile_id : i128
-  return %0 : i128
+func.func @arm_sme_get_tile_f64() {
+  // CHECK: arm_sme.get_tile : vector<[2]x[2]xf64>
+  %0 = arm_sme.get_tile : vector<[2]x[2]xf64>
+  return
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
index a481516d4c15f93..1f895e4984ba844 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -6,17 +6,17 @@
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
 func.func @mixed_tiles() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.D ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // ZA15.Q is still free.
   return
 }
@@ -26,28 +26,26 @@ func.func @mixed_tiles() {
 // CHECK-LABEL: za_b
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_b() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_b = arm_sme.get_tile_id : i8
+  // CHECK-NEXT: tile_id = 0
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @za_b__out_of_tiles() {
-  %za0_b = arm_sme.get_tile_id : i8
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i8
+  %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @za_b_overlapping_za_q() {
-  %za0_b = arm_sme.get_tile_id : i8
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -56,8 +54,8 @@ func.func @za_b_overlapping_za_q() {
 // CHECK-LABEL: za0_h
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
 func.func @za0_h() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
@@ -66,21 +64,23 @@ func.func @za0_h() {
 // CHECK-LABEL: za_h
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
-  // CHECK-NEXT: arith.constant 1
-  %za1_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  // CHECK-NEXT: tile_id = 1
+  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_h__out_of_tiles
 func.func @za_h__out_of_tiles() {
-  %za0_h = arm_sme.get_tile_id : i16
-  %za1_h = arm_sme.get_tile_id : i16
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  // CHECK-NEXT: tile_id = 1
+  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i16
+  %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
@@ -90,14 +90,14 @@ func.func @za_h__out_of_tiles() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_s() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 3
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -107,38 +107,37 @@ func.func @za_h_overlapping_za_s() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_d() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA9.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 1
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA5.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 5
-  %za5_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 5
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_h_overlapping_za_q() {
-  %za0_h = arm_sme.get_tile_id : i16
-  %za0_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za8_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -147,8 +146,8 @@ func.func @za_h_overlapping_za_q() {
 // CHECK-LABEL: za0_s
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
 func.func @za0_s() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -157,27 +156,26 @@ func.func @za0_s() {
 // CHECK-LABEL: za_s
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 2
-  %za2_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 3
-  %za3_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 2
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 3
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @za_s__out_of_tiles() {
-  %za0_s = arm_sme.get_tile_id : i32
-  %za1_s = arm_sme.get_tile_id : i32
-  %za2_s = arm_sme.get_tile_id : i32
-  %za3_s = arm_sme.get_tile_id : i32
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i32
+  %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -187,42 +185,41 @@ func.func @za_s__out_of_tiles() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s_overlapping_za_d() {
   // ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 2
-  %za2_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 2
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_s_overlapping_za_q() {
-  %za0_s = arm_sme.get_tile_id : i32
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -231,8 +228,8 @@ func.func @za_s_overlapping_za_q() {
 // CHECK-LABEL: za0_d
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
 func.func @za0_d() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 0
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
@@ -241,63 +238,61 @@ func.func @za0_d() {
 // CHECK-LABEL: za_d
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_d() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 1
-  %za1_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 2
-  %za2_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 4
-  %za4_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 5
-  %za5_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 6
-  %za6_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 0
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 1
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 2
+  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 4
+  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 5
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 6
+  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_d__out_of_tiles() {
-  %za0_d = arm_sme.get_tile_id : i64
-  %za1_d = arm_sme.get_tile_id : i64
-  %za2_d = arm_sme.get_tile_id : i64
-  %za3_d = arm_sme.get_tile_id : i64
-  %za4_d = arm_sme.get_tile_id : i64
-  %za5_d = arm_sme.get_tile_id : i64
-  %za6_d = arm_sme.get_tile_id : i64
-  %za7_d = arm_sme.get_tile_id : i64
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i64
+  %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_d_overlapping_za_q() {
-  %za0_d = arm_sme.get_tile_id : i64
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -306,8 +301,8 @@ func.func @za_d_overlapping_za_q() {
 // CHECK-LABEL: za0_q
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
 func.func @za0_q() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 0
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -316,62 +311,61 @@ func.func @za0_q() {
 // CHECK-LABEL: za_q
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_q() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 1
-  %za1_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 2
-  %za2_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 3
-  %za3_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 4
-  %za4_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 5
-  %za5_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 6
-  %za6_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 7
-  %za7_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 8
-  %za8_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 9
-  %za9_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 10
-  %za10_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 11
-  %za11_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 12
-  %za12_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 13
-  %za13_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 14
-  %za14_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 15
-  %za15_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 0
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 1
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 2
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 4
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 6
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 8
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 10
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 12
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 14
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @za_q__out_of_tiles() {
-  %za0_q = arm_sme.get_tile_id : i128
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za8_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 2378f4234aef1ef..04412e4db1c5f3a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,7 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
-// RUN:           -allocate-arm-sme-tiles -canonicalize      \
-// RUN:           -allow-unregistered-dialect                \
-// RUN: | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
 
 // This test verifies the tile mask operand of the zero intrinsic zeroes
 // the correct tiles. Both integer and floating-point datatypes are checked.
@@ -10,13 +7,8 @@
 
 // CHECK-LABEL: zero_za_b
 func.func @zero_za_b() {
-  // CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8
-  // CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
   %zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
-  "prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -24,20 +16,10 @@ func.func @zero_za_b() {
 
 // CHECK-LABEL: zero_za_h
 func.func @zero_za_h() {
-  // CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16
-  // CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16>
+  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
   %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
-  "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
   %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
-  "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -45,32 +27,14 @@ func.func @zero_za_h() {
 
 // CHECK-LABEL: zero_za_s
 func.func @zero_za_s() {
-  // CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK:      arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
   %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
   %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
   %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
   %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
-  "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -78,55 +42,21 @@ func.func @zero_za_s() {
 
 // CHECK-LABEL: zero_za_d
 func.func @zero_za_d() {
-  // CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
   %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
   %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
   %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
   %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
   %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
   %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
   %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
   %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
-  "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 77ac071ef67de9a..d16d250b70eb3f6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
@@ -10,13 +10,9 @@
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
-// CHECK-DAG:  %[[C255:.*]] = arith.constant 255 : i32
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-DAG:  %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
-// CHECK-DAG:  "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
+// CHECK-DAG:  "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
 // CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
@@ -27,8 +23,7 @@
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -49,8 +44,6 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-LABEL: @vector_load_i8_with_offset(
 // CHECK-SAME:                              %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[C123:.*]] = arith.constant 123 : index
@@ -68,10 +61,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   %c123 = arith.constant 123 : index
@@ -84,8 +75,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
 // CHECK-SAME:                                     %[[ARG0:.*]]: memref<?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -98,10 +87,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
@@ -113,12 +100,10 @@ func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16
 
 // CHECK-LABEL: @vector_load_i16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -129,13 +114,10 @@ func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
 
 // CHECK-LABEL: @vector_load_i32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT:   arith.extui %[[TILE_ID]]
-// CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -146,12 +128,10 @@ func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
 
 // CHECK-LABEL: @vector_load_i64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -162,12 +142,10 @@ func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
 
 // CHECK-LABEL: @vector_load_f16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -178,12 +156,10 @@ func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
 
 // CHECK-LABEL: @vector_load_bf16(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xbf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -194,13 +170,10 @@ func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
 
 // CHECK-LABEL: @vector_load_f32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT:   arith.extui %[[TILE_ID]]
-// CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -211,12 +184,10 @@ func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
 
 // CHECK-LABEL: @vector_load_f64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -227,10 +198,8 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 
 // CHECK-LABEL: @vector_load_i128(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128>
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32
 // CHECK:       arm_sme.intr.ld1q.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -244,7 +213,6 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // -----
 
 // CHECK-LABEL: @vector_store_i8(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
@@ -256,19 +224,18 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
 // CHECK-NEXT: return
-func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>) {
+func.func @vector_store_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
@@ -276,15 +243,14 @@ func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>)
 // -----
 
 // CHECK-LABEL: @vector_store_i16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xi16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i16(%arg0 : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
@@ -292,16 +258,14 @@ func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>
 // -----
 
 // CHECK-LABEL: @vector_store_i32(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
 // CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
-func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i32(%arg0 : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
@@ -309,15 +273,14 @@ func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>
 // -----
 
 // CHECK-LABEL: @vector_store_i64(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xi64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
 // CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
-func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i64(%arg0 : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
@@ -325,15 +288,14 @@ func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>
 // -----
 
 // CHECK-LABEL: @vector_store_f16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f16(%arg0 : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
@@ -341,31 +303,28 @@ func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>
 // -----
 
 // CHECK-LABEL: @vector_store_bf16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xbf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xbf16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_bf16(%arg0 : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 // -----
 
 // CHECK-LABEL: @vector_store_f32(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xf32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
 // CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
-func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f32(%arg0 : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
@@ -373,15 +332,14 @@ func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>
 // -----
 
 // CHECK-LABEL: @vector_store_f64(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xf64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
 // CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
-func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f64(%arg0 : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
@@ -389,13 +347,12 @@ func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>
 // -----
 
 // CHECK-LABEL: @vector_store_i128(
-// CHECK-SAME:                     %[[TILE:.*]]: vector<[1]x[1]xi128>,
 // CHECK-SAME:                     %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32
 // CHECK:       arm_sme.intr.st1q.horiz
-func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi128>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i128(%arg0 : memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
@@ -407,12 +364,11 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
-func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>)
+func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>) {
   // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
-  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
-  // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
-  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  // CHECK: "arm_sme.intr.mopa"(%[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
 }
@@ -420,8 +376,9 @@ func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_bf16
-func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
 }
@@ -429,10 +386,9 @@ func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f32
-func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
-  // CHECK-NOT: arith.extui
-  // CHECK-NOT: arith.trunci
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -440,9 +396,9 @@ func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f64
-func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
-  // CHECK: arith.trunci {{.*}} : i64 to i32
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -451,8 +407,8 @@ func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]
 
 // CHECK-LABEL: @vector_outerproduct_no_accumulator
 func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
-  // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
   %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -460,12 +416,12 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %dim0 : index, %dim1 : index) {
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
-  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
-  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  // CHECK: "arm_sme.intr.mopa"(%[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
@@ -474,11 +430,12 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
@@ -487,11 +444,12 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
-func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
@@ -500,11 +458,12 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
-func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
@@ -520,7 +479,8 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 
 // -----
 
-func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
+  %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
@@ -529,7 +489,8 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
@@ -537,8 +498,9 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
   // CHECK: vector.outerproduct
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -550,14 +512,13 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i32(
-// CHECK-SAME:                       %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                       %[[SLICE:.*]]: vector<[4]xi32>,
 // CHECK-SAME:                       %[[INDEX:.*]]: index)
-func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
+func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-  // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
-  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
@@ -565,8 +526,9 @@ func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i8
-func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
@@ -574,8 +536,9 @@ func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[1
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i16
-func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
@@ -583,8 +546,9 @@ func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i64
-func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
@@ -592,8 +556,9 @@ func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i128
-func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
@@ -601,8 +566,9 @@ func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f16
-func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
@@ -610,8 +576,9 @@ func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_bf16
-func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
@@ -619,8 +586,9 @@ func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f32
-func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
@@ -628,8 +596,9 @@ func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f64
-func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
@@ -637,19 +606,18 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i32(
-// CHECK-SAME:                         %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                         %[[EL:.*]]: i32,
 // CHECK-SAME:                         %[[ROW:.*]]: index,
 // CHECK-SAME:                         %[[COL:.*]]: index)
-func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
-  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-  // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
+  // CHECK-DAG:  %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+  // CHECK-DAG:  %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-DAG:  %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
   // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
-  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
@@ -657,9 +625,10 @@ func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i8
-func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
@@ -667,9 +636,10 @@ func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i16
-func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
@@ -677,9 +647,10 @@ func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i64
-func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
@@ -687,9 +658,10 @@ func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i128
-func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
@@ -697,9 +669,10 @@ func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %r
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f16
-func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
@@ -707,9 +680,10 @@ func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_bf16
-func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
@@ -717,9 +691,10 @@ func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %r
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f32
-func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
@@ -727,9 +702,10 @@ func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f64
-func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
@@ -741,14 +717,13 @@ func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row:
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i32(
-// CHECK-SAME:                        %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                        %[[INDEX:.*]]: index)
-func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) -> vector<[4]xi32> {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_SLICE_INDEX]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
@@ -756,8 +731,9 @@ func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i8
-func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
@@ -765,8 +741,9 @@ func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i16
-func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
@@ -774,8 +751,9 @@ func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i64
-func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) -> vector<[2]xi64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
@@ -783,8 +761,9 @@ func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i128
-func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
@@ -792,8 +771,9 @@ func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f16
-func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) -> vector<[8]xf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
@@ -801,8 +781,9 @@ func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_bf16
-func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -> vector<[8]xbf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
@@ -810,8 +791,9 @@ func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f32
-func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
@@ -819,8 +801,9 @@ func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f64
-func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
@@ -828,16 +811,15 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_element(
-// CHECK-SAME:                          %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                          %[[ROW:.*]]: index,
 // CHECK-SAME:                          %[[COL:.*]]: index)
-func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_element(%row: index, %col: index) -> i32 {
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
   return %el : i32
 }
@@ -845,9 +827,10 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i8
-func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
   return %el : i8
 }
@@ -855,9 +838,10 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i16
-func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
   return %el : i16
 }
@@ -865,9 +849,10 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i64
-func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
   return %el : i64
 }
@@ -875,9 +860,10 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i128
-func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
   return %el : i128
 }
@@ -885,9 +871,10 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index,
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f16
-func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
   return %el : f16
 }
@@ -895,9 +882,10 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_bf16
-func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
   return %el : bf16
 }
@@ -905,9 +893,10 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f32
-func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
   return %el : f32
 }
@@ -915,9 +904,10 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f64
-func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ae3b260da83a2c1..2491b2e2468cdaf 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -452,8 +452,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 // CHECK: %[[C4:.*]] = arith.constant 4 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -500,8 +499,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
 // CHECK-LABEL:   func.func @splat_vec2d_from_i32(
 // CHECK-SAME:      %[[SRC:.*]]: i32) {
 // CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK:   arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK:   %[[VSCALE:.*]] = vector.vscale
 // CHECK:   %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
 // CHECK:   scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 18b95cf2fdf843c..ccd984bcb036e82 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -4,9 +4,9 @@
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles  \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index f189fd97d66cde7..81f503bec91efcc 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -3,10 +3,10 @@
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
 // RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-vector-to-scf -allocate-arm-sme-tiles -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
-// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
index 413cb06b0ae1894..2c9edbf560b0110 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
@@ -2,11 +2,11 @@
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
 // RUN:   -canonicalize \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
-// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 0c186cc373a3b32..936163d1cd30d99 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 442a70cacd66508..8c73c24d695cfb2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 74b51dcc9b4df3a..965337c60b9ffdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles  \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 82f38b4dbfa9d1f..4ca61a089bdf52d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 3b218aefcd415ff..14dca2d4d7082aa 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index e2cbe735fa4ff06..2751c2d136485e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 6e33a421bf799a2..b45f24f6c8fddaf 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
 // RUN:  -e entry -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 961bb274d1e3352..5d2d0a73992f1e0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = za0_d_f64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 25ef1799e63adb1..af5fe236e5bd577 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,8 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-arm-sme-to-llvm \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-arm-sme-to-llvm -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \
diff --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index b3202b26f8e1e3d..7c9976bed912734 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -4,10 +4,9 @@
 llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
                                                 %nxv4i1 : vector<[4]xi1>,
                                                 %nxv16i8 : vector<[16]xi8>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
-      (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -16,10 +15,9 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
 llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
 ) -> vector<[3]xf32> {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
-  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
-      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[16]xi8>, vector<[4]xi1>, i32) -> vector<[3]xf32>
   llvm.return %res : vector<[3]xf32>
 }
 
@@ -28,9 +26,8 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
 llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
 ) -> vector<[3]xi32> {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
-  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
-      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   llvm.return %res : vector<[4]xi32>
 }
diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 767d89a75eec326..edc1f749130440c 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -2,9 +2,8 @@
 
 // CHECK-LABEL: @arm_sme_zero
 llvm.func @arm_sme_zero() {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.zero(i32 0)
-  "arm_sme.intr.zero"(%c0) : (i32) -> ()
+  "arm_sme.intr.zero"() <{tile_mask = 0 : i32}> : () -> ()
   llvm.return
 }
 
@@ -18,19 +17,18 @@ llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>,
                          %nxv2i1  : vector<[2]xi1>,
                          %nxv4i1  : vector<[4]xi1>,
                          %nxv8i1  : vector<[8]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64
-  "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
-    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.mopa"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+    (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32
-  "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
-    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.mopa"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+    (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16
-  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16
-  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
   llvm.return
 }
 
@@ -41,31 +39,30 @@ llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
                          %nxv16i8 : vector<[16]xi8>,
                          %nxv8i1  : vector<[8]xi1>,
                          %nxv16i1 : vector<[16]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16
-  "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.smopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16
-  "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.umopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16
-  "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.sumopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16
-  "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.usmopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8
-  "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.smopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8
-  "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.umopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8
-  "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.sumopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
-  "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.usmopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -79,19 +76,18 @@ llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>,
                          %nxv2i1  : vector<[2]xi1>,
                          %nxv4i1  : vector<[4]xi1>,
                          %nxv8i1  : vector<[8]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.mops.nxv2f64
-  "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
-    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.mops"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+    (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.nxv4f32
-  "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
-    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.mops"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+    (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16
-  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16
-  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
   llvm.return
 }
 
@@ -102,31 +98,30 @@ llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
                          %nxv16i8 : vector<[16]xi8>,
                          %nxv8i1  : vector<[8]xi1>,
                          %nxv16i1 : vector<[16]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16
-  "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.smops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16
-  "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.umops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16
-  "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.sumops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16
-  "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.usmops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8
-  "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.smops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8
-  "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.umops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8
-  "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.sumops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
-  "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.usmops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -141,35 +136,35 @@ llvm.func @arm_sme_load(%nxv1i1  : vector<[1]xi1>,
                         %ptr    : !llvm.ptr) {
   %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.ld1q.horiz
-  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1d.horiz
-  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1w.horiz
-  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1h.horiz
-  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1b.horiz
-  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1q.vert
-  "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1d.vert
-  "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1w.vert
-  "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1h.vert
-  "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1b.vert
-  "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   llvm.return
 }
 
@@ -184,35 +179,35 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
                          %ptr    : !llvm.ptr) {
   %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.st1q.horiz
-  "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1d.horiz
-  "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1w.horiz
-  "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1h.horiz
-  "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1b.horiz
-  "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1q.vert
-  "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1d.vert
-  "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1w.vert
-  "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1h.vert
-  "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1b.vert
-  "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.str
   "arm_sme.intr.str"(%c0, %ptr, %c0) : (i32, !llvm.ptr, i32) -> ()
   llvm.return
@@ -236,34 +231,33 @@ llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
                                         %nxv8bf16 : vector<[8]xbf16>,
                                         %nxv4f32 : vector<[4]xf32>,
                                         %nxv2f64 : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
-      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
-      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+      (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
 
@@ -285,34 +279,33 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
                                        %nxv8bf16 : vector<[8]xbf16>,
                                        %nxv4f32 : vector<[4]xf32>,
                                        %nxv2f64 : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
-      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
-      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+      (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
 
@@ -334,34 +327,33 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
                                               %nxv8bf16  : vector<[8]xbf16>,
                                               %nxv4f32   : vector<[4]xf32>,
                                               %nxv2f64   : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
-  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
-    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
-  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
-  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
-  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
-  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
-    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
-  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
-  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
-  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
-  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   llvm.return
 }
 
@@ -382,33 +374,32 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
                                               %nxv8bf16 : vector<[8]xbf16>,
                                               %nxv4f32  : vector<[4]xf32>,
                                               %nxv2f64  : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
-  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
-    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
-  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
-  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
-  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
-  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
-    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
-  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
-  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
-  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
-  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   llvm.return
 }
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 0ed4f94d5802b09..1e18d23c0044159 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -21,9 +21,9 @@
 using namespace mlir;
 
 // This is needed because these matchers are defined as overloaded functions.
-using HasOpAttrName = detail::AttrOpMatcher(StringRef);
-using HasOpName = detail::NameOpMatcher(StringRef);
-using IsConstantOp = detail::constant_op_matcher();
+using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef);
+using HasOpName = mlir::detail::NameOpMatcher(StringRef);
+using IsConstantOp = mlir::detail::constant_op_matcher();
 
 namespace test {
 #ifdef MLIR_INCLUDE_TESTS

>From 36f675e30491f685126e30f298f2046c21dc767a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 24 Nov 2023 11:30:29 +0000
Subject: [PATCH 2/4] Fixups

---
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |  2 ++
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 30 ++++++++++++-------
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  4 +++
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  | 14 ++++-----
 .../Conversion/ArmSMEToLLVM/CMakeLists.txt    |  1 -
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  4 +--
 .../lib/Conversion/ArmSMEToSCF/CMakeLists.txt |  1 -
 .../Conversion/VectorToArmSME/CMakeLists.txt  |  1 -
 .../VectorToArmSME/VectorToArmSME.cpp         |  6 ++--
 mlir/lib/Dialect/ArmSME/CMakeLists.txt        |  1 -
 mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt     |  2 +-
 .../Dialect/ArmSME/{Utils => IR}/Utils.cpp    | 14 +++++++++
 .../Dialect/ArmSME/Transforms/CMakeLists.txt  |  1 -
 .../ArmSME/Transforms/TileAllocation.cpp      |  9 +++---
 mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt  |  9 ------
 mlir/tools/mlir-query/mlir-query.cpp          |  6 ++--
 16 files changed, 60 insertions(+), 45 deletions(-)
 rename mlir/lib/Dialect/ArmSME/{Utils => IR}/Utils.cpp (77%)
 delete mode 100644 mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index 1da8e488a4c4647..9982d4278b6033e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -24,7 +24,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
+namespace mlir::arm_sme {
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+}
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index abcc9b649c4a530..f7cc1d3fe7517f4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -40,7 +40,7 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
 def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   let description = [{
     An interface for operations that use or allocate Arm SME tiles. These
-    operations need to be assigned a tile ID an i32 attribute, which specifies
+    operations need to be assigned a tile ID, an i32 attribute, which specifies
     which virtual tile within the ZA storage to use. The number of tiles
     available depends on the type of the tile. This is summarized below:
 
@@ -52,7 +52,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
     | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
     | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
 
-    Operations that allocate a new tiles (such as arm_sme.get_tile), are used as
+    Operations that allocate a new tile (such as arm_sme.get_tile), are used as
     the roots for tile allocation, with all operations that (transitively)
     depend on a root being assigned the same tile ID.
   }];
@@ -71,7 +71,10 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
       }]
     >,
     InterfaceMethod<
-      "Returns the (possibly null) tile ID assigned to this operation.",
+      [{
+        Returns the tile ID assigned to this operation. This will be null before
+        tile allocation.
+      }],
       /*returnType=*/"mlir::IntegerAttr",
       /*methodName=*/"getTileId",
       /*arguments=*/(ins),
@@ -82,13 +85,16 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
       }]
     >,
     InterfaceMethod<
-      "The type of tile this operation allocates (or none)",
+      [{
+        The type of tile this operation allocates. Returns none (std::nullopt)
+        if this operation does not allocate a tile.
+      }],
       /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
       /*methodName=*/"getAllocatedTileType",
       /*arguments=*/(ins),
       /*methodBody=*/[{}],
       /*defaultImpl=*/ [{
-        // Do not allocate a new tile.
+        // This operation does not allocate a tile.
         return std::nullopt;
       }]
     >
@@ -104,7 +110,7 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
       return op;
     }
 
-    // A helper to replace this operation and forward any tile ID.
+    // A helper to replace this operation and forward its tile ID (if present).
     template<typename T, typename... Args>
     T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
       auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
@@ -112,6 +118,8 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
       return newOp;
     }
   }];
+
+  let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -222,11 +230,11 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
   let summary = "Returns a SME virtual tile";
   let description = [{
     Allocates a new SME "virtual tile" within a function. The contents of the
-    tile returned from this operation undefined.
+    tile returned from this operation are undefined.
 
     Example 1:
 
@@ -264,12 +272,12 @@ def GetTile : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
   }];
 }
 
-def MaterializeSSATile : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
   let summary = "SME tile placeholder";
   let description = [{
     A placeholder to preserve dataflow while lowering to SME intrinsics (which
-    do not take or return tile values). This operation is intended to be DCE'd
-    once all ArmSME operations have been lowered.
+    do not take or return SME virtual tile values). This operation is intended
+    to be DCE'd once all ArmSME operations have been lowered.
 
     This operation is not intended to be used outside of the ArmSME -> LLVM
     conversion.
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 954371cb9d5a0b1..d512c41634bfb53 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -35,8 +35,12 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+/// Returns the type of SME tile this vector type corresponds to or none.
 std::optional<ArmSMETileType> getSMETileType(VectorType);
 
+/// Verifies the tile ID (if set) on this tile operation is valid.
+LogicalResult verifyOperationHasValidTileId(Operation *);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 4f4b090dd10e3c0..1e7579706a34071 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,7 +32,7 @@ using namespace mlir;
 
 namespace {
 
-IntegerAttr getTileIdOrError(ArmSMETileOpInterface op) {
+IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
   auto tileId = op.getTileId();
   if (!tileId)
     op.emitOpError(
@@ -40,13 +40,13 @@ IntegerAttr getTileIdOrError(ArmSMETileOpInterface op) {
   return tileId;
 }
 
-struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTile> {
-  using ConvertOpToLLVMPattern<arm_sme::GetTile>::ConvertOpToLLVMPattern;
+struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
+  using ConvertOpToLLVMPattern<arm_sme::GetTileOp>::ConvertOpToLLVMPattern;
 
   LogicalResult
-  matchAndRewrite(arm_sme::GetTile getTile, OpAdaptor,
+  matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
         getTile, getTile.getTileType());
     return success();
   }
@@ -140,7 +140,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
         loc, rewriter.getI32IntegerAttr(zeroMask));
 
     // Create a placeholder op to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATile>(
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
         zero, zero.getVectorType());
 
     return success();
@@ -558,7 +558,7 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::MaterializeSSATile, arm_sme::aarch64_sme_zero,
+      arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
       arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
       arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
       arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
index 9914f39e17a1a91..d0a921296668d31 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_conversion_library(MLIRArmSMEToLLVM
   LINK_LIBS PUBLIC
   MLIRArmSMETransforms
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRTransforms
   MLIRLLVMCommonConversion
   MLIRLLVMDialect)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 541d711fbd95f29..837bf10d8b95d58 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -89,7 +89,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto tileElementType = tileType.getElementType();
 
     // Allocate a new SME tile.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
         rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
@@ -299,7 +299,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         loc, rewriter.getI32Type(), numCols);
 
     // Allocate a new SME tile.
-    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTile>(
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
         rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
index 3bf4d7082afe422..467c3b942e754f4 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
@@ -9,6 +9,5 @@ add_mlir_conversion_library(MLIRArmSMEToSCF
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRTransforms
   )
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index 715816a90128cdb..b062f65e914e8b9 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,6 +10,5 @@ add_mlir_conversion_library(MLIRVectorToArmSME
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRLLVMCommonConversion
   )
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 109b04ce34a88b7..3016c7b0a84772d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -253,7 +253,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
     auto tileSliceIndex = forOp.getInductionVar();
@@ -315,7 +315,7 @@ struct BroadcastOpToArmSMELowering
     else
       return failure();
 
-    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Create a loop over ZA tile slices.
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
@@ -371,7 +371,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
         loc, tileSliceType, splatOp.getInput());
 
-    auto tile = rewriter.create<arm_sme::GetTile>(loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.
diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
index 31167e6af908b9f..9f57627c321fb0c 100644
--- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
@@ -1,3 +1,2 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
-add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 66062335fa842a4..ce0c2f4986fc661 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRArmSMEDialect
   ArmSME.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME
@@ -15,5 +16,4 @@ add_mlir_dialect_library(MLIRArmSMEDialect
   MLIRSCFDialect
   MLIRSideEffectInterfaces
   MLIRVectorDialect
-  MLIRArmSMEUtils
 )
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
similarity index 77%
rename from mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
rename to mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index c3cdf5703bcbc28..6105cd622528303 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 
 namespace mlir::arm_sme {
 
@@ -59,4 +60,17 @@ std::optional<ArmSMETileType> getSMETileType(VectorType type) {
   }
 }
 
+LogicalResult verifyOperationHasValidTileId(Operation *op) {
+  auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
+  if (!tileOp)
+    return success(); // Not a tile op (no need to check).
+  auto tileId = tileOp.getTileId();
+  if (!tileId)
+    return success(); // Not having a tile ID (yet) is okay.
+  if (!tileId.getType().isSignlessInteger(32))
+    return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
+  // TODO: Verify value of tile ID is in range.
+  return success();
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..7b6b2e77dcebfe0 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 0a65efeb266d15b..228b3777ffb23da 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -136,9 +136,10 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
-static FailureOr<unsigned> getTile(ArmSMETileType tileType,
-                                   TileMask &tilesInUse) {
+/// Allocates and returns a tile ID. Returns an error if there are no tiles
+/// left.
+static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
+                                          TileMask &tilesInUse) {
   auto masks = getMasks(tileType);
   for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
     if ((tilesInUse & tileMask) == TileMask::kNone) {
@@ -168,7 +169,7 @@ struct AssignTileIDsPattern
     else
       tilesInUse = TileMask::kNone;
 
-    auto tileId = getTile(*tileType, tilesInUse);
+    auto tileId = allocateTileId(*tileType, tilesInUse);
     if (failed(tileId))
       return tileOp.emitError("ran out of SME virtual tiles!");
 
diff --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
deleted file mode 100644
index ecf774a215d24f8..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
+++ /dev/null
@@ -1,9 +0,0 @@
-add_mlir_dialect_library(MLIRArmSMEUtils
-  Utils.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
-
-  LINK_LIBS PUBLIC
-  MLIRIR
-  )
diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp
index 1e18d23c0044159..0ed4f94d5802b09 100644
--- a/mlir/tools/mlir-query/mlir-query.cpp
+++ b/mlir/tools/mlir-query/mlir-query.cpp
@@ -21,9 +21,9 @@
 using namespace mlir;
 
 // This is needed because these matchers are defined as overloaded functions.
-using HasOpAttrName = mlir::detail::AttrOpMatcher(StringRef);
-using HasOpName = mlir::detail::NameOpMatcher(StringRef);
-using IsConstantOp = mlir::detail::constant_op_matcher();
+using HasOpAttrName = detail::AttrOpMatcher(StringRef);
+using HasOpName = detail::NameOpMatcher(StringRef);
+using IsConstantOp = detail::constant_op_matcher();
 
 namespace test {
 #ifdef MLIR_INCLUDE_TESTS

>From e3b0e3d56a4e61369a9c9fbcad320922887b8267 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 24 Nov 2023 11:44:13 +0000
Subject: [PATCH 3/4] Update comments that reference removed operations

---
 .../mlir/Dialect/ArmSME/Transforms/Passes.h     |  2 +-
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp    | 17 +++++++----------
 mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp |  3 +--
 .../ArmSME/Transforms/TileAllocation.cpp        | 11 +++++------
 4 files changed, 14 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..11a7385fe311dd3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,7 +29,7 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
     const ArmZaMode = ArmZaMode::Disabled);
 
-/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
+/// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 1e7579706a34071..a28b8ef7f7fceb3 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -56,20 +56,17 @@ struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
 ///
 ///  BEFORE:
 ///  ```mlir
-///     %v = arm_sme.zero : vector<[4]x[4]xi32>
+///     %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
 ///  ```
 ///
 ///  AFTER:
 ///  ```mlir
-///     %tile_id = arm_sme.get_tile_id : i32
-///     %zero_mask = arith.shli %c17_i32, %tile_id : i32
-///     "arm_sme.intr.zero"(%zero_mask) : (i32) -> ()
-///     %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///     "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
+///     %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
 ///  ```
 ///
-///  The 'arm_sme.cast_tile_to_vector' (which models the return) and the
-///  'arith.shli' (which generates the mask) will be folded away after tile
-///  allocation and canonization.
+///  The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
+///  once all ArmSME ops have been converted to LLVM intrinsics.
 struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
   using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;
 
@@ -443,8 +440,8 @@ struct MoveTileSliceToVectorConversion
 ///
 /// is converted to:
 ///
-///   "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
-///     : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
+///   "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
+///     : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 837bf10d8b95d58..28d4c3192e6ab84 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -61,8 +61,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  AFTER:
 ///  ```mlir
 ///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-///  %tile_id = arm_sme.get_tile_id : i32
-///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 228b3777ffb23da..01f0494c3a29511 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -6,10 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass allocates SME tiles at the 'func.func' op level for
-// 'arm_sme.get_tile_id' ops. It does this using a 16-bit tile mask that has a
-// bit for each 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile
-// granule.
+// This pass allocates SME tiles at the 'func.func' op level for ArmSME
+// operations. It does this using a 16-bit tile mask that has a bit for each
+// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
 //
 // The 128-bit tiles overlap with other element tiles as follows (see section
 // B2.3.2 of SME spec [1]):
@@ -34,8 +33,8 @@
 //   ZA7.D   ZA7.Q, ZA15.Q
 //
 // The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
-// that is initalized during the first 'arm_sme.get_tile_id' rewrite and
-// updated on each subsequent rewrite.
+// that is initalized during the first tile allocation within a function and
+// updated on each subsequent allocation.
 //
 // [1] https://developer.arm.com/documentation/ddi0616/aa
 //

>From 874b12c9596f9d5de5cc3d8e645d9be6373b153f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 24 Nov 2023 12:02:54 +0000
Subject: [PATCH 4/4] Add explanation

---
 .../Dialect/ArmSME/Transforms/TileAllocation.cpp  | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 01f0494c3a29511..354f3236bb07b3d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -180,7 +180,20 @@ struct AssignTileIDsPattern
     getForwardSlice(tileOp.getOperation(), &dependantOps);
 
     // Set all operations to use the same tile ID.
-    // This is a navie tile allocation scheme, but works for common cases.
+    // This is a naive tile allocation scheme, but works for common cases. For
+    // example, as this only allocates tile IDs to existing ops, it can't solve
+    // cases like:
+    //
+    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
+    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
+    // } else {
+    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
+    // }
+    //
+    // Where %tileA and %tileB come from different root operations. This case
+    // would require allocating a new tile for the result of the scf.if, and
+    // moving the contents of %tileA or %tileB to result tile (based on the
+    // %some_cond).
     auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
     tileOp.setTileId(tileIDAttr);
     for (auto *op : dependantOps) {



More information about the Mlir-commits mailing list