[Mlir-commits] [mlir] eaff02f - [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (#73253)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 30 02:22:30 PST 2023


Author: Benjamin Maxwell
Date: 2023-11-30T10:22:22Z
New Revision: eaff02f28e194b497ec2d9a6bbecc33b54f6df27

URL: https://github.com/llvm/llvm-project/commit/eaff02f28e194b497ec2d9a6bbecc33b54f6df27
DIFF: https://github.com/llvm/llvm-project/commit/eaff02f28e194b497ec2d9a6bbecc33b54f6df27.diff

LOG: [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (#73253)

This reworks the ArmSME dialect to use attributes for tile allocation.
This has a number of advantages and corrects some issues with the
previous approach:

* Tile allocation can now be done ASAP (i.e. immediately after
`-convert-vector-to-arm-sme`)
* SSA form for control flow is now supported (e.g.`scf.for` loops that
yield tiles)
* ArmSME ops can be converted to intrinsics very late (i.e. after
lowering to control flow)
 * Tests are simplified by removing constants and casts
* Avoids correctness issues with representing LLVM `immargs` as MLIR
values
- The tile ID on the SME intrinsics is an `immarg` (so is required to be
a compile-time constant), `immargs` should be mapped to MLIR attributes
(this is already the case for intrinsics in the LLVM dialect)
- Using MLIR values for `immargs` can lead to invalid LLVM IR being
generated (and passes such as -cse making incorrect optimizations)

As part of this patch we bid farewell to the following operations:

```mlir
arm_sme.get_tile_id : i32
arm_sme.cast_tile_to_vector : i32 to vector<[4]x[4]xi32>
arm_sme.cast_vector_to_tile : vector<[4]x[4]xi32> to i32
```

These are now replaced with:
```mlir
// Allocates a new tile with (indeterminate) state:
arm_sme.get_tile : vector<[4]x[4]xi32>
// A placeholder operation for lowering ArmSME ops to intrinsics:
arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
```

The new tile allocation works by operations implementing the
`ArmSMETileOpInterface`. This interface says that an operation needs to
be assigned a tile ID, and may conditionally allocate a new SME tile.

Operations allocate a new tile by implementing...
```c++
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType()
```
...and returning what type of tile the op allocates (ZAB, ZAH, etc).

Operations that don't allocate a tile return `std::nullopt` (which is
the default behaviour).

Currently the following ops are defined as allocating:
```mlir
arm_sme.get_tile
arm_sme.zero
arm_sme.tile_load
arm_sme.outerproduct // (if no accumulator is specified)
```

Allocating operations become the roots for the tile allocation pass,
which currently just (naively) assigns all transitive uses of a root
operation the same tile ID. However, this is enough to handle current
use cases.

Once tile IDs have been allocated subsequent rewrites can forward the
tile IDs to any newly created operations.

Added: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
    mlir/lib/Dialect/ArmSME/IR/Utils.cpp
    mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
    mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
    mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
    mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
    mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
    mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
    mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
    mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
    mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
    mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
    mlir/lib/Dialect/ArmSME/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
    mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
    mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
    mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
    mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
    mlir/test/Dialect/ArmSME/canonicalize.mlir
    mlir/test/Dialect/ArmSME/cse.mlir
    mlir/test/Dialect/ArmSME/invalid.mlir
    mlir/test/Dialect/ArmSME/roundtrip.mlir
    mlir/test/Dialect/ArmSME/tile-allocation.mlir
    mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
    mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
    mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
    mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
    mlir/test/Target/LLVMIR/arm-sme.mlir

Removed: 
    mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
    mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index fe1f9062a37ef51..9982d4278b6033e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -14,6 +14,8 @@
 #define MLIR_DIALECT_ARMSME_IR_ARMSME_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -22,7 +24,9 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+namespace mlir::arm_sme {
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+}
 
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
new file mode 100644
index 000000000000000..9ad112a5d354a2f
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEEnums.h
@@ -0,0 +1,16 @@
+//===- ArmSMEEnums.h - Arm SME Dialect Enums --------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_ENUMS_H
+#define MLIR_DIALECT_ARMSME_ENUMS_H
+
+#include "mlir/IR/Dialect.h"
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h.inc"
+
+#endif

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index b75918ebf2f6d9c..4d96e04c886fa31 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -54,7 +54,10 @@ def MOPVector : ScalableVectorOfRankAndLengthAndType<[1], [16, 8, 4, 2],
   }];
 }
 
-class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
+class ArmSME_IntrOp<string mnemonic,
+                    list<int> immArgPositions = [],
+                    list<string> immArgAttrNames = [],
+                    list<int> overloadedOperands = [],
                     list<Trait> traits = [], int numResults = 0,
                     list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
@@ -64,16 +67,27 @@ class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
           /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/numResults>;
+          /*int numResults=*/numResults,
+          /*bit requiresAccessGroup=*/0,
+          /*bit requiresAliasAnalysis=*/0,
+          /*bit requiresFastmath=*/0,
+          /*list<int> immArgPositions=*/immArgPositions,
+          /*list<string> immArgAttrNames=*/immArgAttrNames>;
 
 // Zero
-def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
-                            Arguments<(ins Arg<I32, "Tile mask">:$tile_mask)>;
+def LLVM_aarch64_sme_zero
+   : ArmSME_IntrOp<"zero",
+                   /*immArgPositions=*/[0],
+                   /*immArgAttrNames=*/["tile_mask"]>,
+     Arguments<(ins Arg<I32Attr, "Tile mask">:$tile_mask)>;
 
 // MOP's
 class ArmSME_IntrMopOverloadedOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic, [4]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+    : ArmSME_IntrOp<mnemonic,
+                    /*immArgPositions=*/[0],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[4]>,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<MOPPredicate, "LHS predicate">:$lhs_predicate,
                  Arg<MOPPredicate, "RHS predicate">:$rhs_predicate,
                  Arg<MOPVector, "LHS vector operand">:$lhs_vector,
@@ -92,12 +106,17 @@ def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
 
+class ArmSME_IntrLoadStoreOp<string mnemonic>
+    : ArmSME_IntrOp<mnemonic,
+                    /*immArgPositions=*/[2],
+                    /*immArgAttrNames=*/["tile_id"]>;
+
 // Loads
 class ArmSME_IntrLoadOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Load address">:$load_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_ld1b_horiz : ArmSME_IntrLoadOp<"ld1b.horiz">;
@@ -113,10 +132,10 @@ def LLVM_aarch64_sme_ld1q_vert : ArmSME_IntrLoadOp<"ld1q.vert">;
 
 // Stores
 class ArmSME_IntrStoreOp<string mnemonic>
-    : ArmSME_IntrOp<mnemonic>,
+    : ArmSME_IntrLoadStoreOp<mnemonic>,
       Arguments<(ins Arg<SVEPredicate, "Vector predicate">:$predicate,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>:$store_address,
-                 Arg<I32, "Virtual tile ID">:$tile_id,
+                 Arg<I32Attr, "Virtual tile ID">:$tile_id,
                  Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_st1b_horiz : ArmSME_IntrStoreOp<"st1b.horiz">;
@@ -138,22 +157,28 @@ def LLVM_aarch64_sme_str
 
 // Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
-    : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
+    : ArmSME_IntrOp<"write." # direction,
+                    /*immArgPositions=*/[0],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[3],
                     [AllShapesMatch<["predicate", "vector"]>]>,
-      Arguments<(ins Arg<I32, "Virtual tile ID">:$tile_id,
+      Arguments<(ins Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
 // Tile slice to vector
 class LLVM_aarch64_sme_read<string direction>
-    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+    : ArmSME_IntrOp<"read." # direction,
+                    /*immArgPositions=*/[2],
+                    /*immArgAttrNames=*/["tile_id"],
+                    /*overloadedOperands=*/[],
                     [AllShapesMatch<["vector", "predicate", "res"]>,
                      AllElementTypesMatch<["vector", "res"]>],
                     /*numResults=*/1, /*overloadedResults=*/[0]>,
       Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
                      Arg<SVEPredicate, "Vector predicate">:$predicate,
-                     Arg<I32, "Virtual tile ID">:$tile_id,
+                     Arg<I32Attr, "Virtual tile ID">:$tile_id,
                      Arg<I32, "Tile slice">:$tile_slice_index)>;
 
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index ba33a2826e6ca4b..f7cc1d3fe7517f4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -21,6 +21,107 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 
+//===----------------------------------------------------------------------===//
+// ArmSME op interfaces
+//===----------------------------------------------------------------------===//
+
+def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
+    [
+      I32EnumAttrCase<"ZAB", 0, "za.b">,
+      I32EnumAttrCase<"ZAH", 1, "za.h">,
+      I32EnumAttrCase<"ZAS", 2, "za.s">,
+      I32EnumAttrCase<"ZAD", 3, "za.d">,
+      I32EnumAttrCase<"ZAQ", 4, "za.q">,
+    ]>{
+  let cppNamespace = "mlir::arm_sme";
+  let genSpecializedAttr = 0;
+}
+
+def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
+  let description = [{
+    An interface for operations that use or allocate Arm SME tiles. These
+    operations need to be assigned a tile ID, an i32 attribute, which specifies
+    which virtual tile within the ZA storage to use. The number of tiles
+    available depends on the type of the tile. This is summarized below:
+
+    | Tile Vector Types                                                       | Possible Tile IDs   |
+    |-------------------------------------------------------------------------|---------------------|
+    | `vector<[16]x[16]xi8>`                                                  | 0                   |
+    | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` | 0 and 1             |
+    | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
+    | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
+    | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
+
+    Operations that allocate a new tile (such as arm_sme.get_tile), are used as
+    the roots for tile allocation, with all operations that (transitively)
+    depend on a root being assigned the same tile ID.
+  }];
+  let methods = [
+    InterfaceMethod<
+      "Sets the tile ID for this operation.",
+      /*returnType=*/"void",
+      /*methodName=*/"setTileId",
+      /*arguments=*/(ins "mlir::IntegerAttr":$tileId),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        if (!tileId)
+          return;
+        ::mlir::Operation* op = this->getOperation();
+        op->setAttr("tile_id", tileId);
+      }]
+    >,
+    InterfaceMethod<
+      [{
+        Returns the tile ID assigned to this operation. This will be null before
+        tile allocation.
+      }],
+      /*returnType=*/"mlir::IntegerAttr",
+      /*methodName=*/"getTileId",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        ::mlir::Operation* op = this->getOperation();
+        return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
+      }]
+    >,
+    InterfaceMethod<
+      [{
+        The type of tile this operation allocates. Returns none (std::nullopt)
+        if this operation does not allocate a tile.
+      }],
+      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
+      /*methodName=*/"getAllocatedTileType",
+      /*arguments=*/(ins),
+      /*methodBody=*/[{}],
+      /*defaultImpl=*/ [{
+        // This operation does not allocate a tile.
+        return std::nullopt;
+      }]
+    >
+  ];
+
+  let extraSharedClassDeclaration = [{
+    // A helper to create a new operation and propagate this operations tile ID.
+    template<typename T, typename... Args>
+    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
+      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
+      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
+        tileOp.setTileId($_op.getTileId());
+      return op;
+    }
+
+    // A helper to replace this operation and forward its tile ID (if present).
+    template<typename T, typename... Args>
+    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
+      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
+      rewriter.replaceOp($_op, newOp);
+      return newOp;
+    }
+  }];
+
+  let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+}
+
 //===----------------------------------------------------------------------===//
 // ArmSME type definitions
 //===----------------------------------------------------------------------===//
@@ -44,7 +145,8 @@ def nxnxv2f64  : SMETileType<F64,  [2,  2 ], "vector<[2]x[2]xf64>">;
 
 def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                          nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64],
-                        "a vector type that fits into a SME tile">
+                        "a vector type that fits into a SME tile",
+                        "VectorType">
 {
   let description = [{
     Possible vector types:
@@ -66,40 +168,6 @@ def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
   }];
 }
 
-def TileID : AnyTypeOf<[I8, I16, I32, I64, I128],
-                       "an identifier of a virtual tile (of a size) within the ZA storage">
-{
-  let description = [{
-    The tile ID is an 8, 16, 32, 64, or 128-bit signless integer. The value of
-    the integer indicates the tile to use, and the bit size indicates the size
-    of tile. The number of tiles available and the element types of those depend
-    on the size. This is summarised below:
-
-    | Tile ID Type | Possible Tile IDs   | Tile Vector Types                                                       |
-    |--------------|---------------------|-------------------------------------------------------------------------|
-    | `i8`         | 0                   | `vector<[16]x[16]xi8>`                                                  |
-    | `i16`        | 0 and 1             | `vector<[8]x[8]xi16>`, `vector<[8]x[8]xf16>`, or `vector<[8]x[8]xbf16>` |
-    | `i32`        | 0 to 3 (inclusive)  | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          |
-    | `i64`        | 0 to 7 (inclusive)  | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          |
-    | `i128`       | 0 to 15 (inclusive) | `vector<[1]x[1]xi128>`                                                  |
-  }];
-}
-
-// A type constraint that verifies the bitwidth of the scalar integer returned
-// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile".
-def TileElementWidthMatchesTileID : TypesMatchWith<
-  "`tile_id` has the same number of bits as elements in `vector`",
-  "vector", "tile_id",
-  "IntegerType::get("
-      "$_self.getContext(),"
-      "::llvm::isa<IntegerType>(::llvm::cast<VectorType>($_self).getElementType())"
-          "? ::llvm::cast<IntegerType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth()"
-          ": ::llvm::cast<FloatType>("
-                  "::llvm::cast<VectorType>($_self).getElementType())"
-                  ".getWidth())">;
-
 class HasMatchingMaskTypeConstraint<string vector, string mask> :
   OptionalTypesMatchWith<
     mask # " has i1 element type and same shape as " # vector,
@@ -162,125 +230,67 @@ def ArmSME_CombiningKindAttr : EnumAttr<ArmSME_Dialect, CombiningKind,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from tile id to 2-d scalable vector type";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
+  let summary = "Returns a SME virtual tile";
   let description = [{
-    A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d
-    scalable vector type, which represents an SME "virtual tile". This would
-    normally be used when lowering operations that return "virtual tile" vector
-    types to model the output. This is required to preserve dataflow as SME
-    intrinsics have no return values.
+    Allocates a new SME "virtual tile" within a function. The contents of the
+    tile returned from this operation are undefined.
 
-    Example:
+    Example 1:
 
-    Input:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 8-bit element "virtual tile"
+    %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
-    After lowering `vector.load`:
+    Example 2:
+
     ```mlir
-    %tile_id = arm_sme.get_tile_id : i32
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-    }
-    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate two 16-bit element "virtual tiles"
+    %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+    %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
-    In the example above, the `vector.load` can't be replaced with an SME
-    intrinsic that has no outputs since it is used by the `vector.store`.
-    However, by inserting a `cast_tile_to_vector` op after the load intrinsics
-    the `vector.load` can be replaced. This enables "local" rewrites on
-    individual vector ops, rather than "global" rewrites that would have to
-    look at the vector op uses and also lower them.
-
-    Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold
-    the cast away if it comes from a `arm_sme.cast_vector_to_tile`.
-  }];
-  let arguments = (ins TileID:$tile_id);
-  let results = (outs SMETile:$vector);
-  let assemblyFormat =
-    "$tile_id attr-dict `:` type($tile_id) `to` type($vector)";
-  let hasCanonicalizeMethod = 1;
-}
-
-def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> {
-  let summary = "Cast from 2-d scalable vector type to tile id";
-  let description = [{
-    A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector
-    type, which represents an SME "virtual tile", to a tile id. This is
-    required to preserve dataflow as the SME intrinsics have no return values.
-
-    Example:
-
-    Input:
+    Example 3:
     ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
+    // Allocate an 128-bit element "virtual tile"
+    %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
+  }];
 
-    After lowering `vector.store`:
-    ```mlir
-    %tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-    scf.for %vnum = %c0 to %num_vectors step %c1 {
-      // ...
-      %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32
-      "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
+
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getTile().getType());
     }
-    ```
 
-    Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold
-    the cast away if it comes from a `arm_sme.cast_tile_to_vector`.
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getTileType());
+    }
   }];
-  let arguments = (ins SMETile:$vector);
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat =
-    "$vector attr-dict `:` type($vector) `to` type($tile_id)";
-  let hasCanonicalizeMethod = 1;
 }
 
-def GetTileID : ArmSME_Op<"get_tile_id"> {
-  let summary = "Returns an SME \"virtual tile\" id";
+def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
+  let summary = "SME tile placeholder";
   let description = [{
-    A `get_tile_id` operation returns a scalar integer representing an SME
-    "virtual tile" id. The bitwidth of the scalar indicates the element
-    bitwidth of the "virtual tile".
+    A placeholder to preserve dataflow while lowering to SME intrinsics (which
+    do not take or return SME virtual tile values). This operation is intended
+    to be DCE'd once all ArmSME operations have been lowered.
 
-    The scope of a tile id is a function and cannot be passed or returned from
-    functions.
-
-    Example:
-    ```mlir
-    // Allocate and return an 8-bit element "virtual tile" id
-    %za0_b = arm_sme.get_tile_id : i8
-    ```
-
-    Example:
-    ```
-    // Allocate and return two 16-bit element "virtual tile" ids
-    %za0_h = arm_sme.get_tile_id : i16
-    %za1_h = arm_sme.get_tile_id : i16
-    ```
-
-    Example:
-    ```
-    // Allocate and return an 128-bit element "virtual tile" id
-    %za0_q = arm_sme.get_tile_id : i128
-    ```
+    This operation is not intended to be used outside of the ArmSME -> LLVM
+    conversion.
   }];
-
-  let results = (outs TileID:$tile_id);
-  let assemblyFormat = "attr-dict `:` type($tile_id)";
+  let results = (outs SMETile:$tile);
+  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
 //
 // Tile reset.
 //
 
-def ZeroOp : ArmSME_Op<"zero", [Pure]> {
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let summary = "Initialize the two-dimensional ZA array with 0s";
   let results = (outs SMETile:$res);
   let description = [{
@@ -303,11 +313,15 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getVectorType());
+    }
   }];
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
 def TileLoadOp : ArmSME_Op<"tile_load", [
+  ArmSMETileOpInterface,
   AttrSizedOperandSegments,
   OptionalTypesMatchWith<
     "padding type matches element type of result",
@@ -375,6 +389,9 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      return arm_sme::getSMETileType(getVectorType());
+    }
   }];
 
   let builders = [
@@ -394,6 +411,7 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
 }
 
 def TileStoreOp : ArmSME_Op<"tile_store", [
+  ArmSMETileOpInterface,
   AttrSizedOperandSegments,
   HasMatchingMaskTypeConstraint<"valueToStore", "mask">,
 ]> {
@@ -457,6 +475,7 @@ def TileStoreOp : ArmSME_Op<"tile_store", [
 }
 
 def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
+  ArmSMETileOpInterface,
   AllTypesMatch<["tile", "result"]>, TileSliceMaskConstraint<"result", "mask">
 ]> {
   let summary = "Tile slice load and update operation";
@@ -515,6 +534,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
 }
 
 def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
+  ArmSMETileOpInterface,
   TileSliceMaskConstraint<"tile", "mask">
 ]> {
   let summary = "Tile slice store operation";
@@ -570,6 +590,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
+    ArmSMETileOpInterface,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -617,7 +638,8 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
   }];
 }
 
-def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [Pure,
+def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
+    ArmSMETileOpInterface,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -670,7 +692,8 @@ class OuterProductResultTileTypeConstraint<string operand> :
     "}()">;
 
 def OuterProductOp :
-  ArmSME_Op<"outerproduct", [Pure,
+  ArmSME_Op<"outerproduct", [
+    ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
     HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
@@ -715,7 +738,7 @@ def OuterProductOp :
     ```
   }];
 
-  let arguments = (ins
+let arguments = (ins
     SVEVector:$lhs, SVEVector:$rhs,
     Optional<SVEPredicate>:$lhsMask,
     Optional<SVEPredicate>:$rhsMask,
@@ -736,6 +759,12 @@ def OuterProductOp :
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
   }];
 }
 

diff  --git a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
index 9319a8042c93f4a..9801d8b099e3f0a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt
@@ -15,6 +15,12 @@ set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
 mlir_tablegen(ArmSMEOpsConversions.inc -gen-llvmir-conversions)
 add_public_tablegen_target(MLIRArmSMEConversionsIncGen)
 
+# Generate op interface declarations and definitions
+set(LLVM_TARGET_DEFINITIONS ArmSMEOps.td)
+mlir_tablegen(ArmSMEOpInterfaces.h.inc -gen-op-interface-decls)
+mlir_tablegen(ArmSMEOpInterfaces.cpp.inc -gen-op-interface-defs)
+add_public_tablegen_target(MLIRArmSMEOpInterfaces)
+
 # Generate declarations and definitions of ArmSME intrinsic Ops
 set(LLVM_TARGET_DEFINITIONS ArmSMEIntrinsicOps.td)
 mlir_tablegen(ArmSMEIntrinsicOps.h.inc -gen-op-decls)

diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..11a7385fe311dd3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,7 +29,7 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
     const ArmZaMode = ArmZaMode::Disabled);
 
-/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
+/// Pass that allocates tile IDs to ArmSME operations.
 std::unique_ptr<Pass> createTileAllocationPass();
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 0941592497beaae..b7d90195d49d761 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -15,10 +15,11 @@
 #ifndef MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
-#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include <optional>
 
-namespace mlir {
-namespace arm_sme {
+namespace mlir::arm_sme {
 
 constexpr unsigned MinStreamingVectorLengthInBits = 128;
 
@@ -34,13 +35,13 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+/// Returns the type of SME tile this vector type corresponds to, or none if the
+/// vector type does not fit within an SME tile.
+std::optional<ArmSMETileType> getSMETileType(VectorType);
 
-} // namespace arm_sme
-} // namespace mlir
+/// Verifies the tile ID (if set) on this tile operation is valid.
+LogicalResult verifyOperationHasValidTileId(Operation *);
+
+} // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_

diff  --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..a28b8ef7f7fceb3 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -32,24 +32,41 @@ using namespace mlir;
 
 namespace {
 
+IntegerAttr getTileIdOrError(arm_sme::ArmSMETileOpInterface op) {
+  auto tileId = op.getTileId();
+  if (!tileId)
+    op.emitOpError(
+        "expected tile ID to be allocated before conversion to LLVM");
+  return tileId;
+}
+
+struct GetTileConversion : public ConvertOpToLLVMPattern<arm_sme::GetTileOp> {
+  using ConvertOpToLLVMPattern<arm_sme::GetTileOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
+        getTile, getTile.getTileType());
+    return success();
+  }
+};
+
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
 ///  ```mlir
-///     %v = arm_sme.zero : vector<[4]x[4]xi32>
+///     %v = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xi32>
 ///  ```
 ///
 ///  AFTER:
 ///  ```mlir
-///     %tile_id = arm_sme.get_tile_id : i32
-///     %zero_mask = arith.shli %c17_i32, %tile_id : i32
-///     "arm_sme.intr.zero"(%zero_mask) : (i32) -> ()
-///     %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///     "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
+///     %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
 ///  ```
 ///
-///  The 'arm_sme.cast_tile_to_vector' (which models the return) and the
-///  'arith.shli' (which generates the mask) will be folded away after tile
-///  allocation and canonization.
+///  The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
+///  once all ArmSME ops have been converted to LLVM intrinsics.
 struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
   using ConvertOpToLLVMPattern<arm_sme::ZeroOp>::ConvertOpToLLVMPattern;
 
@@ -61,9 +78,9 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     unsigned tileElementWidth =
         zero.getVectorType().getElementType().getIntOrFloatBitWidth();
 
-    // Get Tile ID for the `zero` intrinsic.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
+    auto tileId = getTileIdOrError(zero);
+    if (!tileId)
+      return failure();
 
     // Get the base mask for tile based on the element size.
     // The base mask is just the mask to zero the first tile (of a size).
@@ -93,9 +110,6 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
         llvm_unreachable("bad element size");
       }
     }();
-    auto maskType = rewriter.getI32Type();
-    auto baseMask = rewriter.create<arith::ConstantOp>(
-        loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize));
 
     // The actual mask is just the base mask shifted by the tile ID.
     // This will be folded to a constant after tile allocation.
@@ -118,13 +132,13 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern<arm_sme::ZeroOp> {
     //  * Mask    -> 10001000 = (00010001 << 3)
     //
     // This holds for all tile sizes.
-    auto tileMask = rewriter.create<arith::ShLIOp>(
-        loc, baseMask, castTileIDToI32(tileId, loc, rewriter));
-    rewriter.create<arm_sme::aarch64_sme_zero>(loc, tileMask);
+    int32_t zeroMask = baseMaskForSize << int32_t(tileId.getInt());
+    rewriter.create<arm_sme::aarch64_sme_zero>(
+        loc, rewriter.getI32IntegerAttr(zeroMask));
 
-    // Create `CastTileToVectorOp` to use as the output.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(zero, zero.getType(),
-                                                           tileId);
+    // Create a placeholder op to preserve dataflow.
+    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
+        zero, zero.getVectorType());
 
     return success();
   }
@@ -141,15 +155,9 @@ struct LoadTileSliceConversion
                   arm_sme::LoadTileSliceOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = loadTileSliceOp.getLoc();
-    auto tileType = loadTileSliceOp.getVectorType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
-
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
-    // loaded to.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        loadTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(loadTileSliceOp);
+    if (!tileId)
+      return failure();
 
     Value ptr = this->getStridedElementPtr(loc, loadTileSliceOp.getMemRefType(),
                                            adaptor.getBase(),
@@ -164,7 +172,9 @@ struct LoadTileSliceConversion
     // Create all active predicate mask.
     auto maskOp = loadTileSliceOp.getMask();
 
-    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
+    auto tileType = loadTileSliceOp.getVectorType();
+    auto tileElementType = tileType.getElementType();
+    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
     arm_sme::TileSliceLayout layout = loadTileSliceOp.getLayout();
 
     // Create 'arm_sme.intr.ld1*.(horiz|vert)' intrinsic to load ZA tile slice.
@@ -174,23 +184,23 @@ struct LoadTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.create<arm_sme::aarch64_sme_ld1b_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 16:
         rewriter.create<arm_sme::aarch64_sme_ld1h_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 32:
         rewriter.create<arm_sme::aarch64_sme_ld1w_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 64:
         rewriter.create<arm_sme::aarch64_sme_ld1d_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       case 128:
         rewriter.create<arm_sme::aarch64_sme_ld1q_horiz>(loc, maskOp, ptr,
-                                                         tileI32, tileSliceI32);
+                                                         tileId, tileSliceI32);
         break;
       }
     } else {
@@ -199,31 +209,30 @@ struct LoadTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.create<arm_sme::aarch64_sme_ld1b_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 16:
         rewriter.create<arm_sme::aarch64_sme_ld1h_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 32:
         rewriter.create<arm_sme::aarch64_sme_ld1w_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 64:
         rewriter.create<arm_sme::aarch64_sme_ld1d_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       case 128:
         rewriter.create<arm_sme::aarch64_sme_ld1q_vert>(loc, maskOp, ptr,
-                                                        tileI32, tileSliceI32);
+                                                        tileId, tileSliceI32);
         break;
       }
     }
 
     // The load intrinsics have no result, replace 'arm_sme.tile_load' with
-    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(loadTileSliceOp,
-                                                           tileType, tile);
+    // the input tile to preserve dataflow.
+    rewriter.replaceOp(loadTileSliceOp, loadTileSliceOp.getTile());
 
     return success();
   }
@@ -244,11 +253,9 @@ struct StoreTileSliceConversion
     auto tileElementType = tileType.getElementType();
     unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the vector
-    // being stored.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        storeTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(storeTileSliceOp);
+    if (!tileId)
+      return failure();
 
     // Create 'arm_sme.intr.st1*.horiz' intrinsic to store ZA tile slice.
     Value ptr = this->getStridedElementPtr(
@@ -263,7 +270,6 @@ struct StoreTileSliceConversion
 
     auto maskOp = storeTileSliceOp.getMask();
 
-    Value tileI32 = castTileIDToI32(tile, loc, rewriter);
     arm_sme::TileSliceLayout layout = storeTileSliceOp.getLayout();
 
     if (layout == arm_sme::TileSliceLayout::Horizontal) {
@@ -272,23 +278,23 @@ struct StoreTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_horiz>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       }
     } else {
@@ -297,23 +303,23 @@ struct StoreTileSliceConversion
         llvm_unreachable("unexpected element type!");
       case 8:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1b_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 16:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1h_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 32:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1w_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 64:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1d_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       case 128:
         rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_st1q_vert>(
-            storeTileSliceOp, maskOp, ptr, tileI32, tileSliceI32);
+            storeTileSliceOp, maskOp, ptr, tileId, tileSliceI32);
         break;
       }
     }
@@ -334,14 +340,10 @@ struct MoveVectorToTileSliceConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveVectorToTileSliceOp.getLoc();
     auto tileType = moveVectorToTileSliceOp.getTileType();
-    auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.cast_vector_to_tile' to get a tile ID for the tile being
-    // loaded to.
-    auto tile = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(tileElementWidth),
-        moveVectorToTileSliceOp.getTile());
+    auto tileId = getTileIdOrError(moveVectorToTileSliceOp);
+    if (!tileId)
+      return failure();
 
     auto tileSlice = moveVectorToTileSliceOp.getTileSliceIndex();
 
@@ -357,26 +359,24 @@ struct MoveVectorToTileSliceConversion
                                   /*scalableDims=*/{true});
     auto allActiveMask = rewriter.create<vector::SplatOp>(loc, predTy, one);
 
-    auto tileI32 = castTileIDToI32(tile, loc, rewriter);
-
     // Create 'arm_sme.intr.write.(horiz|vert)' to write vector to tile slice.
     switch (moveVectorToTileSliceOp.getLayout()) {
     case arm_sme::TileSliceLayout::Horizontal:
       rewriter.create<arm_sme::aarch64_sme_write_horiz>(
-          loc, tileI32, tileSliceI32, allActiveMask,
+          loc, tileId, tileSliceI32, allActiveMask,
           moveVectorToTileSliceOp.getVector());
       break;
     case arm_sme::TileSliceLayout::Vertical:
       rewriter.create<arm_sme::aarch64_sme_write_vert>(
-          loc, tileI32, tileSliceI32, allActiveMask,
+          loc, tileId, tileSliceI32, allActiveMask,
           moveVectorToTileSliceOp.getVector());
       break;
     }
 
     // Intrinsic has no result, replace 'arm_sme.move_vector_to_tile_slice' with
-    // 'arm_sme.cast_tile_to_vector' to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
-        moveVectorToTileSliceOp, tileType, tile);
+    // the input tile to preserve dataflow.
+    rewriter.replaceOp(moveVectorToTileSliceOp,
+                       moveVectorToTileSliceOp.getTile());
 
     return success();
   }
@@ -394,12 +394,11 @@ struct MoveTileSliceToVectorConversion
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = moveTileSliceToVector.getLoc();
     auto sliceType = moveTileSliceToVector.getSliceType();
-    auto tile = moveTileSliceToVector.getTile();
     auto sliceIndex = moveTileSliceToVector.getTileSliceIndex();
 
-    // Cast tile to i32 tile ID.
-    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(loc, tile);
-    Value tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+    auto tileId = getTileIdOrError(moveTileSliceToVector);
+    if (!tileId)
+      return failure();
 
     // Create an 'all true' predicate for the tile slice.
     auto predicateType = sliceType.cloneWith({}, rewriter.getI1Type());
@@ -419,12 +418,12 @@ struct MoveTileSliceToVectorConversion
     case arm_sme::TileSliceLayout::Horizontal:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_horiz>(
           moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-          tileIdI32, sliceIndexI32);
+          tileId, sliceIndexI32);
       break;
     case arm_sme::TileSliceLayout::Vertical:
       rewriter.replaceOpWithNewOp<arm_sme::aarch64_sme_read_vert>(
           moveTileSliceToVector, sliceType, zeroVector, allTruePredicate,
-          tileIdI32, sliceIndexI32);
+          tileId, sliceIndexI32);
       break;
     }
 
@@ -441,8 +440,8 @@ struct MoveTileSliceToVectorConversion
 ///
 /// is converted to:
 ///
-///   "arm_sme.intr.mopa"(%tile_id, %ptrue_s, %ptrue_s, %lhs, %rhs)
-///     : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
+///   "arm_sme.intr.mopa"(%ptrue_s, %ptrue_s, %lhs, %rhs) <{tile_id = 0 : i32}>
+///     : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>,
 ///        vector<[4]xf32>) -> ()
 ///
 /// Currently only supports FMOPA and BFMOPA (non-widening).
@@ -454,6 +453,10 @@ struct OuterProductOpConversion
   matchAndRewrite(arm_sme::OuterProductOp outerProductOp,
                   arm_sme::OuterProductOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto tileId = getTileIdOrError(outerProductOp);
+    if (!tileId)
+      return failure();
+
     auto isSupportedType = [](VectorType vectorType) {
       // TODO: the FP outer product instruction variants are predicated on
       // 
diff erent features [1]:
@@ -498,13 +501,8 @@ struct OuterProductOpConversion
     Value acc = outerProductOp.getAcc();
     if (!acc)
       // Initalize accumulator with zero.
-      acc = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
-
-    unsigned elementWidth = resultVectorType.getElementTypeBitWidth();
-    auto tileId = rewriter.create<arm_sme::CastVectorToTile>(
-        loc, rewriter.getIntegerType(elementWidth), acc);
-
-    auto tileI32 = castTileIDToI32(tileId, loc, rewriter);
+      acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+          rewriter, loc, resultVectorType);
 
     Value lhsMask = outerProductOp.getLhsMask();
     Value rhsMask = outerProductOp.getRhsMask();
@@ -519,13 +517,13 @@ struct OuterProductOpConversion
     }
 
     // Create 'arm_sme.intr.mopa' outer product intrinsic.
-    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileI32, lhsMask, rhsMask,
+    rewriter.create<arm_sme::aarch64_sme_mopa>(loc, tileId, lhsMask, rhsMask,
                                                outerProductOp.getLhs(),
                                                outerProductOp.getRhs());
 
-    // Create `CastTileToVectorOp` to use as the output.
-    rewriter.replaceOpWithNewOp<arm_sme::CastTileToVector>(
-        outerProductOp, resultVectorType, tileId);
+    // The outerproduct intrinsics have no result, replace
+    // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
+    rewriter.replaceOp(outerProductOp, acc);
 
     return success();
   }
@@ -557,21 +555,20 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
-      arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
-      arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
-      arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
-      arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
-      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
-      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
-      arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
-      arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
-      arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
-      arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
-      arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
-      arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
-      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
-      arm_sme::aarch64_sme_mopa>();
+      arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
+      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
+      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
+      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
+      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
+      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
+      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
+      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
+      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
+      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
+      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
+      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
+      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
+      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
 }
@@ -580,7 +577,8 @@ void mlir::populateArmSMEToLLVMConversionPatterns(
     ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
-               OuterProductOpConversion, ZeroOpConversion>(converter);
+               OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
+      converter);
 }
 
 std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {

diff  --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
index 9914f39e17a1a91..d0a921296668d31 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt
@@ -10,7 +10,6 @@ add_mlir_conversion_library(MLIRArmSMEToLLVM
   LINK_LIBS PUBLIC
   MLIRArmSMETransforms
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRTransforms
   MLIRLLVMCommonConversion
   MLIRLLVMDialect)

diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index a4dfd4b7edc77c5..69c68663070b6d5 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -61,8 +61,7 @@ void getMemrefIndices(ValueRange indices, unsigned rank, Value tileSliceIndex,
 ///  AFTER:
 ///  ```mlir
 ///  %ptrue_s = arith.constant dense<true> : vector<[4]xi1>
-///  %tile_id = arm_sme.get_tile_id : i32
-///  %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
+///  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
 ///  %vscale = vector.vscale
 ///  %c0 = arith.constant 0 : index
 ///  %c1 = arith.constant 1 : index
@@ -87,16 +86,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
     auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.load_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+    // Allocate a new SME tile.
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+        rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -128,8 +121,8 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      numTileSlices, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
+    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+        rewriter, loc, tileType, tileLoadOp.getBase(), allTruePredicate, tile,
         memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -209,7 +202,8 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     // Initialize tile with zero to satisfy padding. Inactive cols will be
     // zeroed anyway since the loads use zeroing predication. For inactive rows
     // however, no load will occur so these need to be zeroed.
-    auto tile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
+        rewriter, loc, tileType);
 
     // Create a loop to load the active tile slices from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -226,9 +220,9 @@ struct TileLoadOpWithMaskAndPadZeroConversion
     getMemrefIndices(tileLoadOp.getIndices(),
                      tileLoadOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    rewriter.create<arm_sme::LoadTileSliceOp>(
-        loc, tileType, tileLoadOp.getBase(), numColsOp, tile, memrefIndices,
-        tileSliceIndex, tileLoadOp.getLayout());
+    tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
+        rewriter, loc, tileType, tileLoadOp.getBase(), numColsOp, tile,
+        memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
 
@@ -276,7 +270,6 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto loc = tileLoadOp.getLoc();
     auto tileType = tileLoadOp.getVectorType();
     auto tileElementType = tileType.getElementType();
-    unsigned tileElementWidth = tileElementType.getIntOrFloatBitWidth();
 
     auto maskOp = tileLoadOp.getMask();
     if (!maskOp)
@@ -304,14 +297,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), numCols);
 
-    // Create 'arm_sme.get_tile' op.
-    auto tileId = rewriter.create<arm_sme::GetTileID>(
-        loc, rewriter.getIntegerType(tileElementWidth));
-
-    // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type to
-    // use as input tile to 'arm_sme.load_tile_slice' ops.
-    auto tile =
-        rewriter.create<arm_sme::CastTileToVector>(loc, tileType, tileId);
+    // Allocate a new SME tile.
+    auto tile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
+        rewriter, loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -356,8 +344,8 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
-        loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
+    tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
+        rewriter, loc, tileType, loadSlice->getResult(0), tile, tileSliceIndex,
         tileLoadOp.getLayout());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -450,8 +438,9 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
     getMemrefIndices(tileStoreOp.getIndices(),
                      tileStoreOp.getMemRefType().getRank(), tileSliceIndex,
                      upperBound, memrefIndices, loc, rewriter);
-    rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
-        tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
+
+    tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
+        rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, maskCols,
         tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout());
 
     return success();
@@ -506,6 +495,9 @@ struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
       rewriter.setInsertionPointToStart(forOp.getBody());
       // Extract the current row from the tile.
       Value rowIndex = forOp.getInductionVar();
+      // FIXME: Forward tile IDs.
+      // For now, if you vector.print a SME tile you need to do
+      // -allocate-arm-sme-tiles after -convert-arm-sme-to-scf.
       auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
           loc, printOp.getSource(), rowIndex);
       // Print the row with a 1D vector.print.

diff  --git a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
index 3bf4d7082afe422..467c3b942e754f4 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
+++ b/mlir/lib/Conversion/ArmSMEToSCF/CMakeLists.txt
@@ -9,6 +9,5 @@ add_mlir_conversion_library(MLIRArmSMEToSCF
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRTransforms
   )

diff  --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index 715816a90128cdb..b062f65e914e8b9 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,6 +10,5 @@ add_mlir_conversion_library(MLIRVectorToArmSME
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRLLVMCommonConversion
   )

diff  --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index f1e92d1ac9708b7..3016c7b0a84772d 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -44,20 +44,6 @@ static scf::ForOp getLoopOverTileSlices(PatternRewriter &rewriter, Location loc,
   return forOp;
 }
 
-/// Returns a tile of the given vector type.
-static arm_sme::CastTileToVector
-getSMETileAndCastToVector(PatternRewriter &rewriter, Location loc,
-                          VectorType type) {
-  unsigned tileElementWidth = type.getElementType().getIntOrFloatBitWidth();
-
-  // Create 'arm_sme.get_tile' op.
-  auto tileId = rewriter.create<arm_sme::GetTileID>(
-      loc, rewriter.getIntegerType(tileElementWidth));
-
-  // Create `arm_sme.cast_tile_to_vector` to cast tile ID to a vector type.
-  return rewriter.create<arm_sme::CastTileToVector>(loc, type, tileId);
-}
-
 namespace {
 
 /// Conversion pattern for vector.transfer_read.
@@ -267,8 +253,7 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
         tileSliceType, denseAttr.getSplatValue<Attribute>());
     auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
     auto tileSliceIndex = forOp.getInductionVar();
@@ -330,8 +315,7 @@ struct BroadcastOpToArmSMELowering
     else
       return failure();
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Create a loop over ZA tile slices.
     auto forOp = getLoopOverTileSlices(rewriter, loc, tileElementType);
@@ -387,8 +371,7 @@ struct SplatOpToArmSMELowering : public OpRewritePattern<vector::SplatOp> {
     Value broadcastOp1D = rewriter.create<vector::BroadcastOp>(
         loc, tileSliceType, splatOp.getInput());
 
-    arm_sme::CastTileToVector tile =
-        getSMETileAndCastToVector(rewriter, loc, tileType);
+    auto tile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Next, create a loop over ZA tile slices and "move" the generated 1-d
     // vector to each slice.

diff  --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
index 31167e6af908b9f..9f57627c321fb0c 100644
--- a/mlir/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt
@@ -1,3 +1,2 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
-add_subdirectory(Utils)

diff  --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 9df15420b9c9b6f..29fa9085a0a963d 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -28,6 +28,8 @@ using namespace mlir::arm_sme;
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.cpp.inc"
 
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.cpp.inc"
+
 #define GET_OP_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOps.cpp.inc"
 
@@ -54,23 +56,3 @@ void ArmSMEDialect::initialize() {
 #include "mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.cpp.inc"
       >();
 }
-
-// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id
-LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op,
-                                             PatternRewriter &rewriter) {
-  if (auto castTileToVectorOp = op.getVector().getDefiningOp<CastTileToVector>()) {
-    op.replaceAllUsesWith(castTileToVectorOp.getTileId());
-    return success();
-  }
-  return failure();
-}
-
-// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile
-LogicalResult CastTileToVector::canonicalize(CastTileToVector op,
-                                             PatternRewriter &rewriter) {
-  if (auto castVectorToTileOp = op.getTileId().getDefiningOp<CastVectorToTile>()) {
-    op.replaceAllUsesWith(castVectorToTileOp.getVector());
-    return success();
-  }
-  return failure();
-}

diff  --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
index 3e448ec4fb1e04d..ce0c2f4986fc661 100644
--- a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRArmSMEDialect
   ArmSME.cpp
+  Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME

diff  --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
similarity index 51%
rename from mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
rename to mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index f17077ff8565d59..6105cd622528303 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -11,25 +11,22 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
-
-#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 
-using namespace mlir;
-using namespace mlir::arm_sme;
+namespace mlir::arm_sme {
 
-unsigned mlir::arm_sme::getSMETileSliceMinNumElts(Type type) {
+unsigned getSMETileSliceMinNumElts(Type type) {
   assert(isValidSMETileElementType(type) && "invalid tile type!");
   return MinStreamingVectorLengthInBits / type.getIntOrFloatBitWidth();
 }
 
-bool mlir::arm_sme::isValidSMETileElementType(Type type) {
+bool isValidSMETileElementType(Type type) {
   return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
          type.isInteger(64) || type.isInteger(128) || type.isF16() ||
          type.isBF16() || type.isF32() || type.isF64() || type.isF128();
 }
 
-bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
+bool isValidSMETileVectorType(VectorType vType) {
   if ((vType.getRank() != 2) || !vType.allDimsScalable())
     return false;
 
@@ -37,22 +34,43 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
   if (!isValidSMETileElementType(elemType))
     return false;
 
-  unsigned minNumElts = arm_sme::getSMETileSliceMinNumElts(elemType);
+  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
   if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
     return false;
 
   return true;
 }
 
-Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
-                                     RewriterBase &rewriter) {
-  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
-             tile.getDefiningOp())) &&
-         "expected ArmSME GetTileID or CastVectorToTile op!");
-  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
-  if (tileElementWidth < 32)
-    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
-  if (tileElementWidth > 32)
-    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
-  return tile;
+std::optional<ArmSMETileType> getSMETileType(VectorType type) {
+  if (!isValidSMETileVectorType(type))
+    return {};
+  switch (type.getElementTypeBitWidth()) {
+  case 8:
+    return ArmSMETileType::ZAB;
+  case 16:
+    return ArmSMETileType::ZAH;
+  case 32:
+    return ArmSMETileType::ZAS;
+  case 64:
+    return ArmSMETileType::ZAD;
+  case 128:
+    return ArmSMETileType::ZAQ;
+  default:
+    llvm_unreachable("unknown SME tile type");
+  }
+}
+
+LogicalResult verifyOperationHasValidTileId(Operation *op) {
+  auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
+  if (!tileOp)
+    return success(); // Not a tile op (no need to check).
+  auto tileId = tileOp.getTileId();
+  if (!tileId)
+    return success(); // Not having a tile ID (yet) is okay.
+  if (!tileId.getType().isSignlessInteger(32))
+    return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
+  // TODO: Verify value of tile ID is in range.
+  return success();
 }
+
+} // namespace mlir::arm_sme

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..7b6b2e77dcebfe0 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
-  MLIRArmSMEUtils
   MLIRFuncDialect
   MLIRLLVMCommonConversion
   MLIRVectorDialect

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e0462a6dc124131..597846e31e218ef 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -6,10 +6,9 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass allocates SME tiles at the 'func.func' op level for
-// 'arm_sme.get_tile_id' ops. It does this using a 16-bit tile mask that has a
-// bit for each 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile
-// granule.
+// This pass allocates SME tiles at the 'func.func' op level for ArmSME
+// operations. It does this using a 16-bit tile mask that has a bit for each
+// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
 //
 // The 128-bit tiles overlap with other element tiles as follows (see section
 // B2.3.2 of SME spec [1]):
@@ -34,8 +33,8 @@
 //   ZA7.D   ZA7.Q, ZA15.Q
 //
 // The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
-// that is initalized during the first 'arm_sme.get_tile_id' rewrite and
-// updated on each subsequent rewrite.
+// that is initalized during the first tile allocation within a function and
+// updated on each subsequent allocation.
 //
 // [1] https://developer.arm.com/documentation/ddi0616/aa
 //
@@ -43,8 +42,10 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 #define DEBUG_TYPE "allocate-arm-sme-tiles"
 
@@ -107,73 +108,154 @@ enum class TileMask : unsigned {
 };
 
 /// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
-  static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
-  static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
-                                                   TileMask::kZA1H};
-  static const SmallVector<TileMask> ZA_S_MASKS = {
-      TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
-  static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+                                            TileMask::kZA2S, TileMask::kZA3S};
+  static constexpr std::array ZA_D_MASKS = {
       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
-  static const SmallVector<TileMask> ZA_Q_MASKS = {
+  static constexpr std::array ZA_Q_MASKS = {
       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
-  switch (cast<IntegerType>(type).getWidth()) {
-  default:
-    llvm_unreachable("unexpected type!");
-  case 8:
+  switch (type) {
+  case ArmSMETileType::ZAB:
     return ZA_B_MASKS;
-  case 16:
+  case ArmSMETileType::ZAH:
     return ZA_H_MASKS;
-  case 32:
+  case ArmSMETileType::ZAS:
     return ZA_S_MASKS;
-  case 64:
+  case ArmSMETileType::ZAD:
     return ZA_D_MASKS;
-  case 128:
+  case ArmSMETileType::ZAQ:
     return ZA_Q_MASKS;
   }
 }
 
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
-                             unsigned &tileID) {
-  auto masks = getMasks(tileIDOp.getType());
-  for (const auto &it : llvm::enumerate(masks)) {
-    const auto tileMask = it.value();
+/// Allocates and returns a tile ID. Returns an error if there are no tiles
+/// left.
+static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
+                                          TileMask &tilesInUse) {
+  auto masks = getMasks(tileType);
+  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
     if ((tilesInUse & tileMask) == TileMask::kNone) {
       tilesInUse |= tileMask;
-      tileID = it.index();
-      return success();
+      return tileId;
     }
   }
-  return tileIDOp.emitError("ran out of SME virtual tiles!");
+  return failure();
 }
 
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetTileID tileIDOp,
+/// Collects transitive uses of a root value through control flow. This can
+/// handle basic SCF constructs, along with control flow (br and cond_br).
+/// Simple loops work at the SCF level, while more complex control flow can be
+/// dealt with after lowering to CF. This is used to implement basic tile
+/// allocation.
+static void findDependantOps(Value rootValue,
+                             SetVector<Operation *> &dependantOps) {
+  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
+    for (auto [idx, value] : llvm::enumerate(inputValues)) {
+      if (value == rootValue)
+        findDependantOps(exitValues[idx], dependantOps);
+    }
+  };
+  for (Operation *user : rootValue.getUsers()) {
+    if (dependantOps.contains(user))
+      continue;
+    dependantOps.insert(user);
+    TypeSwitch<Operation *>(user)
+        .Case<cf::BranchOp>([&](auto branchOp) {
+          // (CF) Follow branch.
+          traverseCorrespondingValues(branchOp.getDestOperands(),
+                                      branchOp.getDest()->getArguments());
+        })
+        .Case<cf::CondBranchOp>([&](auto condBranchOp) {
+          // (CF) Follow true branch.
+          traverseCorrespondingValues(
+              condBranchOp.getTrueOperands(),
+              condBranchOp.getTrueDest()->getArguments());
+          // (CF) Follow false branch.
+          traverseCorrespondingValues(
+              condBranchOp.getFalseOperands(),
+              condBranchOp.getFalseDest()->getArguments());
+        })
+        .Case<LoopLikeOpInterface>([&](auto loopOp) {
+          // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
+          traverseCorrespondingValues(loopOp.getInits(),
+                                      loopOp.getRegionIterArgs());
+        })
+        .Case<scf::YieldOp>([&](auto yieldOp) {
+          // (SCF) Follow yields of (basic) control flow (e.g. for loops).
+          auto parent = user->getParentOp();
+          traverseCorrespondingValues(user->getOperands(),
+                                      parent->getResults());
+        })
+        .Default([&](auto) {
+          // Otherwise, assume users of _any_ result are dependant.
+          for (Value result : user->getResults())
+            findDependantOps(result, dependantOps);
+        });
+  }
+}
+
+struct AssignTileIDsPattern
+    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
                                 PatternRewriter &rewriter) const override {
-    auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
-    TileMask tilesInUse;
-    if (auto tilesInUseAttr =
-            funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+    if (tileOp.getTileId())
+      return failure();
+
+    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+    if (!tileType)
+      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
+    TileMask tilesInUse = TileMask::kNone;
+    if (auto tilesInUseAttr = llvm::dyn_cast_or_null<IntegerAttr>(
+            func->getDiscardableAttr(kTilesInUseAttr)))
       tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
-    else
-      tilesInUse = TileMask::kNone;
 
-    unsigned tileID;
-    if (failed(getTile(tileIDOp, tilesInUse, tileID)))
-      return failure();
+    auto tileId = allocateTileId(*tileType, tilesInUse);
+    if (failed(tileId))
+      return tileOp.emitError("ran out of SME virtual tiles!");
+
+    func->setDiscardableAttr(kTilesInUseAttr,
+                             rewriter.getI32IntegerAttr((unsigned)tilesInUse));
 
-    funcOp->setAttr(kTilesInUseAttr,
-                    rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+    // Find all the ops that (transitively) depend on this tile.
+    SetVector<Operation *> dependantOps;
+    findDependantOps(tileOp->getResult(0), dependantOps);
+
+    // Set all operations dependent on `tileOp` to use the same tile ID.
+    // This is a naive tile allocation scheme, but works for common cases. For
+    // example, as this only allocates tile IDs to existing ops, it can't solve
+    // cases like this (%tileA and %tileB come from 
diff erent root operations):
+    //
+    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
+    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
+    // } else {
+    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
+    // }
+    //
+    // This case would require allocating a new tile for the result of the
+    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
+    // on the %some_cond).
+    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
+    tileOp.setTileId(tileIDAttr);
+    for (auto *op : dependantOps) {
+      if (auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
+        auto currentTileId = tileOp.getTileId();
+        if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
+          return tileOp.emitOpError(
+              "already assigned 
diff erent SME virtual tile!");
+        tileOp.setTileId(tileIDAttr);
+      }
+    }
 
-    auto tileType = tileIDOp.getType();
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
-        tileIDOp, tileType, rewriter.getIntegerAttr(tileType, tileID));
     return success();
   }
 };
@@ -182,13 +264,15 @@ struct TileAllocationPass
     : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
   void runOnOperation() override {
     RewritePatternSet patterns(&getContext());
-    ConversionTarget target(getContext());
-    patterns.add<GetTileIDConversion>(patterns.getContext());
-    target.addLegalOp<arith::ConstantOp>();
-    target.addIllegalOp<GetTileID>();
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    patterns.add<AssignTileIDsPattern>(patterns.getContext());
+    GreedyRewriteConfig config;
+    // Setting useTopDownTraversal ensures tiles are allocated in program
+    // order.
+    config.useTopDownTraversal = true;
+    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
+            getOperation(), std::move(patterns), config))) {
       signalPassFailure();
+    }
   }
 };
 } // namespace

diff  --git a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
deleted file mode 100644
index da8517aaf80a9fa..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Utils/CMakeLists.txt
+++ /dev/null
@@ -1,11 +0,0 @@
-add_mlir_dialect_library(MLIRArmSMEUtils
-  Utils.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Utils
-
-  LINK_LIBS PUBLIC
-  MLIRArmSMEDialect
-  MLIRDialect
-  MLIRIR
-  )

diff  --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 38a2332ffd5e0a3..fc28645a7acf7c0 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -6,8 +6,7 @@
 
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor(
 // CHECK-SAME:                                   %[[SRC:.*]]: memref<?x?xi32>) {
-// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -16,7 +15,7 @@
 // CHECK-NEXT:    scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
 // CHECK-NEXT:      %[[PTRUE_S:.*]] = arith.constant dense<true> : vector<[4]xi1>
 // CHECK-NEXT:      %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
-// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+// CHECK-NEXT:      arm_sme.load_tile_slice %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[PTRUE_S]], %[[TILE]], %[[TILE_SLICE_INDEX]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -60,8 +59,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-LABEL: func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(
 // CHECK-SAME:                                                             %[[SRC:.*]]: memref<?x?xi32>,
 // CHECK-SAME:                                                             %[[PAD:.*]]: i32) {
-// CHECK-DAG:     %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG:     %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK-DAG:     %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
@@ -79,7 +77,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
 // CHECK-NEXT:        %[[OFFSET:.*]] = arith.addi %[[C0]], %[[TILE_SLICE_INDEX]] : index
 // CHECK:             %[[PAD_1D:.*]] = vector.splat %[[PAD]] : vector<[4]xi32>
 // CHECK:             %[[LOAD_SLICE:.*]] = vector.maskedload %[[SRC]]{{\[}}%[[OFFSET]], %[[C0]]], %[[MASK_1D]], %[[PAD_1D]] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]xi32> into vector<[4]xi32>
-// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[CAST_TILE_TO_VECTOR]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
+// CHECK:             arm_sme.move_vector_to_tile_slice %[[LOAD_SLICE]], %[[TILE]], %[[TILE_SLICE_INDEX]] : vector<[4]xi32> into vector<[4]x[4]xi32>
 func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32>, %pad : i32) {
   %c0 = arith.constant 0 : index
   %c2 = arith.constant 2 : index

diff  --git a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
index 12f7e7333ebc973..b8db105f9c601b4 100644
--- a/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/arith-ops-to-sme.mlir
@@ -95,8 +95,7 @@ func.func @arith_constant_dense_2d_zero_f64() {
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[C16:.*]] = arith.constant 16 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i8 to vector<[16]x[16]xi8>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[16]x[16]xi8>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C16]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -115,8 +114,7 @@ func.func @arith_constant_dense_2d_nonzero_i8() {
 // CHECK: %[[C1:.*]] = arith.constant 1 : index
 // CHECK: %[[C2:.*]] = arith.constant 2 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[GET_TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[GET_TILE_ID]] : i64 to vector<[2]x[2]xf64>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[2]x[2]xf64>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C2]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {

diff  --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
deleted file mode 100644
index 65996e81c42d909..000000000000000
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir
+++ /dev/null
@@ -1,51 +0,0 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s
-
-// This test verifies the temporary casts that are emitted when lowering to
-// intrinsics to preserve data flow are correct. Canonicalization will remove
-// these.
-
-// CHECK-LABEL: @arm_sme_zero
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: arm_sme.intr.zero
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_zero(%dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  %tile = arm_sme.zero : vector<[16]x[16]xi8>
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_load
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.ld1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-// CHECK: }
-// CHECK: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
-func.func @arm_sme_tile_load(%dest : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
-  %c0 = arith.constant 0 : index
-  %tile = arm_sme.tile_load %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return %tile : vector<[16]x[16]xi8>
-}
-
-// -----
-
-// CHECK-LABEL: @arm_sme_tile_store(
-// CHECK-SAME:                      %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK: scf.for
-// CHECK:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
-// CHECK:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK:   "arm_sme.intr.st1b.horiz"({{.*}}, {{.*}}, %[[TILE_ID_I32]], {{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_tile_store(%tile : vector<[16]x[16]xi8>, %dest : memref<?x?xi8>) {
-  %c0 = arith.constant 0 : index
-  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
-  return
-}

diff  --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
index fa62332bc3f5b17..bd88da37bdf966d 100644
--- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
 
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
@@ -9,23 +9,21 @@
 // CHECK-LABEL:   func.func @arm_sme_load_tile_slice_hor_i8(
 // CHECK-SAME:                                              %[[SRC:.*]]: memref<?x?xi8>,
 // CHECK-SAME:                                              %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME:                                              %[[TILE:.*]]: vector<[16]x[16]xi8>,
-// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index) {
+// CHECK-SAME:                                              %[[TILE_SLICE_INDEX:.*]]: index)
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.ld1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -33,9 +31,10 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -43,9 +42,10 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -53,9 +53,10 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -63,9 +64,10 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -73,9 +75,10 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -83,9 +86,10 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -93,9 +97,10 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -103,9 +108,10 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -113,9 +119,10 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile : vector<[16]x[16]xi8>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<[16]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -123,9 +130,10 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xi16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -133,9 +141,10 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -143,9 +152,10 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xi64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -153,9 +163,10 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile : vector<[1]x[1]xi128>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vector<[1]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -163,9 +174,10 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -173,9 +185,10 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vector<[8]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -183,9 +196,10 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile : vector<[4]x[4]xf32>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vector<[4]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -193,9 +207,10 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
 // -----
 
 // CHECK-LABEL: @arm_sme_load_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile : vector<[2]x[2]xf64>, %tile_slice_index : index) {
+// CHECK: "arm_sme.intr.ld1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vector<[2]xi1>, %tile_slice_index : index) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -207,25 +222,23 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
 // -----
 
 // CHECK-LABEL:   func.func @arm_sme_store_tile_slice_hor_i8(
-// CHECK-SAME:                                               %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                                               %[[TILE_SLICE_INDEX:.*]]: index,
 // CHECK-SAME:                                               %[[MASK:.*]]: vector<[16]xi1>,
-// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>) {
+// CHECK-SAME:                                               %[[DEST:.*]]: memref<?x?xi8>)
 // CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[DEST]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK:           %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK:           %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[STRIDE:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK:           %[[OFFSET:.*]] = llvm.mul %[[C0_I64]], %[[STRIDE]]  : i64
 // CHECK:           %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFFSET]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK:           %[[TILE_SLICE_INDEX_I32:.*]] = arith.index_castui %[[TILE_SLICE_INDEX]] : index to i32
-// CHECK:           %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_INDEX_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK:           "arm_sme.intr.st1b.horiz"(%[[MASK]], %[[GEP]], %[[TILE_SLICE_INDEX_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK:           return
 // CHECK:         }
-func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+func.func @arm_sme_store_tile_slice_hor_i8(%tile_slice_index : index,  %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -233,9 +246,10 @@ func.func @arm_sme_store_tile_slice_hor_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -243,9 +257,10 @@ func.func @arm_sme_store_tile_slice_hor_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -253,9 +268,10 @@ func.func @arm_sme_store_tile_slice_hor_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -263,9 +279,10 @@ func.func @arm_sme_store_tile_slice_hor_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_i128
-// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -273,9 +290,10 @@ func.func @arm_sme_store_tile_slice_hor_i128(%tile : vector<[1]x[1]xi128>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -283,9 +301,10 @@ func.func @arm_sme_store_tile_slice_hor_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_bf16
-// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -293,9 +312,10 @@ func.func @arm_sme_store_tile_slice_hor_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f32
-// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -303,9 +323,10 @@ func.func @arm_sme_store_tile_slice_hor_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_hor_f64
-// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_hor_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -313,9 +334,10 @@ func.func @arm_sme_store_tile_slice_hor_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i8
-// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
+// CHECK: "arm_sme.intr.st1b.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i8(%tile_slice_index : index, %mask : vector<[16]xi1>, %dest : memref<?x?xi8>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
   return
 }
@@ -323,9 +345,10 @@ func.func @arm_sme_store_tile_slice_ver_i8(%tile : vector<[16]x[16]xi8>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xi16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
   return
 }
@@ -333,9 +356,10 @@ func.func @arm_sme_store_tile_slice_ver_i16(%tile : vector<[8]x[8]xi16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xi32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
   return
 }
@@ -343,9 +367,10 @@ func.func @arm_sme_store_tile_slice_ver_i32(%tile : vector<[4]x[4]xi32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xi64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
   return
 }
@@ -353,9 +378,10 @@ func.func @arm_sme_store_tile_slice_ver_i64(%tile : vector<[2]x[2]xi64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_i128
-// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) : (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
+// CHECK: "arm_sme.intr.st1q.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_i128(%tile_slice_index : index, %mask : vector<[1]xi1>, %dest : memref<?x?xi128>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
   return
 }
@@ -363,9 +389,10 @@ func.func @arm_sme_store_tile_slice_ver_i128(%tile : vector<[1]x[1]xi128>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
   return
 }
@@ -373,9 +400,10 @@ func.func @arm_sme_store_tile_slice_ver_f16(%tile : vector<[8]x[8]xf16>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) : (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
+// CHECK: "arm_sme.intr.st1h.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_bf16(%tile_slice_index : index, %mask : vector<[8]xi1>, %dest : memref<?x?xbf16>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
   return
 }
@@ -383,9 +411,10 @@ func.func @arm_sme_store_tile_slice_ver_bf16(%tile : vector<[8]x[8]xbf16>, %tile
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f32
-// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
+// CHECK: "arm_sme.intr.st1w.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f32(%tile_slice_index : index, %mask : vector<[4]xi1>, %dest : memref<?x?xf32>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
   return
 }
@@ -393,9 +422,10 @@ func.func @arm_sme_store_tile_slice_ver_f32(%tile : vector<[4]x[4]xf32>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_store_tile_slice_ver_f64
-// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) : (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
-func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
+// CHECK: "arm_sme.intr.st1d.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, !llvm.ptr, i32) -> ()
+func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : vector<[2]xi1>, %dest : memref<?x?xf64>) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   arm_sme.store_tile_slice %tile, %tile_slice_index, %mask, %dest[%c0] layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
   return
 }
@@ -407,9 +437,10 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile : vector<[2]x[2]xf64>, %tile_s
 // -----
 
 // CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_hor_i32
-// CHECK: "arm_sme.intr.write.horiz"({{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.horiz"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
   return
 }
@@ -417,9 +448,10 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_vector_to_tile_slice_ver_bf16
-// CHECK: "arm_sme.intr.write.vert"({{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
-func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> () {
+// CHECK: "arm_sme.intr.write.vert"({{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return
 }
@@ -431,8 +463,9 @@ func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i8
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %tile_slice_index : index) -> vector<[16]xi8> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+func.func @arm_sme_move_tile_slice_to_vector_i8(%tile_slice_index : index) -> vector<[16]xi8> {
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
@@ -440,8 +473,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i8(%tile : vector<[16]x[16]xi8>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %tile_slice_index : index) -> vector<[8]xi16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+func.func @arm_sme_move_tile_slice_to_vector_i16(%tile_slice_index : index) -> vector<[8]xi16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
@@ -449,8 +483,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i16(%tile : vector<[8]x[8]xi16>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
-func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %tile_slice_index : index) -> vector<[4]xi32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+func.func @arm_sme_move_tile_slice_to_vector_i32(%tile_slice_index : index) -> vector<[4]xi32> {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
@@ -458,8 +493,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i32(%tile : vector<[4]x[4]xi32>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %tile_slice_index : index) -> vector<[2]xi64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+func.func @arm_sme_move_tile_slice_to_vector_i64(%tile_slice_index : index) -> vector<[2]xi64> {
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
@@ -467,8 +503,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i64(%tile : vector<[2]x[2]xi64>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_i128
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
@@ -476,8 +513,9 @@ func.func @arm_sme_move_tile_slice_to_vector_i128(%tile : vector<[1]x[1]xi128>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %tile_slice_index : index) -> vector<[8]xf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+func.func @arm_sme_move_tile_slice_to_vector_f16(%tile_slice_index : index) -> vector<[8]xf16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
@@ -485,8 +523,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f16(%tile : vector<[8]x[8]xf16>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_bf16
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>, %tile_slice_index : index) -> vector<[8]xbf16> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile_slice_index : index) -> vector<[8]xbf16> {
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
@@ -494,8 +533,9 @@ func.func @arm_sme_move_tile_slice_to_vector_bf16(%tile : vector<[8]x[8]xbf16>,
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f32
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %tile_slice_index : index) -> vector<[4]xf32> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+func.func @arm_sme_move_tile_slice_to_vector_f32(%tile_slice_index : index) -> vector<[4]xf32> {
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
@@ -503,8 +543,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f32(%tile : vector<[4]x[4]xf32>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_f64
-// CHECK: "arm_sme.intr.read.horiz"({{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %tile_slice_index : index) -> vector<[2]xf64> {
+// CHECK: "arm_sme.intr.read.horiz"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+func.func @arm_sme_move_tile_slice_to_vector_f64(%tile_slice_index : index) -> vector<[2]xf64> {
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
@@ -512,8 +553,9 @@ func.func @arm_sme_move_tile_slice_to_vector_f64(%tile : vector<[2]x[2]xf64>, %t
 // -----
 
 // CHECK-LABEL: @arm_sme_move_tile_slice_to_vector_ver_i128
-// CHECK: "arm_sme.intr.read.vert"({{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile : vector<[1]x[1]xi128>, %tile_slice_index : index) -> vector<[1]xi128> {
+// CHECK: "arm_sme.intr.read.vert"({{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+func.func @arm_sme_move_tile_slice_to_vector_ver_i128(%tile_slice_index : index) -> vector<[1]xi128> {
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = arm_sme.move_tile_slice_to_vector %tile[%tile_slice_index] layout<vertical> : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }

diff  --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index 06bbd3050fdecec..b7ba3f728c705a3 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,27 +1,25 @@
 // RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-// -----
-
-// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector
-// CHECK-SAME: %[[TILE_ID:.*]]: i8
-func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 {
-  // CHECK-NOT: arm_sme.cast_tile_to_vector
-  // CHECK-NOT: arm_sme.cast_vector_to_tile
-  // CHECK-NEXT: return %[[TILE_ID]] : i8
-  %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8>
-  %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8
-  return %tile_id_1 : i8
-}
+// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
+// once it becomes unused, after lowering to control flow.
 
 // -----
 
-// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile
-// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8>
-func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> {
-  // CHECK-NOT: arm_sme.cast_vector_to_tile
-  // CHECK-NOT: arm_sme.cast_tile_to_vector
-  // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8>
-  %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8
-  %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
-  return %tile_1 : vector<[16]x[16]xi8>
+// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
+// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-NOT: vector<[4]x[4]xf32>
+func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+  %c10 = arith.constant 10 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+  cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
+^bb1(%1: index, %2: vector<[4]x[4]xf32>):  // 2 preds: ^bb0, ^bb2
+  %3 = arith.cmpi slt, %1, %c10 : index
+  cf.cond_br %3, ^bb2, ^bb3
+^bb2:  // pred: ^bb1
+  %4 = arith.addi %1, %c1 : index
+  cf.br ^bb1(%4, %tile : index, vector<[4]x[4]xf32>)
+^bb3:  // pred: ^bb1
+  return
 }

diff  --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
index 734bd9f15c8dea1..74e7293eaeca5fc 100644
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ b/mlir/test/Dialect/ArmSME/cse.mlir
@@ -1,16 +1,30 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
 
-// This test is checking that CSE does not remove 'arm_sme.get_tile_id' ops as
+// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
 // duplicates.
-// CHECK-LABEL: @get_tile_id
-// CHECK: %[[TILE_ID_0:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE_ID_1:.*]] = arm_sme.get_tile_id : i32
-// CHECK: "prevent.dce"(%[[TILE_ID_0]]) : (i32) -> ()
-// CHECK: "prevent.dce"(%[[TILE_ID_1]]) : (i32) -> ()
-func.func @get_tile_id() {
-  %tile_id_1 = arm_sme.get_tile_id : i32
-  %tile_id_2 = arm_sme.get_tile_id : i32
-  "prevent.dce"(%tile_id_1) : (i32) -> ()
-  "prevent.dce"(%tile_id_2) : (i32) -> ()
+
+// CHECK-LABEL: @zero_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @zero_tile() {
+  %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
+  %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
+  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
+  return
+}
+
+// CHECK-LABEL: @get_tile
+// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
+// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
+// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
+func.func @get_tile() {
+  %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
+  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
   return
 }

diff  --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir
index d35eb7f7bccc75e..85b95a8b6cf12b7 100644
--- a/mlir/test/Dialect/ArmSME/invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/invalid.mlir
@@ -1,89 +1,49 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
 //===----------------------------------------------------------------------===//
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> {
-  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16>
-  return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_rank_1() -> vector<[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8>
+  %0 = arm_sme.get_tile : vector<[16]xi8>
   return %0 : vector<[16]xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> {
+func.func @arm_sme_get_tile__bad_vector_type_i4() -> vector<[16]x[16]xi4> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x[16]xi4>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4>
+  %0 = arm_sme.get_tile : vector<[16]x[16]xi4>
   return %0 : vector<[16]x[16]xi4>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_0() -> vector<16x[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<16x[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8>
+  %0 = arm_sme.get_tile : vector<16x[16]xi8>
   return %0 : vector<16x[16]xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> {
+func.func @arm_sme_get_tile__bad_vector_type_non_scalable_dim_1() -> vector<[16]x16xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[16]x16xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8>
+  %0 = arm_sme.get_tile : vector<[16]x16xi8>
   return %0 : vector<[16]x16xi8>
 }
 
 // -----
 
-func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
+func.func @arm_sme_get_tile__bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> {
   // expected-error at +1 {{op result #0 must be a vector type that fits into a SME tile, but got 'vector<[4]x[16]xi8>'}}
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
+  %0 = arm_sme.get_tile : vector<[4]x[16]xi8>
   return %0 : vector<[4]x[16]xi8>
 }
 
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 {
-  // expected-error at +1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}}
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32
-  return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 {
-  // expected-error at +1 {{op operand #0 must be a vector type that fits into a SME tile, but got 'vector<[16]xi8>'}}
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8
-  return %0 : i8
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id__bad_type() -> i1 {
-  // expected-error at +1 {{op result #0 must be an identifier of a virtual tile (of a size) within the ZA storage}}
-  %0 = arm_sme.get_tile_id : i1
-  return %0 : i1
-}
-
 //===----------------------------------------------------------------------===//
 // arm_sme.move_vector_to_tile_slice
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index 1dbcc32e9259fc5..58ff7ef4d8340ec 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1,197 +1,78 @@
 // RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// arm_sme.cast_tile_to_vector
+// arm_sme.get_tile
 //===----------------------------------------------------------------------===//
 
-func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
-  return %0 : vector<[16]x[16]xi8>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16>
-  return %0 : vector<[8]x[8]xi16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
-  return %0 : vector<[4]x[4]xi32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64>
-  return %0 : vector<[2]x[2]xi64>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128>
-  return %0 : vector<[1]x[1]xi128>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16>
-  return %0 : vector<[8]x[8]xf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16>
-  return %0 : vector<[8]x[8]xbf16>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32>
-  return %0 : vector<[4]x[4]xf32>
-}
-
-// -----
-
-func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> {
-  // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64>
-  %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64>
-  return %0 : vector<[2]x[2]xf64>
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.cast_vector_to_tile
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8
-  return %0 : i8
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16
-  return %0 : i16
-}
 
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32
-  return %0 : i32
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64
-  return %0 : i64
-}
-
-// -----
-
-func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128
-  return %0 : i128
+func.func @arm_sme_get_tile_i8() {
+  // CHECK: arm_sme.get_tile : vector<[16]x[16]xi8>
+  %0 = arm_sme.get_tile : vector<[16]x[16]xi8>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16
-  return %0 : i16
+func.func @arm_sme_get_tile_i16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xi16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xi16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16
-  return %0 : i16
+func.func @arm_sme_get_tile_i32() {
+  // CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
+  %0 = arm_sme.get_tile : vector<[4]x[4]xi32>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32
-  return %0 : i32
+func.func @arm_sme_get_tile_i64() {
+  // CHECK: arm_sme.get_tile : vector<[2]x[2]xi64>
+  %0 = arm_sme.get_tile : vector<[2]x[2]xi64>
+  return
 }
 
 // -----
 
-func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 {
-  // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64
-  %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64
-  return %0 : i64
-}
-
-//===----------------------------------------------------------------------===//
-// arm_sme.get_tile_id
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_get_tile_id_i8() -> i8 {
-  // CHECK: arm_sme.get_tile_id : i8
-  %0 = arm_sme.get_tile_id : i8
-  return %0 : i8
+func.func @arm_sme_get_tile_i128() {
+  // CHECK: arm_sme.get_tile : vector<[1]x[1]xi128>
+  %0 = arm_sme.get_tile : vector<[1]x[1]xi128>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i16() -> i16 {
-  // CHECK: arm_sme.get_tile_id : i16
-  %0 = arm_sme.get_tile_id : i16
-  return %0 : i16
+func.func @arm_sme_get_tile_f16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xf16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xf16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i32() -> i32 {
-  // CHECK: arm_sme.get_tile_id : i32
-  %0 = arm_sme.get_tile_id : i32
-  return %0 : i32
+func.func @arm_sme_get_tile_bf16() {
+  // CHECK: arm_sme.get_tile : vector<[8]x[8]xbf16>
+  %0 = arm_sme.get_tile : vector<[8]x[8]xbf16>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i64() -> i64 {
-  // CHECK: arm_sme.get_tile_id : i64
-  %0 = arm_sme.get_tile_id : i64
-  return %0 : i64
+func.func @arm_sme_get_tile_f32() {
+  // CHECK: arm_sme.get_tile : vector<[4]x[4]xf32>
+  %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+  return
 }
 
 // -----
 
-func.func @arm_sme_get_tile_id_i128() -> i128 {
-  // CHECK: arm_sme.get_tile_id : i128
-  %0 = arm_sme.get_tile_id : i128
-  return %0 : i128
+func.func @arm_sme_get_tile_f64() {
+  // CHECK: arm_sme.get_tile : vector<[2]x[2]xf64>
+  %0 = arm_sme.get_tile : vector<[2]x[2]xf64>
+  return
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
new file mode 100644
index 000000000000000..39d9ab6491e3b49
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics
+
+// -----
+
+func.func @selecting_between_
diff erent_tiles_is_unsupported(%dest : memref<?x?xi32>, %cond: i1) {
+  %c0 = arith.constant 0 : index
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // Select between tileA and tileB. This is currently unsupported as it would
+  // require inserting tile move operations during tile allocation.
+  %tile = scf.if %cond -> vector<[4]x[4]xi32> {
+    scf.yield %tileA : vector<[4]x[4]xi32>
+  } else {
+    scf.yield %tileB : vector<[4]x[4]xi32>
+  }
+  // expected-error at +1 {{op already assigned 
diff erent SME virtual tile!}}
+  arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  return
+}

diff  --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
index a481516d4c15f93..1f895e4984ba844 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -6,17 +6,17 @@
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
 func.func @mixed_tiles() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.D ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // ZA15.Q is still free.
   return
 }
@@ -26,28 +26,26 @@ func.func @mixed_tiles() {
 // CHECK-LABEL: za_b
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_b() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_b = arm_sme.get_tile_id : i8
+  // CHECK-NEXT: tile_id = 0
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @za_b__out_of_tiles() {
-  %za0_b = arm_sme.get_tile_id : i8
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i8
+  %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   return
 }
 
 // -----
 
 func.func @za_b_overlapping_za_q() {
-  %za0_b = arm_sme.get_tile_id : i8
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -56,8 +54,8 @@ func.func @za_b_overlapping_za_q() {
 // CHECK-LABEL: za0_h
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
 func.func @za0_h() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
@@ -66,21 +64,23 @@ func.func @za0_h() {
 // CHECK-LABEL: za_h
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
-  // CHECK-NEXT: arith.constant 1
-  %za1_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  // CHECK-NEXT: tile_id = 1
+  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_h__out_of_tiles
 func.func @za_h__out_of_tiles() {
-  %za0_h = arm_sme.get_tile_id : i16
-  %za1_h = arm_sme.get_tile_id : i16
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  // CHECK-NEXT: tile_id = 1
+  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i16
+  %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   return
 }
 
@@ -90,14 +90,14 @@ func.func @za_h__out_of_tiles() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_s() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 3
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -107,38 +107,37 @@ func.func @za_h_overlapping_za_s() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_d() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: tile_id = 0
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // ZA1.Q, ZA9.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 1
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA5.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 5
-  %za5_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 5
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_h_overlapping_za_q() {
-  %za0_h = arm_sme.get_tile_id : i16
-  %za0_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za8_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -147,8 +146,8 @@ func.func @za_h_overlapping_za_q() {
 // CHECK-LABEL: za0_s
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
 func.func @za0_s() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -157,27 +156,26 @@ func.func @za0_s() {
 // CHECK-LABEL: za_s
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 2
-  %za2_s = arm_sme.get_tile_id : i32
-  // CHECK-NEXT: arith.constant 3
-  %za3_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 2
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 3
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
 // -----
 
 func.func @za_s__out_of_tiles() {
-  %za0_s = arm_sme.get_tile_id : i32
-  %za1_s = arm_sme.get_tile_id : i32
-  %za2_s = arm_sme.get_tile_id : i32
-  %za3_s = arm_sme.get_tile_id : i32
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i32
+  %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   return
 }
 
@@ -187,42 +185,41 @@ func.func @za_s__out_of_tiles() {
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s_overlapping_za_d() {
   // ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
-  // CHECK-NEXT: arith.constant 0
-  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 0
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
-  // CHECK-NEXT: arith.constant 1
-  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 1
+  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
-  // CHECK-NEXT: arith.constant 2
-  %za2_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: tile_id = 2
+  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // ZA3.Q, ZA11.Q
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // ZA7.Q, ZA15.Q
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_s_overlapping_za_q() {
-  %za0_s = arm_sme.get_tile_id : i32
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -231,8 +228,8 @@ func.func @za_s_overlapping_za_q() {
 // CHECK-LABEL: za0_d
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
 func.func @za0_d() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 0
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
@@ -241,63 +238,61 @@ func.func @za0_d() {
 // CHECK-LABEL: za_d
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_d() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 1
-  %za1_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 2
-  %za2_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 3
-  %za3_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 4
-  %za4_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 5
-  %za5_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 6
-  %za6_d = arm_sme.get_tile_id : i64
-  // CHECK-NEXT: arith.constant 7
-  %za7_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: tile_id = 0
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 1
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 2
+  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 3
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 4
+  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 5
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 6
+  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 7
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_d__out_of_tiles() {
-  %za0_d = arm_sme.get_tile_id : i64
-  %za1_d = arm_sme.get_tile_id : i64
-  %za2_d = arm_sme.get_tile_id : i64
-  %za3_d = arm_sme.get_tile_id : i64
-  %za4_d = arm_sme.get_tile_id : i64
-  %za5_d = arm_sme.get_tile_id : i64
-  %za6_d = arm_sme.get_tile_id : i64
-  %za7_d = arm_sme.get_tile_id : i64
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i64
+  %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   return
 }
 
 // -----
 
 func.func @za_d_overlapping_za_q() {
-  %za0_d = arm_sme.get_tile_id : i64
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -306,8 +301,8 @@ func.func @za_d_overlapping_za_q() {
 // CHECK-LABEL: za0_q
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
 func.func @za0_q() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 0
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
@@ -316,62 +311,61 @@ func.func @za0_q() {
 // CHECK-LABEL: za_q
 // CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_q() {
-  // CHECK-NEXT: arith.constant 0
-  %za0_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 1
-  %za1_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 2
-  %za2_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 3
-  %za3_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 4
-  %za4_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 5
-  %za5_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 6
-  %za6_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 7
-  %za7_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 8
-  %za8_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 9
-  %za9_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 10
-  %za10_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 11
-  %za11_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 12
-  %za12_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 13
-  %za13_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 14
-  %za14_q = arm_sme.get_tile_id : i128
-  // CHECK-NEXT: arith.constant 15
-  %za15_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: tile_id = 0
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 1
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 2
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 4
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 6
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 8
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 10
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 12
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 14
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }
 
 // -----
 
 func.func @za_q__out_of_tiles() {
-  %za0_q = arm_sme.get_tile_id : i128
-  %za1_q = arm_sme.get_tile_id : i128
-  %za2_q = arm_sme.get_tile_id : i128
-  %za3_q = arm_sme.get_tile_id : i128
-  %za4_q = arm_sme.get_tile_id : i128
-  %za5_q = arm_sme.get_tile_id : i128
-  %za6_q = arm_sme.get_tile_id : i128
-  %za7_q = arm_sme.get_tile_id : i128
-  %za8_q = arm_sme.get_tile_id : i128
-  %za9_q = arm_sme.get_tile_id : i128
-  %za10_q = arm_sme.get_tile_id : i128
-  %za11_q = arm_sme.get_tile_id : i128
-  %za12_q = arm_sme.get_tile_id : i128
-  %za13_q = arm_sme.get_tile_id : i128
-  %za14_q = arm_sme.get_tile_id : i128
-  %za15_q = arm_sme.get_tile_id : i128
-  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // expected-error at +1 {{ran out of SME virtual tiles!}}
-  %next_tile = arm_sme.get_tile_id : i128
+  %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   return
 }

diff  --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 2378f4234aef1ef..04412e4db1c5f3a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,7 +1,4 @@
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm \
-// RUN:           -allocate-arm-sme-tiles -canonicalize      \
-// RUN:           -allow-unregistered-dialect                \
-// RUN: | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
 
 // This test verifies the tile mask operand of the zero intrinsic zeroes
 // the correct tiles. Both integer and floating-point datatypes are checked.
@@ -10,13 +7,8 @@
 
 // CHECK-LABEL: zero_za_b
 func.func @zero_za_b() {
-  // CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8
-  // CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
   %zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
-  "prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -24,20 +16,10 @@ func.func @zero_za_b() {
 
 // CHECK-LABEL: zero_za_h
 func.func @zero_za_h() {
-  // CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16
-  // CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16>
+  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
   %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
-  "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
   %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
-  "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -45,32 +27,14 @@ func.func @zero_za_h() {
 
 // CHECK-LABEL: zero_za_s
 func.func @zero_za_s() {
-  // CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32
-  // CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK:      arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
   %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
   %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
   %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32>
+  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
   %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
-  "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -78,55 +42,21 @@ func.func @zero_za_s() {
 
 // CHECK-LABEL: zero_za_d
 func.func @zero_za_d() {
-  // CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64
-  // CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64
-
-  // CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32
-  // CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32
-
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
   %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
   %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:      "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
   %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
   %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
   %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
   %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
   %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
-  "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
-  // CHECK:     "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> ()
-  // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64>
+  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
   %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
-  "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
   return
 }

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 77ac071ef67de9a..d16d250b70eb3f6 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s
 
 //===----------------------------------------------------------------------===//
 // vector.transfer_write
@@ -10,13 +10,9 @@
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
-// CHECK-DAG:  %[[C255:.*]] = arith.constant 255 : i32
 // CHECK-DAG:  %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[16]xi1>
 // CHECK-DAG:  %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-DAG:  %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32
-// CHECK-DAG:  "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> ()
+// CHECK-DAG:  "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
 // CHECK-DAG:  %[[VSCALE:.*]] = vector.vscale
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
@@ -27,8 +23,7 @@
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
   %cst = arith.constant dense<0> : vector<[16]x[16]xi8>
@@ -49,8 +44,6 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-LABEL: @vector_load_i8_with_offset(
 // CHECK-SAME:                              %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[C123:.*]] = arith.constant 123 : index
@@ -68,10 +61,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   %c123 = arith.constant 123 : index
@@ -84,8 +75,6 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-LABEL: @vector_load_i8_from_rank_1_memref(
 // CHECK-SAME:                                     %[[ARG0:.*]]: memref<?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?xi8> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-// CHECK-DAG:  %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8
-// CHECK-DAG:  %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:  %[[C1:.*]] = arith.constant 1 : index
 // CHECK-DAG:  %[[MIN_SVL_B:.*]] = arith.constant 16 : index
@@ -98,10 +87,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref<?x?xi8>) -> vector<[16]x[16
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[TILE_SLICE_IDX_I64]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.ld1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
-// CHECK-NEXT: return  %[[CAST_TILE_TO_VECTOR]] : vector<[16]x[16]xi8>
 func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16]x[16]xi8> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0] : memref<?xi8>, vector<[16]x[16]xi8>
@@ -113,12 +100,10 @@ func.func @vector_load_i8_from_rank_1_memref(%arg0 : memref<?xi8>) -> vector<[16
 
 // CHECK-LABEL: @vector_load_i16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xi16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
@@ -129,13 +114,10 @@ func.func @vector_load_i16(%arg0 : memref<?x?xi16>) -> vector<[8]x[8]xi16> {
 
 // CHECK-LABEL: @vector_load_i32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT:   arith.extui %[[TILE_ID]]
-// CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
@@ -146,12 +128,10 @@ func.func @vector_load_i32(%arg0 : memref<?x?xi32>) -> vector<[4]x[4]xi32> {
 
 // CHECK-LABEL: @vector_load_i64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xi64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
@@ -162,12 +142,10 @@ func.func @vector_load_i64(%arg0 : memref<?x?xi64>) -> vector<[2]x[2]xi64> {
 
 // CHECK-LABEL: @vector_load_f16(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
@@ -178,12 +156,10 @@ func.func @vector_load_f16(%arg0 : memref<?x?xf16>) -> vector<[8]x[8]xf16> {
 
 // CHECK-LABEL: @vector_load_bf16(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xbf16>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i16
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i16 to vector<[8]x[8]xbf16>
 // CHECK-DAG: %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[TILE_ID]] : i16 to i32
 // CHECK:       arm_sme.intr.ld1h.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
@@ -194,13 +170,10 @@ func.func @vector_load_bf16(%arg0 : memref<?x?xbf16>) -> vector<[8]x[8]xbf16> {
 
 // CHECK-LABEL: @vector_load_f32(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xf32>
 // CHECK-DAG: %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK-NOT:   arith.extui %[[TILE_ID]]
-// CHECK-NOT:   arith.trunci %[[TILE_ID]]
 // CHECK:       arm_sme.intr.ld1w.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
@@ -211,12 +184,10 @@ func.func @vector_load_f32(%arg0 : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
 
 // CHECK-LABEL: @vector_load_f64(
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i64
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i64 to vector<[2]x[2]xf64>
 // CHECK-DAG: %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i64 to i32
 // CHECK:       arm_sme.intr.ld1d.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
@@ -227,10 +198,8 @@ func.func @vector_load_f64(%arg0 : memref<?x?xf64>) -> vector<[2]x[2]xf64> {
 
 // CHECK-LABEL: @vector_load_i128(
 // CHECK-SAME:                    %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i128
-// CHECK-DAG: %[[CAST_TILE_TO_VECTOR:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i128 to vector<[1]x[1]xi128>
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[TILE_ID]] : i128 to i32
 // CHECK:       arm_sme.intr.ld1q.horiz
+// CHECK-SAME:  tile_id = 0
 func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
   %c0 = arith.constant 0 : index
   %tile = vector.load %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
@@ -244,7 +213,6 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // -----
 
 // CHECK-LABEL: @vector_store_i8(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[16]x[16]xi8>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi8>)
 // CHECK-DAG:  %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<?x?xi8> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-DAG:  %[[C0:.*]] = arith.constant 0 : index
@@ -256,19 +224,18 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
 // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index
 // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] {
 // CHECK:        %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64
-// CHECK-NEXT:   %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8
 // CHECK-NEXT:   %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
 // CHECK-NEXT:   %[[OFF0:.*]] = llvm.mul %[[TILE_SLICE_I64]], %[[STRIDE0]]  : i64
 // CHECK-NEXT:   %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_I64]]  : i64
 // CHECK-NEXT:   %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8
 // CHECK-NEXT:   %[[TILE_SLICE_I32:.*]] = arith.index_castui %[[TILE_SLICE]] : index to i32
-// CHECK-NEXT:   %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i8 to i32
-// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_ID_I32]], %[[TILE_SLICE_I32]]) : (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+// CHECK-NEXT:   "arm_sme.intr.st1b.horiz"(%[[PTRUE_ALL]], %[[GEP]], %[[TILE_SLICE_I32]]) <{tile_id = 0 : i32}> : (vector<[16]xi1>, !llvm.ptr, i32) -> ()
 // CHECK-NEXT: }
 // CHECK-NEXT: return
-func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>) {
+func.func @vector_store_i8(%arg0 : memref<?x?xi8>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi8>, vector<[16]x[16]xi8>
   return
 }
@@ -276,15 +243,14 @@ func.func @vector_store_i8(%tile : vector<[16]x[16]xi8>, %arg0 : memref<?x?xi8>)
 // -----
 
 // CHECK-LABEL: @vector_store_i16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xi16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xi16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i16(%arg0 : memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
   return
 }
@@ -292,16 +258,14 @@ func.func @vector_store_i16(%tile : vector<[8]x[8]xi16>, %arg0 : memref<?x?xi16>
 // -----
 
 // CHECK-LABEL: @vector_store_i32(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi32>)
 // CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
-func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i32(%arg0 : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
@@ -309,15 +273,14 @@ func.func @vector_store_i32(%tile : vector<[4]x[4]xi32>, %arg0 : memref<?x?xi32>
 // -----
 
 // CHECK-LABEL: @vector_store_i64(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xi64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xi64>)
 // CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xi64> to i64
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
-func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i64(%arg0 : memref<?x?xi64>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi64>, vector<[2]x[2]xi64>
   return
 }
@@ -325,15 +288,14 @@ func.func @vector_store_i64(%tile : vector<[2]x[2]xi64>, %arg0 : memref<?x?xi64>
 // -----
 
 // CHECK-LABEL: @vector_store_f16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xf16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f16(%arg0 : memref<?x?xf16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
   return
 }
@@ -341,31 +303,28 @@ func.func @vector_store_f16(%tile : vector<[8]x[8]xf16>, %arg0 : memref<?x?xf16>
 // -----
 
 // CHECK-LABEL: @vector_store_bf16(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[8]x[8]xbf16>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xbf16>)
 // CHECK:     %[[MIN_SVL_H:.*]] = arith.constant 8 : index
 // CHECK:     %[[SVL_H:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_H]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[8]x[8]xbf16> to i16
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
 // CHECK:       arm_sme.intr.st1h.horiz
-func.func @vector_store_bf16(%tile : vector<[8]x[8]xbf16>, %arg0 : memref<?x?xbf16>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_bf16(%arg0 : memref<?x?xbf16>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xbf16>, vector<[8]x[8]xbf16>
   return
 }
 // -----
 
 // CHECK-LABEL: @vector_store_f32(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[4]x[4]xf32>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf32>)
 // CHECK:     %[[MIN_SVL_S:.*]] = arith.constant 4 : index
 // CHECK:     %[[SVL_S:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_S]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xf32> to i32
-// CHECK-NOT:   arith.extui %[[CAST_VECTOR_TO_TILE]]
-// CHECK-NOT:   arith.trunci %[[CAST_VECTOR_TO_TILE]]
 // CHECK:       arm_sme.intr.st1w.horiz
-func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f32(%arg0 : memref<?x?xf32>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
   return
 }
@@ -373,15 +332,14 @@ func.func @vector_store_f32(%tile : vector<[4]x[4]xf32>, %arg0 : memref<?x?xf32>
 // -----
 
 // CHECK-LABEL: @vector_store_f64(
-// CHECK-SAME:                   %[[TILE:.*]]: vector<[2]x[2]xf64>,
 // CHECK-SAME:                   %[[ARG0:.*]]: memref<?x?xf64>)
 // CHECK:     %[[MIN_SVL_D:.*]] = arith.constant 2 : index
 // CHECK:     %[[SVL_D:.*]] = arith.muli %{{.*}}, %[[MIN_SVL_D]] : index
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[2]x[2]xf64> to i64
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i64 to i32
 // CHECK:       arm_sme.intr.st1d.horiz
-func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_f64(%arg0 : memref<?x?xf64>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xf64>, vector<[2]x[2]xf64>
   return
 }
@@ -389,13 +347,12 @@ func.func @vector_store_f64(%tile : vector<[2]x[2]xf64>, %arg0 : memref<?x?xf64>
 // -----
 
 // CHECK-LABEL: @vector_store_i128(
-// CHECK-SAME:                     %[[TILE:.*]]: vector<[1]x[1]xi128>,
 // CHECK-SAME:                     %[[ARG0:.*]]: memref<?x?xi128>)
-// CHECK:       %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[1]x[1]xi128> to i128
-// CHECK:       %[[TILE_ID_I32:.*]] = arith.trunci %[[CAST_VECTOR_TO_TILE]] : i128 to i32
 // CHECK:       arm_sme.intr.st1q.horiz
-func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi128>) {
+// CHECK-SAME:  tile_id = 0
+func.func @vector_store_i128(%arg0 : memref<?x?xi128>) {
   %c0 = arith.constant 0 : index
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
   return
 }
@@ -407,12 +364,11 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>)
-func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>)
+func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>) {
   // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1>
-  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16
-  // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32
-  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  // CHECK: "arm_sme.intr.mopa"(%[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
 }
@@ -420,8 +376,9 @@ func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_bf16
-func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) {
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16>
   "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> ()
 }
@@ -429,10 +386,9 @@ func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f32
-func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) {
-  // CHECK-NOT: arith.extui
-  // CHECK-NOT: arith.trunci
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -440,9 +396,9 @@ func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_add_f64
-func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
-  // CHECK: arith.trunci {{.*}} : i64 to i32
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -451,8 +407,8 @@ func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]
 
 // CHECK-LABEL: @vector_outerproduct_no_accumulator
 func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
-  // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> ()
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
   %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
 }
@@ -460,12 +416,12 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f32
-// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
-func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
+func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %dim0 : index, %dim1 : index) {
   // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1>
   // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1>
-  // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32
-  // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  // CHECK: "arm_sme.intr.mopa"(%[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) <{tile_id = 0 : i32}> : (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>)
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%result) : (vector<[4]x[4]xf32>) -> ()
@@ -474,11 +430,12 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>,
-func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>,
+func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xf16>
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xf16>) -> ()
@@ -487,11 +444,12 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_bf16
-// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>,
-func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>
+func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[8]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>)
+  %acc = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16>
   "prevent.dce"(%result) : (vector<[8]x[8]xbf16>) -> ()
@@ -500,11 +458,12 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto
 // -----
 
 // CHECK-LABEL: @vector_outerproduct_masked_f64
-// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>,
-func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) {
+// CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>,
+func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %dim0 : index, %dim1 : index) {
   // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
   // CHECK: vector.create_mask {{.*}} : vector<[2]xi1>
-  // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  // CHECK: "arm_sme.intr.mopa"({{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>)
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1>
   %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
@@ -520,7 +479,8 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f
 
 // -----
 
-func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) {
+func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>) {
+  %acc = arm_sme.get_tile : vector<[16]x[16]xi8>
   // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8>
@@ -529,7 +489,8 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) {
+func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) {
+  %acc = arm_sme.get_tile : vector<[2]x[2]xf64>
   // expected-error at +1 {{unsupported kind}}
   %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> ()
@@ -537,8 +498,9 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v
 
 // -----
 
-func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) {
+func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %mask : vector<[4]x[4]xi1>) {
   // CHECK: vector.outerproduct
+  %acc = arm_sme.get_tile : vector<[4]x[4]xf32>
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
@@ -550,14 +512,13 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i32(
-// CHECK-SAME:                       %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                       %[[SLICE:.*]]: vector<[4]xi32>,
 // CHECK-SAME:                       %[[INDEX:.*]]: index)
-func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
+func.func @vector_insert_slice_i32(%slice: vector<[4]xi32>, %row: index) -> vector<[4]x[4]xi32>{
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-  // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
-  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
@@ -565,8 +526,9 @@ func.func @vector_insert_slice_i32(%tile: vector<[4]x[4]xi32>, %slice: vector<[4
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i8
-func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_slice_i8(%slice: vector<[16]xi8>, %row: index) -> vector<[16]x[16]xi8> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[16]xi8> into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
@@ -574,8 +536,9 @@ func.func @vector_insert_slice_i8(%tile: vector<[16]x[16]xi8>, %slice: vector<[1
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i16
-func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_slice_i16(%slice: vector<[8]xi16>, %row: index) -> vector<[8]x[8]xi16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xi16> into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
@@ -583,8 +546,9 @@ func.func @vector_insert_slice_i16(%tile: vector<[8]x[8]xi16>, %slice: vector<[8
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i64
-func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_slice_i64(%slice: vector<[2]xi64>, %row: index) -> vector<[2]x[2]xi64> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xi64> into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
@@ -592,8 +556,9 @@ func.func @vector_insert_slice_i64(%tile: vector<[2]x[2]xi64>, %slice: vector<[2
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_i128
-func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_slice_i128(%slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[1]xi128> into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
@@ -601,8 +566,9 @@ func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f16
-func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_slice_f16(%slice: vector<[8]xf16>, %row: index) -> vector<[8]x[8]xf16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xf16> into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
@@ -610,8 +576,9 @@ func.func @vector_insert_slice_f16(%tile: vector<[8]x[8]xf16>, %slice: vector<[8
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_bf16
-func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_slice_bf16(%slice: vector<[8]xbf16>, %row: index) -> vector<[8]x[8]xbf16> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[8]xbf16> into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
@@ -619,8 +586,9 @@ func.func @vector_insert_slice_bf16(%tile: vector<[8]x[8]xbf16>, %slice: vector<
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f32
-func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_slice_f32(%slice: vector<[4]xf32>, %row: index) -> vector<[4]x[4]xf32> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xf32> into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
@@ -628,8 +596,9 @@ func.func @vector_insert_slice_f32(%tile: vector<[4]x[4]xf32>, %slice: vector<[4
 // -----
 
 // CHECK-LABEL: @vector_insert_slice_f64
-func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_slice_f64(%slice: vector<[2]xf64>, %row: index) -> vector<[2]x[2]xf64> {
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %new_tile = vector.insert %slice, %tile[%row] : vector<[2]xf64> into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
@@ -637,19 +606,18 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i32(
-// CHECK-SAME:                         %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                         %[[EL:.*]]: i32,
 // CHECK-SAME:                         %[[ROW:.*]]: index,
 // CHECK-SAME:                         %[[COL:.*]]: index)
-func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
-  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
-  // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+func.func @vector_insert_element_i32(%el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> {
+  // CHECK-DAG:  %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+  // CHECK-DAG:  %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-DAG:  %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32>
   // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32
-  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32>
   return %new_tile : vector<[4]x[4]xi32>
 }
@@ -657,9 +625,10 @@ func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i8
-func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+func.func @vector_insert_element_i8(%el: i8, %row: index, %col: index) -> vector<[16]x[16]xi8> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %new_tile = vector.insert %el, %tile[%row, %col] : i8 into vector<[16]x[16]xi8>
   return %new_tile : vector<[16]x[16]xi8>
 }
@@ -667,9 +636,10 @@ func.func @vector_insert_element_i8(%tile: vector<[16]x[16]xi8>, %el: i8, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i16
-func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+func.func @vector_insert_element_i16(%el: i16, %row: index, %col: index) -> vector<[8]x[8]xi16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %new_tile = vector.insert %el, %tile[%row, %col] : i16 into vector<[8]x[8]xi16>
   return %new_tile : vector<[8]x[8]xi16>
 }
@@ -677,9 +647,10 @@ func.func @vector_insert_element_i16(%tile: vector<[8]x[8]xi16>, %el: i16, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i64
-func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+func.func @vector_insert_element_i64(%el: i64, %row: index, %col: index) -> vector<[2]x[2]xi64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %new_tile = vector.insert %el, %tile[%row, %col] : i64 into vector<[2]x[2]xi64>
   return %new_tile : vector<[2]x[2]xi64>
 }
@@ -687,9 +658,10 @@ func.func @vector_insert_element_i64(%tile: vector<[2]x[2]xi64>, %el: i64, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_i128
-func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+func.func @vector_insert_element_i128(%el: i128, %row: index, %col: index) -> vector<[1]x[1]xi128> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %new_tile = vector.insert %el, %tile[%row, %col] : i128 into vector<[1]x[1]xi128>
   return %new_tile : vector<[1]x[1]xi128>
 }
@@ -697,9 +669,10 @@ func.func @vector_insert_element_i128(%tile: vector<[1]x[1]xi128>, %el: i128, %r
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f16
-func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+func.func @vector_insert_element_f16(%el: f16, %row: index, %col: index) -> vector<[8]x[8]xf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %new_tile = vector.insert %el, %tile[%row, %col] : f16 into vector<[8]x[8]xf16>
   return %new_tile : vector<[8]x[8]xf16>
 }
@@ -707,9 +680,10 @@ func.func @vector_insert_element_f16(%tile: vector<[8]x[8]xf16>, %el: f16, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_bf16
-func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+func.func @vector_insert_element_bf16(%el: bf16, %row: index, %col: index) -> vector<[8]x[8]xbf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %new_tile = vector.insert %el, %tile[%row, %col] : bf16 into vector<[8]x[8]xbf16>
   return %new_tile : vector<[8]x[8]xbf16>
 }
@@ -717,9 +691,10 @@ func.func @vector_insert_element_bf16(%tile: vector<[8]x[8]xbf16>, %el: bf16, %r
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f32
-func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+func.func @vector_insert_element_f32(%el: f32, %row: index, %col: index) -> vector<[4]x[4]xf32> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %new_tile = vector.insert %el, %tile[%row, %col] : f32 into vector<[4]x[4]xf32>
   return %new_tile : vector<[4]x[4]xf32>
 }
@@ -727,9 +702,10 @@ func.func @vector_insert_element_f32(%tile: vector<[4]x[4]xf32>, %el: f32, %row:
 // -----
 
 // CHECK-LABEL: @vector_insert_element_f64
-func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
-  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+func.func @vector_insert_element_f64(%el: f64, %row: index, %col: index) -> vector<[2]x[2]xf64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+  // CHECK: "arm_sme.intr.write.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %new_tile = vector.insert %el, %tile[%row, %col] : f64 into vector<[2]x[2]xf64>
   return %new_tile : vector<[2]x[2]xf64>
 }
@@ -741,14 +717,13 @@ func.func @vector_insert_element_f64(%tile: vector<[2]x[2]xf64>, %el: f64, %row:
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i32(
-// CHECK-SAME:                        %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                        %[[INDEX:.*]]: index)
-func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) -> vector<[4]xi32> {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_slice_i32(%row: index) -> vector<[4]xi32> {
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_SLICE_INDEX]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32>
   return %slice : vector<[4]xi32>
 }
@@ -756,8 +731,9 @@ func.func @vector_extract_slice_i32(%tile: vector<[4]x[4]xi32>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i8
-func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) -> vector<[16]xi8> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_slice_i8(%row: index) -> vector<[16]xi8> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %slice = vector.extract %tile[%row] : vector<[16]xi8> from vector<[16]x[16]xi8>
   return %slice : vector<[16]xi8>
 }
@@ -765,8 +741,9 @@ func.func @vector_extract_slice_i8(%tile: vector<[16]x[16]xi8>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i16
-func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) -> vector<[8]xi16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_slice_i16(%row: index) -> vector<[8]xi16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %slice = vector.extract %tile[%row] : vector<[8]xi16> from vector<[8]x[8]xi16>
   return %slice : vector<[8]xi16>
 }
@@ -774,8 +751,9 @@ func.func @vector_extract_slice_i16(%tile: vector<[8]x[8]xi16>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i64
-func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) -> vector<[2]xi64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_slice_i64(%row: index) -> vector<[2]xi64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %slice = vector.extract %tile[%row] : vector<[2]xi64> from vector<[2]x[2]xi64>
   return %slice : vector<[2]xi64>
 }
@@ -783,8 +761,9 @@ func.func @vector_extract_slice_i64(%tile: vector<[2]x[2]xi64>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_i128
-func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -> vector<[1]xi128> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_slice_i128(%row: index) -> vector<[1]xi128> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %slice = vector.extract %tile[%row] : vector<[1]xi128> from vector<[1]x[1]xi128>
   return %slice : vector<[1]xi128>
 }
@@ -792,8 +771,9 @@ func.func @vector_extract_slice_i128(%tile: vector<[1]x[1]xi128>, %row: index) -
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f16
-func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) -> vector<[8]xf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_slice_f16(%row: index) -> vector<[8]xf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %slice = vector.extract %tile[%row] : vector<[8]xf16> from vector<[8]x[8]xf16>
   return %slice : vector<[8]xf16>
 }
@@ -801,8 +781,9 @@ func.func @vector_extract_slice_f16(%tile: vector<[8]x[8]xf16>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_bf16
-func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -> vector<[8]xbf16> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_slice_bf16(%row: index) -> vector<[8]xbf16> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %slice = vector.extract %tile[%row] : vector<[8]xbf16> from vector<[8]x[8]xbf16>
   return %slice : vector<[8]xbf16>
 }
@@ -810,8 +791,9 @@ func.func @vector_extract_slice_bf16(%tile: vector<[8]x[8]xbf16>, %row: index) -
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f32
-func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) -> vector<[4]xf32> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_slice_f32(%row: index) -> vector<[4]xf32> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %slice = vector.extract %tile[%row] : vector<[4]xf32> from vector<[4]x[4]xf32>
   return %slice : vector<[4]xf32>
 }
@@ -819,8 +801,9 @@ func.func @vector_extract_slice_f32(%tile: vector<[4]x[4]xf32>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_slice_f64
-func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> vector<[2]xf64> {
-  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_slice_f64(%row: index) -> vector<[2]xf64> {
+  // CHECK: %{{.*}} = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %slice = vector.extract %tile[%row] : vector<[2]xf64> from vector<[2]x[2]xf64>
   return %slice : vector<[2]xf64>
 }
@@ -828,16 +811,15 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) ->
 // -----
 
 // CHECK-LABEL: @vector_extract_element(
-// CHECK-SAME:                          %[[TILE:.*]]: vector<[4]x[4]xi32>,
 // CHECK-SAME:                          %[[ROW:.*]]: index,
 // CHECK-SAME:                          %[[COL:.*]]: index)
-func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 {
-  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+func.func @vector_extract_element(%row: index, %col: index) -> i32 {
   // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
-  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
   // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32
-  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[ROW_I32]]) <{tile_id = 0 : i32}> : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32>
   return %el : i32
 }
@@ -845,9 +827,10 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col:
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i8
-func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+func.func @vector_extract_element_i8(%row: index, %col: index) -> i8 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8>
+  %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8>
   return %el : i8
 }
@@ -855,9 +838,10 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i16
-func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+func.func @vector_extract_element_i16(%row: index, %col: index) -> i16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16>
   return %el : i16
 }
@@ -865,9 +849,10 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i64
-func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+func.func @vector_extract_element_i64(%row: index, %col: index) -> i64 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64>
   return %el : i64
 }
@@ -875,9 +860,10 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_i128
-func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+func.func @vector_extract_element_i128(%row: index, %col: index) -> i128 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128>
+  %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128>
   return %el : i128
 }
@@ -885,9 +871,10 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index,
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f16
-func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+func.func @vector_extract_element_f16(%row: index, %col: index) -> f16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16>
   return %el : f16
 }
@@ -895,9 +882,10 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_bf16
-func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+func.func @vector_extract_element_bf16(%row: index, %col: index) -> bf16 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16>
+  %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16>
   return %el : bf16
 }
@@ -905,9 +893,10 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index,
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f32
-func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+func.func @vector_extract_element_f32(%row: index, %col: index) -> f32 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32>
   return %el : f32
 }
@@ -915,9 +904,10 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %
 // -----
 
 // CHECK-LABEL: @vector_extract_element_f64
-func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 {
-  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
+  // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}) <{tile_id = 0 : i32}> : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64>
+  %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }

diff  --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index ae3b260da83a2c1..2491b2e2468cdaf 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -452,8 +452,7 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
 // CHECK: %[[C4:.*]] = arith.constant 4 : index
 // CHECK: %[[C0:.*]] = arith.constant 0 : index
 // CHECK: %[[SRC_1D:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK: %[[TILE:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK: %[[VSCALE:.*]] = vector.vscale
 // CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
 // CHECK: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
@@ -500,8 +499,7 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
 // CHECK-LABEL:   func.func @splat_vec2d_from_i32(
 // CHECK-SAME:      %[[SRC:.*]]: i32) {
 // CHECK:   %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
-// CHECK:   %[[TILE_ID:.*]] = arm_sme.get_tile_id : i32
-// CHECK:   arm_sme.cast_tile_to_vector %[[TILE_ID]] : i32 to vector<[4]x[4]xi32>
+// CHECK:   arm_sme.get_tile : vector<[4]x[4]xi32>
 // CHECK:   %[[VSCALE:.*]] = vector.vscale
 // CHECK:   %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
 // CHECK:   scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
index 18b95cf2fdf843c..81546de6a3466b7 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir
@@ -4,9 +4,9 @@
 // RUN:   -lower-vector-mask \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=entry -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
index f189fd97d66cde7..dcd780b23161da4 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir
@@ -2,11 +2,11 @@
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
 // RUN:   -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
-// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \

diff  --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
index 413cb06b0ae1894..db5b098770402c8 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
@@ -2,11 +2,11 @@
 // RUN:   -transform-interpreter -test-transform-dialect-erase-schedule \
 // RUN:   -canonicalize \
 // RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
 // RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
 // RUN:   -convert-arm-sme-to-llvm \
 // RUN:   -convert-vector-to-llvm=enable-arm-sve \
-// RUN:   -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:   -e=main -entry-point-result=void \
 // RUN:   -march=aarch64 -mattr="+sve,+sme" \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
index 0c186cc373a3b32..936163d1cd30d99 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 442a70cacd66508..8c73c24d695cfb2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 74b51dcc9b4df3a..965337c60b9ffdd 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = test_outerproduct_no_accumulator_2x2xf64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles  \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm -o %t
+// DEFINE:   -test-lower-to-llvm -o %t
 // DEFINE: %{run} = %mcr_aarch64_cmd %t \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme-f64f64 \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
index 82f38b4dbfa9d1f..4ca61a089bdf52d 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
index 3b218aefcd415ff..14dca2d4d7082aa 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
index e2cbe735fa4ff06..2751c2d136485e2 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:   -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:   -e %{entry_point} -entry-point-result=void \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index 6e33a421bf799a2..b45f24f6c8fddaf 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// RUN:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // RUN:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
+// RUN:   -test-lower-to-llvm | \
 // RUN: %mcr_aarch64_cmd \
 // RUN:  -march=aarch64 -mattr=+sve,+sme \
 // RUN:  -e entry -entry-point-result=i32 \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
index 961bb274d1e3352..5d2d0a73992f1e0 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir
@@ -1,9 +1,9 @@
 // DEFINE: %{entry_point} = za0_d_f64
 // DEFINE: %{compile} = mlir-opt %s \
 // DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
 // DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 25ef1799e63adb1..af5fe236e5bd577 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,8 +1,7 @@
 // DEFINE: %{entry_point} = entry
 // DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
-// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
-// DEFINE:   -convert-arm-sme-to-llvm \
-// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-arm-sme-to-llvm -test-lower-to-llvm
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=i32 \

diff  --git a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
index b3202b26f8e1e3d..7c9976bed912734 100644
--- a/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme-invalid.mlir
@@ -4,10 +4,9 @@
 llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
                                                 %nxv4i1 : vector<[4]xi1>,
                                                 %nxv16i8 : vector<[16]xi8>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {predicate, vector} have same shape}}
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv16i8) :
-      (i32, i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -16,10 +15,9 @@ llvm.func @arm_sme_vector_to_tile_invalid_types(%tileslice : i32,
 llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv16i8 : vector<[16]xi8>
 ) -> vector<[3]xf32> {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {vector, predicate, res} have same shape}}
-  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tile, %tileslice) :
-      (vector<[16]xi8>, vector<[4]xi1>, i32, i32) -> vector<[3]xf32>
+  %res = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[16]xi8>, vector<[4]xi1>, i32) -> vector<[3]xf32>
   llvm.return %res : vector<[3]xf32>
 }
 
@@ -28,9 +26,8 @@ llvm.func @arm_sme_tile_slice_to_vector_invalid_shapes(
 llvm.func @arm_sme_tile_slice_to_vector_invalid_element_types(
   %tileslice : i32, %nxv4i1 : vector<[4]xi1>, %nxv4f32 : vector<[4]xf32>
 ) -> vector<[3]xi32> {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // expected-error @+1 {{failed to verify that all of {vector, res} have same element type}}
-  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice) :
-      (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}> :
+      (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   llvm.return %res : vector<[4]xi32>
 }

diff  --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir
index 767d89a75eec326..edc1f749130440c 100644
--- a/mlir/test/Target/LLVMIR/arm-sme.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sme.mlir
@@ -2,9 +2,8 @@
 
 // CHECK-LABEL: @arm_sme_zero
 llvm.func @arm_sme_zero() {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.zero(i32 0)
-  "arm_sme.intr.zero"(%c0) : (i32) -> ()
+  "arm_sme.intr.zero"() <{tile_mask = 0 : i32}> : () -> ()
   llvm.return
 }
 
@@ -18,19 +17,18 @@ llvm.func @arm_sme_fmopa(%nxv2f64 : vector<[2]xf64>,
                          %nxv2i1  : vector<[2]xi1>,
                          %nxv4i1  : vector<[4]xi1>,
                          %nxv8i1  : vector<[8]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64
-  "arm_sme.intr.mopa"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
-    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.mopa"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+    (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32
-  "arm_sme.intr.mopa"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
-    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.mopa"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+    (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8f16
-  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mopa.wide.nxv8bf16
-  "arm_sme.intr.mopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.mopa.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
   llvm.return
 }
 
@@ -41,31 +39,30 @@ llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>,
                          %nxv16i8 : vector<[16]xi8>,
                          %nxv8i1  : vector<[8]xi1>,
                          %nxv16i1 : vector<[16]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv8i16
-  "arm_sme.intr.smopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.smopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv8i16
-  "arm_sme.intr.umopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.umopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv8i16
-  "arm_sme.intr.sumopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.sumopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv8i16
-  "arm_sme.intr.usmopa.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.usmopa.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.smopa.wide.nxv16i8
-  "arm_sme.intr.smopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.smopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umopa.wide.nxv16i8
-  "arm_sme.intr.umopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.umopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumopa.wide.nxv16i8
-  "arm_sme.intr.sumopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.sumopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8
-  "arm_sme.intr.usmopa.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.usmopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -79,19 +76,18 @@ llvm.func @arm_sme_fmops(%nxv2f64 : vector<[2]xf64>,
                          %nxv2i1  : vector<[2]xi1>,
                          %nxv4i1  : vector<[4]xi1>,
                          %nxv8i1  : vector<[8]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.mops.nxv2f64
-  "arm_sme.intr.mops"(%c0, %nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) :
-    (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.mops"(%nxv2i1, %nxv2i1, %nxv2f64, %nxv2f64) <{tile_id = 0 : i32}> :
+    (vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.nxv4f32
-  "arm_sme.intr.mops"(%c0, %nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) :
-    (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.mops"(%nxv4i1, %nxv4i1, %nxv4f32, %nxv4f32) <{tile_id = 0 : i32}> :
+    (vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8f16
-  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8f16, %nxv8f16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.mops.wide.nxv8bf16
-  "arm_sme.intr.mops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.mops.wide"(%nxv8i1, %nxv8i1, %nxv8bf16, %nxv8bf16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> ()
   llvm.return
 }
 
@@ -102,31 +98,30 @@ llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>,
                          %nxv16i8 : vector<[16]xi8>,
                          %nxv8i1  : vector<[8]xi1>,
                          %nxv16i1 : vector<[16]xi1>) {
-  %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv8i16
-  "arm_sme.intr.smops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.smops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv8i16
-  "arm_sme.intr.umops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.umops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv8i16
-  "arm_sme.intr.sumops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.sumops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv8i16
-  "arm_sme.intr.usmops.wide"(%c0, %nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) :
-    (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.usmops.wide"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> :
+    (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.smops.wide.nxv16i8
-  "arm_sme.intr.smops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.smops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.umops.wide.nxv16i8
-  "arm_sme.intr.umops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.umops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.sumops.wide.nxv16i8
-  "arm_sme.intr.sumops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.sumops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8
-  "arm_sme.intr.usmops.wide"(%c0, %nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) :
-    (i32, vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.usmops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> :
+    (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> ()
   llvm.return
 }
 
@@ -141,35 +136,35 @@ llvm.func @arm_sme_load(%nxv1i1  : vector<[1]xi1>,
                         %ptr    : !llvm.ptr) {
   %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.ld1q.horiz
-  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1d.horiz
-  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1w.horiz
-  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1h.horiz
-  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1b.horiz
-  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1q.vert
-  "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1d.vert
-  "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1w.vert
-  "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1h.vert
-  "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.ld1b.vert
-  "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.ld1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   llvm.return
 }
 
@@ -184,35 +179,35 @@ llvm.func @arm_sme_store(%nxv1i1  : vector<[1]xi1>,
                          %ptr    : !llvm.ptr) {
   %c0 = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.st1q.horiz
-  "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1q.horiz"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1d.horiz
-  "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1d.horiz"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1w.horiz
-  "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1w.horiz"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1h.horiz
-  "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1h.horiz"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1b.horiz
-  "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1b.horiz"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1q.vert
-  "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0, %c0) :
-              (vector<[1]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1q.vert"(%nxv1i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[1]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1d.vert
-  "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0, %c0) :
-              (vector<[2]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1d.vert"(%nxv2i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[2]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1w.vert
-  "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0, %c0) :
-              (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1w.vert"(%nxv4i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[4]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1h.vert
-  "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0, %c0) :
-              (vector<[8]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1h.vert"(%nxv8i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[8]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.st1b.vert
-  "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0, %c0) :
-              (vector<[16]xi1>, !llvm.ptr, i32, i32) -> ()
+  "arm_sme.intr.st1b.vert"(%nxv16i1, %ptr, %c0) <{tile_id = 0 : i32}> :
+              (vector<[16]xi1>, !llvm.ptr, i32) -> ()
   // CHECK: call void @llvm.aarch64.sme.str
   "arm_sme.intr.str"(%c0, %ptr, %c0) : (i32, !llvm.ptr, i32) -> ()
   llvm.return
@@ -236,34 +231,33 @@ llvm.func @arm_sme_vector_to_tile_horiz(%tileslice : i32,
                                         %nxv8bf16 : vector<[8]xbf16>,
                                         %nxv4f32 : vector<[4]xf32>,
                                         %nxv2f64 : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv16i8
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
-      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8i16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4i32
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2i64
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv1i128
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
-      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+      (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8f16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv8bf16
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv4f32
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.horiz.nxv2f64
-  "arm_sme.intr.write.horiz"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.write.horiz"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
 
@@ -285,34 +279,33 @@ llvm.func @arm_sme_vector_to_tile_vert(%tileslice : i32,
                                        %nxv8bf16 : vector<[8]xbf16>,
                                        %nxv4f32 : vector<[4]xf32>,
                                        %nxv2f64 : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv16i8
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv16i1, %nxv16i8) :
-      (i32, i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv16i1, %nxv16i8) <{tile_id = 0 : i32}> :
+      (i32, vector<[16]xi1>, vector<[16]xi8>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8i16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8i16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8i16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xi16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4i32
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4i32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4i32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2i64
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2i64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2i64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xi64>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv1i128
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv1i1, %nxv1i128) :
-      (i32, i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv1i1, %nxv1i128) <{tile_id = 0 : i32}> :
+      (i32, vector<[1]xi1>, vector<[1]xi128>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8f16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8f16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8f16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv8bf16
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv8i1, %nxv8bf16) :
-      (i32, i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv8i1, %nxv8bf16) <{tile_id = 0 : i32}> :
+      (i32, vector<[8]xi1>, vector<[8]xbf16>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv4f32
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv4i1, %nxv4f32) :
-      (i32, i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv4i1, %nxv4f32) <{tile_id = 0 : i32}> :
+      (i32, vector<[4]xi1>, vector<[4]xf32>) -> ()
   // CHECK: call void @llvm.aarch64.sme.write.vert.nxv2f64
-  "arm_sme.intr.write.vert"(%tile, %tileslice, %nxv2i1, %nxv2f64) :
-      (i32, i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
+  "arm_sme.intr.write.vert"(%tileslice, %nxv2i1, %nxv2f64) <{tile_id = 0 : i32}> :
+      (i32, vector<[2]xi1>, vector<[2]xf64>) -> ()
   llvm.return
 }
 
@@ -334,34 +327,33 @@ llvm.func @arm_sme_tile_slice_to_vector_horiz(%tileslice : i32,
                                               %nxv8bf16  : vector<[8]xbf16>,
                                               %nxv4f32   : vector<[4]xf32>,
                                               %nxv2f64   : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.horiz.nxv16i8
-  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tile, %tileslice)
-    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  %res0 = "arm_sme.intr.read.horiz"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.horiz.nxv8i16
-  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  %res1 = "arm_sme.intr.read.horiz"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.horiz.nxv4i32
-  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res2 = "arm_sme.intr.read.horiz"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.horiz.nxv2i64
-  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  %res3 = "arm_sme.intr.read.horiz"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.horiz.nxv1i128
-  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tile, %tileslice)
-    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  %res4 = "arm_sme.intr.read.horiz"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.horiz.nxv8f16
-  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  %res5 = "arm_sme.intr.read.horiz"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.horiz.nxv8bf16
-  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  %res6 = "arm_sme.intr.read.horiz"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.horiz.nxv4f32
-  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  %res7 = "arm_sme.intr.read.horiz"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.horiz.nxv2f64
-  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  %res8 = "arm_sme.intr.read.horiz"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   llvm.return
 }
 
@@ -382,33 +374,32 @@ llvm.func @arm_sme_tile_slice_to_vector_vert(%tileslice : i32,
                                               %nxv8bf16 : vector<[8]xbf16>,
                                               %nxv4f32  : vector<[4]xf32>,
                                               %nxv2f64  : vector<[2]xf64>) {
-  %tile = llvm.mlir.constant(0 : index) : i32
   // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sme.read.vert.nxv16i8
-  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tile, %tileslice)
-    : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8>
+  %res0 = "arm_sme.intr.read.vert"(%nxv16i8, %nxv16i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[16]xi8>, vector<[16]xi1>, i32) -> vector<[16]xi8>
   // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sme.read.vert.nxv8i16
-  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16>
+  %res1 = "arm_sme.intr.read.vert"(%nxv8i16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xi16>, vector<[8]xi1>, i32) -> vector<[8]xi16>
   // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sme.read.vert.nxv4i32
-  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %res2 = "arm_sme.intr.read.vert"(%nxv4i32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xi32>, vector<[4]xi1>, i32) -> vector<[4]xi32>
   // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sme.read.vert.nxv2i64
-  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64>
+  %res3 = "arm_sme.intr.read.vert"(%nxv2i64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xi64>, vector<[2]xi1>, i32) -> vector<[2]xi64>
   // CHECK: call <vscale x 1 x i128> @llvm.aarch64.sme.read.vert.nxv1i128
-  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tile, %tileslice)
-    : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128>
+  %res4 = "arm_sme.intr.read.vert"(%nxv1i128, %nxv1i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[1]xi128>, vector<[1]xi1>, i32) -> vector<[1]xi128>
   // CHECK: call <vscale x 8 x half> @llvm.aarch64.sme.read.vert.nxv8f16
-  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16>
+  %res5 = "arm_sme.intr.read.vert"(%nxv8f16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xf16>, vector<[8]xi1>, i32) -> vector<[8]xf16>
   // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sme.read.vert.nxv8bf16
-  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tile, %tileslice)
-    : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16>
+  %res6 = "arm_sme.intr.read.vert"(%nxv8bf16, %nxv8i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[8]xbf16>, vector<[8]xi1>, i32) -> vector<[8]xbf16>
   // CHECK: call <vscale x 4 x float> @llvm.aarch64.sme.read.vert.nxv4f32
-  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tile, %tileslice)
-    : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+  %res7 = "arm_sme.intr.read.vert"(%nxv4f32, %nxv4i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[4]xf32>, vector<[4]xi1>, i32) -> vector<[4]xf32>
   // CHECK: call <vscale x 2 x double> @llvm.aarch64.sme.read.vert.nxv2f64
-  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tile, %tileslice)
-    : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64>
+  %res8 = "arm_sme.intr.read.vert"(%nxv2f64, %nxv2i1, %tileslice) <{tile_id = 0 : i32}>
+    : (vector<[2]xf64>, vector<[2]xi1>, i32) -> vector<[2]xf64>
   llvm.return
 }


        


More information about the Mlir-commits mailing list