[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

Benjamin Maxwell llvmlistbot at llvm.org
Wed May 8 03:39:24 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/90448

>From 4799a3aee89f9d396e16afc3980073a0e49496c6 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 26 Apr 2024 15:09:20 +0000
Subject: [PATCH 1/9] [mlir][ArmSME] Use liveness information in the tile
 allocator

This patch rewrites the ArmSME tile allocator to use liveness
information to make better tile allocation decisions and improve the
correctness of the ArmSME dialect. This algorithm used here is a linear
scan over live ranges, where live ranges are assigned to tiles as they
appear in the program (chronologically). Live ranges release their
assigned tile ID when the current program point is passed their end.
This is a greedy algorithm (which is mainly to keep the implementation
relatively straightforward), and because it seems to be sufficient for
most kernels (e.g. matmuls) that use ArmSME. The general steps of this
are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf,
though there have been a few simplifications and assumptions made for
our use case.

Hopefully, the only changes needed for a user of the ArmSME dialect is
that:

- `-allocate-arm-sme-tiles` will no longer be a standalone pass
  - `-test-arm-sme-tile-allocation` is only for unit tests
- `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf`
  - SME tile allocation is now part of the LLVM conversion

By integrating this into the `ArmSME -> LLVM` conversion we can allow
high-level (value-based) ArmSME operations to be side-effect-free, as we
can guarantee nothing will rearrange ArmSME operations before we emit
intrinsics (which could invalidate the tile allocation).

The hope is for ArmSME operations to have no hidden state/side effects
and allow easily lowering dialects such as `vector` and `arith` to SME,
without making assumptions about how the input IR looks, as the
semantics of the operations will be the same. That is no (new) side
effects and the IR follows the rules of SSA (a value will never change).

The aim is correctness, so we have a base for working on optimizations.
---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h    |   4 +-
 mlir/include/mlir/Conversion/Passes.td        |   7 +-
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |   6 +-
 .../Dialect/ArmSME/IR/ArmSMEOpInterfaces.h    |  28 +
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 141 ++--
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |   3 -
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  17 +-
 .../Dialect/ArmSME/Transforms/Transforms.h    |   9 +
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  20 +
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  69 +-
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  29 +-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   6 +
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  46 ++
 .../ArmSME/Transforms/TileAllocation.cpp      | 644 +++++++++++++-----
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         |   3 +-
 .../ArmSMEToLLVM/tile-spills-and-fills.mlir   |  15 +-
 .../Conversion/ArmSMEToLLVM/unsupported.mlir  |   3 +-
 .../Dialect/ArmSME/basic-tile-allocation.mlir |   2 +-
 mlir/test/Dialect/ArmSME/canonicalize.mlir    |  10 +-
 mlir/test/Dialect/ArmSME/cse.mlir             |  30 -
 .../ArmSME/tile-allocation-invalid.mlir       |  12 +-
 .../ArmSME/tile-allocation-liveness.mlir      | 196 ++++--
 mlir/test/Dialect/ArmSME/tile-zero-masks.mlir |   2 +-
 .../Linalg/CPU/ArmSME/use-too-many-tiles.mlir |   7 +-
 .../ArmSME/Emulated/test-setArmSVLBits.mlir   |   3 +-
 mlir/test/lib/Dialect/ArmSME/CMakeLists.txt   |   1 +
 .../lib/Dialect/ArmSME/TestLowerToArmSME.cpp  |  19 +-
 27 files changed, 893 insertions(+), 439 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
 delete mode 100644 mlir/test/Dialect/ArmSME/cse.mlir

diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab49998..403f811a2569a 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
 class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
 void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab9..e6d678dc1b12b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
 // ArmSMEToLLVM
 //===----------------------------------------------------------------------===//
 
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
   let summary = "Lower the operations from the ArmSME dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
     "arm_sme::ArmSMEDialect",
     "LLVM::LLVMDialect"
   ];
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a7..dac54712c7f47 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 0000000000000..f31062d8c25ed
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,28 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+static constexpr unsigned kInMemoryTileIdBase = 16;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d2..9178655f010c9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
 
 def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   let description = [{
-    An interface for operations that use or allocate Arm SME tiles. These
-    operations need to be assigned a tile ID, an i32 attribute, which specifies
-    which virtual tile within the ZA storage to use. The number of tiles
-    available depends on the type of the tile. This is summarized below:
+    An interface for operations that use Arm SME tiles. These operations need to
+    be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+    within the ZA storage to use. The number of tiles available depends on the
+    type of the tile. This is summarized below:
 
     | Tile Vector Types                                                       | Possible Tile IDs   |
     |-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
     | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
     | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
     | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
-
-    Operations that allocate a new tile (such as arm_sme.get_tile), are used as
-    the roots for tile allocation, with all operations that (transitively)
-    depend on a root being assigned the same tile ID.
   }];
   let methods = [
     InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
         return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
       }]
     >,
-    InterfaceMethod<
-      [{
-        The type of tile this operation allocates. Returns none (std::nullopt)
-        if this operation does not allocate a tile.
-      }],
-      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
-      /*methodName=*/"getAllocatedTileType",
-      /*arguments=*/(ins),
-      /*methodBody=*/[{}],
-      /*defaultImpl=*/ [{
-        // This operation does not allocate a tile.
-        return std::nullopt;
-      }]
-    >,
     InterfaceMethod<
       "Returns the VectorType of the tile used by this operation.",
       /*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   ];
 
   let extraSharedClassDeclaration = [{
-    // A helper to create a new operation and propagate this operations tile ID.
-    template<typename T, typename... Args>
-    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
-      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
-      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
-        tileOp.setTileId($_op.getTileId());
-      return op;
-    }
-
-    // A helper to replace this operation and forward its tile ID (if present).
-    template<typename T, typename... Args>
-    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
-      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
-      rewriter.replaceOp($_op, newOp);
-      return newOp;
-    }
-
     bool isInMemoryTile() {
       auto tileId = getTileId();
       return tileId && tileId.getInt() >= kInMemoryTileIdBase;
     }
   }];
 
-  let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+  let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
-  let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates an undefined value of SME virtual tile type";
   let description = [{
-    Allocates a new SME "virtual tile" within a function. The contents of the
-    tile returned from this operation are undefined.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are undefined.
 
     Example 1:
 
     ```mlir
-    // Allocate an 8-bit element "virtual tile"
+    // Create an 8-bit element "virtual tile" value:
     %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
     Example 2:
 
     ```mlir
-    // Allocate two 16-bit element "virtual tiles"
+    // Create two 16-bit element "virtual tiles" values:
     %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
     Example 3:
     ```mlir
-    // Allocate an 128-bit element "virtual tile"
+    // Create an 128-bit element "virtual tile" value:
     %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
   }];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
     VectorType getTileType() {
       return ::llvm::cast<VectorType>(getTile().getType());
     }
-
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getTileType());
-    }
-  }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
-  let summary = "SME tile placeholder";
-  let description = [{
-    A placeholder to preserve dataflow while lowering to SME intrinsics (which
-    do not take or return SME virtual tile values). This operation is intended
-    to be DCE'd once all ArmSME operations have been lowered.
-
-    This operation is not intended to be used outside of the ArmSME -> LLVM
-    conversion.
   }];
-  let results = (outs SMETile:$tile);
-  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
-  let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates a zero-initialized value of SME virtual tile type";
   let results = (outs SMETile:$res);
   let description = [{
-    Initialise ZA with 0. This operation is convenient wrapper for the SME
-    `zero` intrinsic and instruction.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are zero-initialized.
 
     Example 1: Zero an 8-bit element ZA tile.
 
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+  Pure,
+  ArmSMETileOpInterface,
+  AllTypesMatch<["tile", "result"]>
+]> {
+  let summary = "Copies an SME tile value";
+  let arguments = (ins SMETile:$tile);
+  let results = (outs SMETile:$result);
+  let description = [{
+    Copies an SME "virtual tile" value to a new SSA value. This operation is
+    primarily intended to be used to normalize the IR prior to tile allocation.
+
+    Example:
+
+    ```mlir
+    %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+  let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
 def TileLoadOp : ArmSME_Op<"tile_load", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
 }
 
 def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
                                list<Type> allowedResultVectorTypes,
                                int numOuterProducts> :
   ArmSME_Op<mnemonic, [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874e..156744ba57e7b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
     const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
 
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
 std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e8926..b9d74fec6756e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,16 +124,21 @@ def EnableArmStreaming
   let dependentDialects = ["func::FuncDialect"];
 }
 
-def TileAllocation
-    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
-  let summary = "Allocate SME tiles";
+def TestTileAllocation
+    : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+  let summary = "Tests SME tile allocation";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
     'func.func' op level, and assigns tile IDs (via an attribute) to all ops
-    that implement the `ArmSMETileOpInterface`. An error will be emitted when
-    there's no tiles left.
+    that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+    to be used for testing, tile allocation is done as part of the ArmSME to
+    LLVM conversion.
   }];
-  let constructor = "mlir::arm_sme::createTileAllocationPass()";
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
   let dependentDialects = ["func::FuncDialect"];
 }
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e6999..a25b844f01eaa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_H
 
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
 namespace mlir {
 
 class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
 class RewritePatternSet;
 
 namespace arm_sme {
+
 void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+                               bool dumpRanges = false);
+
 } // namespace arm_sme
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92f..9ea1c5a5d63fe 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include <optional>
 
 namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+inline bool isValidSMETileVectorType(Type type) {
+  auto vType = dyn_cast<VectorType>(type);
+  return vType && isValidSMETileVectorType(vType);
+}
+
 /// Returns the type of SME tile this vector type corresponds to, or none if the
 /// vector type does not fit within an SME tile.
 std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,19 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
 /// Creates a vector type for the SME tile of `elementType`.
 VectorType getSMETileTypeForElement(Type elementType);
 
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+                               FunctionOpInterface function);
+
+/// Returns true if `tileOp` can be cloned to resolve conflicts.
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns true if `tileOp` produces a tile result.
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns the tile `OpOperand` for this `tileOp` (or null).
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 1ba1b88fc1234..488f747af050f 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -245,6 +246,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
     if (!tileOp.isInMemoryTile())
       return failure();
 
+    tileOp->emitWarning(
+        "failed to allocate SME virtual tile to operation, all tile "
+        "operations will go through memory, expect degraded performance");
+
     // Step 1. Create an alloca for the tile at the top of the function (if one
     // does not already exist).
     auto loc = tileOp.getLoc();
@@ -391,20 +396,6 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
   (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
 }
 
-struct GetTileConversion
-    : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
-                                          RequiresSpillsAndFills::No> {
-  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
-        getTile, getTile.getTileType());
-    return success();
-  }
-};
-
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
@@ -436,7 +427,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
     // The base mask is just the mask to zero the first tile (of a size).
     // These masks are derived from:
     // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
-    arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
+    arm_sme::ArmSMETileType tileType =
+        *arm_sme::getSMETileType(zero.getTileType());
     auto baseMaskForSize = [&] {
       switch (tileType) {
       case arm_sme::ArmSMETileType::ZAB:
@@ -488,8 +480,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
         loc, rewriter.getI32IntegerAttr(zeroMask));
 
     // Create a placeholder op to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
-        zero, zero.getVectorType());
+    rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
 
     return success();
   }
@@ -746,10 +737,12 @@ struct OuterProductOpConversion
     auto loc = outerProductOp.getLoc();
 
     Value acc = outerProductOp.getAcc();
-    if (!acc)
+    if (!acc) {
       // Initalize accumulator with zero.
-      acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
-          rewriter, loc, resultVectorType);
+      auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+      zero.setTileId(tileId);
+      acc = zero;
+    }
 
     Value lhsMask = outerProductOp.getLhsMask();
     Value rhsMask = outerProductOp.getRhsMask();
@@ -791,25 +784,27 @@ struct OuterProductWideningOpConversion
     if (!tileId)
       return failure();
 
+    auto loc = op.getLoc();
     Value acc = op.getAcc();
-    if (!acc)
+    if (!acc) {
       // Initalize accumulator with zero.
-      acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
-          rewriter, op.getLoc(), op.getResultType());
+      auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+      zero.setTileId(tileId);
+      acc = zero;
+    }
 
     Value lhsMask = op.getLhsMask();
     Value rhsMask = op.getRhsMask();
     if (!lhsMask || !rhsMask) {
       auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
       Value allActiveMask = rewriter.create<arith::ConstantOp>(
-          op.getLoc(), DenseElementsAttr::get(predTy, true));
+          loc, DenseElementsAttr::get(predTy, true));
       lhsMask = allActiveMask;
       rhsMask = allActiveMask;
     }
 
-    rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
-                                                rhsMask, adaptor.getLhs(),
-                                                adaptor.getRhs());
+    rewriter.create<OuterProductWideningIntrOp>(
+        loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
 
     // The outerproduct intrinsics have no result, replace
     // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -865,15 +860,22 @@ namespace {
 
 struct ConvertArmSMEToLLVMPass
     : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+  ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+    this->dumpTileLiveRanges = dumpTileLiveRanges;
+  }
   void runOnOperation() override {
+    auto function = getOperation();
+
+    if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
+      return signalPassFailure();
+
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext());
     configureArmSMEToLLVMConversionLegality(target);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialConversion(function, target, std::move(patterns))))
       signalPassFailure();
   }
 };
@@ -883,7 +885,7 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
+      arm_sme::GetTileOp, arm_sme::CopyTileOp, arm_sme::aarch64_sme_zero,
       arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
       arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
       arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
@@ -955,9 +957,10 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                        arm_sme::aarch64_sme_usmopa_wide>,
       OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
                                        arm_sme::aarch64_sme_usmops_wide>,
-      ZeroOpConversion, GetTileConversion>(patterns, converter);
+      ZeroOpConversion>(patterns, converter);
 }
 
-std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
-  return std::make_unique<ConvertArmSMEToLLVMPass>();
+std::unique_ptr<Pass>
+mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+  return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
 }
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 16b61c282749c..9f55932c33af6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -196,12 +196,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
       // Initialize tile with zero to satisfy padding. Inactive cols will be
       // zeroed anyway since the loads use zeroing predication. For inactive
       // rows however, no load will occur so these need to be zeroed.
-      initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
-          rewriter, loc, tileType);
+      initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
     } else {
-      // Allocate a new SME tile.
-      initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
-          rewriter, loc, tileType);
+      initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
     }
 
     // Create a loop to load the active tile slices from memory.
@@ -212,10 +209,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
             Value currentTile) -> Value {
           // Create 'arm_sme.load_tile_slice' to load tile slice from memory
           // into tile.
-          return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-              rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
-              currentTile, memrefIndices, tileSliceIndex,
-              tileLoadOp.getLayout());
+          return rewriter.create<arm_sme::LoadTileSliceOp>(
+              loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
+              memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
         });
 
     if (failed(forOp))
@@ -292,9 +288,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), numCols);
 
-    // Allocate a new SME tile.
-    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
-        rewriter, loc, tileType);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -339,10 +333,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    auto moveSlice =
-        tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
-            rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
-            tileSliceIndex, tileLoadOp.getLayout());
+    auto moveSlice = rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
+        tileLoadOp.getLayout());
     rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -386,8 +379,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
         tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
         tileStoreOp.getMask(),
         [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
-          tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
-              rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+          rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+              tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
               predicate, tileStoreOp.getBase(), memrefIndices,
               tileStoreOp.getLayout());
         });
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 29fa9085a0a96..cb3a665844872 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -20,6 +20,12 @@
 using namespace mlir;
 using namespace mlir::arm_sme;
 
+namespace mlir::arm_sme::detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *op) {
+  return verifyOperationHasValidTileId(op);
+}
+} // namespace mlir::arm_sme::detail
+
 //===----------------------------------------------------------------------===//
 // Tablegen Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 6a9e022182226..00d764bf5caff 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -116,4 +116,50 @@ VectorType getSMETileTypeForElement(Type elementType) {
   return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
 }
 
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+                               FunctionOpInterface function) {
+  SmallVector<Operation *> worklist;
+  function->walk([&](Operation *op) {
+    auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
+    if (armSMEOp && isOpTriviallyDead(armSMEOp))
+      worklist.push_back(armSMEOp);
+  });
+  while (!worklist.empty()) {
+    Operation *op = worklist.pop_back_val();
+    if (!isOpTriviallyDead(op))
+      continue;
+    for (Value value : op->getOperands()) {
+      if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
+        worklist.push_back(armSMEOp);
+    }
+    rewriter.eraseOp(op);
+  }
+}
+
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
+  return tileOp && tileOp->getNumResults() == 1 &&
+         tileOp->getNumOperands() == 0 && isPure(tileOp);
+}
+
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
+  for (Value result : tileOp->getResults()) {
+    if (arm_sme::isValidSMETileVectorType(result.getType()))
+      return true;
+  }
+  return false;
+}
+
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
+  auto isTileOperandType = [](OpOperand &operand) {
+    return arm_sme::isValidSMETileVectorType(operand.get().getType());
+  };
+  OpOperand *tileOperand =
+      llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
+  assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
+         "expected at most one tile operand");
+  if (tileOperand == tileOp->getOpOperands().end())
+    return nullptr;
+  return tileOperand;
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 4acb2a8fb7b53..e3cf52078a24e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass allocates SME tiles at the 'func.func' op level for ArmSME
+// This transform allocates SME tiles at the 'func.func' op level for ArmSME
 // operations. It does this using a 16-bit tile mask that has a bit for each
 // 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
 //
@@ -32,39 +32,31 @@
 //   ZA6.D   ZA6.Q, ZA14.Q
 //   ZA7.D   ZA7.Q, ZA15.Q
 //
-// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
-// that is initalized during the first tile allocation within a function and
-// updated on each subsequent allocation.
-//
 // [1] https://developer.arm.com/documentation/ddi0616/aa
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/Liveness.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/IntervalMap.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
 
-#define DEBUG_TYPE "allocate-arm-sme-tiles"
-
-namespace mlir {
-namespace arm_sme {
-#define GEN_PASS_DEF_TILEALLOCATION
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_TESTTILEALLOCATION
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
 
 using namespace mlir;
 using namespace mlir::arm_sme;
 
 namespace {
 
-static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
-static constexpr StringLiteral
-    kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
-
 enum class TileMask : unsigned {
   // clang-format off
   kZA0B  = 0xffff, // 1111 1111 1111 1111
@@ -137,172 +129,510 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
-                                          TileMask &tilesInUse) {
-  auto masks = getMasks(tileType);
-  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
-    if ((tilesInUse & tileMask) == TileMask::kNone) {
-      tilesInUse |= tileMask;
-      return tileId;
+class TileAllocator {
+public:
+  /// Allocates and returns a tile ID.
+  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+    auto masks = getMasks(tileType);
+    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+      if ((tilesInUse & tileMask) == TileMask::kNone) {
+        tilesInUse |= tileMask;
+        return tileId;
+      }
     }
+    return failure();
+  }
+
+  /// Releases a previously allocated tile ID.
+  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) != TileMask::kNone &&
+           "cannot release unallocated tile!");
+    tilesInUse ^= tileMask;
+  }
+
+  /// Allocates an in-memory tile ID.
+  unsigned allocateInMemoryTileId() {
+    // Note: We never release in-memory tile IDs. We could, which may allow
+    // reusing an allocation, but as we _never_ want to spill an SME tile this
+    // is not optimized.
+    return nextInMemoryTileId++;
   }
-  return failure();
-}
 
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
-                             SetVector<Operation *> &dependantOps) {
-  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
-    for (auto [idx, value] : llvm::enumerate(inputValues)) {
-      if (value == rootValue)
-        findDependantOps(exitValues[idx], dependantOps);
+private:
+  TileMask tilesInUse = TileMask::kNone;
+  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+// Add new intermediate blocks for the true and false destinations of a
+// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
+// branches.
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+  SmallVector<cf::CondBranchOp> worklist;
+  function.walk([&](cf::CondBranchOp condBranch) {
+    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+          return isValidSMETileVectorType(value.getType());
+        })) {
+      worklist.push_back(condBranch);
     }
+  });
+
+  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+    rewriter.setInsertionPointToEnd(source);
+    rewriter.create<cf::BranchOp>(loc, dest, args);
   };
-  for (Operation *user : rootValue.getUsers()) {
-    if (dependantOps.contains(user))
+
+  for (auto condBranch : worklist) {
+    auto loc = condBranch.getLoc();
+    Block *block = condBranch->getBlock();
+    auto newTrueBranch = rewriter.splitBlock(block, block->end());
+    auto newFalseBranch = rewriter.splitBlock(block, block->end());
+    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+               condBranch.getTrueDestOperands());
+    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+               condBranch.getFalseDestOperands());
+    condBranch.getFalseDestOperandsMutable().clear();
+    condBranch.getTrueDestOperandsMutable().clear();
+    condBranch.setSuccessor(newTrueBranch, 0);
+    condBranch.setSuccessor(newFalseBranch, 1);
+  }
+}
+
+/// Inserts tile copies `cf.br` operations.
+void insertCopiesAtBranches(IRRewriter &rewriter,
+                            FunctionOpInterface function) {
+  splitCondBranches(rewriter, function);
+  for (Block &block : function.getBlocks()) {
+    Operation *terminator = block.getTerminator();
+    if (!isa<cf::BranchOp>(terminator))
       continue;
-    dependantOps.insert(user);
-    TypeSwitch<Operation *>(user)
-        .Case<cf::BranchOp>([&](auto branchOp) {
-          // (CF) Follow branch.
-          traverseCorrespondingValues(branchOp.getDestOperands(),
-                                      branchOp.getDest()->getArguments());
-        })
-        .Case<cf::CondBranchOp>([&](auto condBranchOp) {
-          // (CF) Follow true branch.
-          traverseCorrespondingValues(
-              condBranchOp.getTrueOperands(),
-              condBranchOp.getTrueDest()->getArguments());
-          // (CF) Follow false branch.
-          traverseCorrespondingValues(
-              condBranchOp.getFalseOperands(),
-              condBranchOp.getFalseDest()->getArguments());
-        })
-        .Case<LoopLikeOpInterface>([&](auto loopOp) {
-          // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
-          traverseCorrespondingValues(loopOp.getInits(),
-                                      loopOp.getRegionIterArgs());
-        })
-        .Case<scf::YieldOp>([&](auto yieldOp) {
-          // (SCF) Follow yields of (basic) control flow (e.g. for loops).
-          auto parent = user->getParentOp();
-          traverseCorrespondingValues(user->getOperands(),
-                                      parent->getResults());
+    rewriter.setInsertionPoint(terminator);
+    for (OpOperand &operand : terminator->getOpOperands()) {
+      if (isValidSMETileVectorType(operand.get().getType())) {
+        auto copy =
+            rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+        operand.assign(copy);
+      }
+    }
+  }
+}
+
+/// A range where a tile value is live. The range may contain holes.
+struct LiveRange {
+  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+                                     llvm::IntervalMapHalfOpenInfo<unsigned>>;
+  using Allocator = RangeSet::Allocator;
+  static constexpr uint8_t kValidLiveRange = 0xff;
+
+  LiveRange(Allocator &allocator)
+      : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+  /// Returns true if this range overlaps with `otherRange`.
+  bool overlaps(LiveRange const &otherRange) const {
+    return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+                                                         *otherRange.ranges)
+        .valid();
+  }
+
+  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+  void unionWith(LiveRange const &otherRange) {
+    for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+         ++it)
+      ranges->insert(it.start(), it.stop(), kValidLiveRange);
+    values.set_union(otherRange.values);
+  }
+
+  /// Inserts an interval [start, end) for `value` into this range.
+  void insert(Value value, unsigned start, unsigned end) {
+    values.insert(value);
+    if (start != end)
+      ranges->insert(start, end, kValidLiveRange);
+  }
+
+  bool empty() const { return ranges->empty(); }
+  unsigned start() const { return ranges->start(); }
+  unsigned end() const { return ranges->stop(); }
+  bool operator<(LiveRange const &other) const {
+    return start() < other.start();
+  }
+
+  ArmSMETileType getTileType() const {
+    return *getSMETileType(cast<VectorType>(values[0].getType()));
+  }
+
+  std::unique_ptr<RangeSet> ranges;
+  SetVector<Value> values;
+  std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+  unsigned index = 0;
+  SetVector<Block *> blocks =
+      getTopologicallySortedBlocks(function.getFunctionBody());
+  DenseMap<Operation *, unsigned> operationToIndexMap;
+  for (Block *block : blocks) {
+    index++; // We want block args to have their own number.
+    for (Operation &op : block->getOperations()) {
+      // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+      op.walk([&](ArmSMETileOpInterface nestedOp) {
+        if (&op != nestedOp.getOperation()) {
+          assert(false &&
+                 "ArmSME tile allocation does not support nested regions");
+        }
+      });
+#endif
+      operationToIndexMap.try_emplace(&op, index++);
+    }
+  }
+  return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                     LiveRange::Allocator &liveRangeAllocator,
+                     Liveness &liveness, FunctionOpInterface function) {
+  DenseMap<Value, LiveRange> liveRanges;
+  auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
+                              LivenessBlockInfo const &livenessInfo,
+                              bool liveAtBlockEntry = false) {
+    if (!isValidSMETileVectorType(value.getType()))
+      return;
+    auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+    auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    unsigned start =
+        operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+    unsigned end = operationToIndexMap.at(lastUseInBlock);
+    it->second.insert(value, start, end);
+  };
+
+  for (Block &block : function.getBlocks()) {
+    LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+    // Handle block arguments:
+    for (Value argument : block.getArguments())
+      updateLiveRanges(argument, &block.front(), *livenessInfo,
+                       /*liveAtBlockEntry=*/true);
+    // Handle live-ins:
+    for (Value liveIn : livenessInfo->in())
+      updateLiveRanges(liveIn, &block.front(), *livenessInfo,
+                       /*liveAtBlockEntry=*/true);
+    // Handle new definitions:
+    for (Operation &op : block) {
+      for (Value result : op.getResults())
+        updateLiveRanges(result, &op, *livenessInfo);
+    }
+  }
+
+  return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+                                        function_ref<void(Value)> callback) {
+  Block *block = blockArg.getOwner();
+  unsigned argNumber = blockArg.getArgNumber();
+  for (Block *pred : block->getPredecessors()) {
+    TypeSwitch<Operation *>(pred->getTerminator())
+        .Case<cf::BranchOp>([&](auto branch) {
+          Value predecessorOperand = branch.getDestOperands()[argNumber];
+          callback(predecessorOperand);
         })
-        .Default([&](auto) {
-          // Otherwise, assume users of _any_ result are dependant.
-          for (Value result : user->getResults())
-            findDependantOps(result, dependantOps);
+        .Case<cf::CondBranchOp>([&](auto condBranch) {
+          if (condBranch.getFalseDest() == block) {
+            Value predecessorOperand =
+                condBranch.getFalseDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
+          if (condBranch.getTrueDest() == block) {
+            Value predecessorOperand =
+                condBranch.getTrueDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
         });
   }
 }
-struct AssignTileIDsPattern
-    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
-                                PatternRewriter &rewriter) const override {
-    if (tileOp.getTileId())
-      return failure();
-
-    auto func = tileOp->getParentOfType<FunctionOpInterface>();
-    auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
-      if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
-              func->getDiscardableAttr(name)))
-        return unsigned(attr.getInt());
-      return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+  DenseMap<Value, LiveRange *> liveRanges;
+  for (auto &[value, liveRange] : initialLiveRanges) {
+    liveRanges.insert({value, &liveRange});
+  }
+
+  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+    LiveRange *aLiveRange = liveRanges.at(a);
+    LiveRange *bLiveRange = liveRanges.at(b);
+    if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+      aLiveRange->unionWith(*bLiveRange);
+      for (Value value : bLiveRange->values)
+        liveRanges[value] = aLiveRange;
+    }
+  };
+
+  // Merge the live ranges of new definitions with their tile operands.
+  auto unifyDefinitionsWithOperands = [&](Value value) {
+    auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+    if (!armSMEOp)
+      return;
+    for (auto operand : armSMEOp->getOperands()) {
+      if (isValidSMETileVectorType(operand.getType()))
+        mergeValuesIfNonOverlapping(value, operand);
+    }
+  };
+
+  // Merge the live ranges of block arguments with their predecessors.
+  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+    auto blockArg = dyn_cast<BlockArgument>(value);
+    if (!blockArg)
+      return;
+    forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+      mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+    });
+  };
+
+  auto applyRule = [&](auto rule) {
+    llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+  };
+
+  // Unify as many live ranges as we can. This prevents unnecessary moves.
+  applyRule(unifyBlockArgumentsWithPredecessors);
+  applyRule(unifyDefinitionsWithOperands);
+
+  // Remove duplicate live range entries.
+  SetVector<LiveRange *> uniqueLiveRanges;
+  for (auto [_, liveRange] : liveRanges) {
+    if (!liveRange->empty())
+      uniqueLiveRanges.insert(liveRange);
+  }
+
+  // Sort the new live ranges by starting point (ready for tile allocation).
+  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+  std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+            [](LiveRange *a, LiveRange *b) { return *a < *b; });
+  return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+  TileAllocator tileAllocator;
+  SetVector<LiveRange *> allocatedRanges;
+
+  auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
+    unsigned memoryTileId = tileAllocator.allocateInMemoryTileId();
+    auto spillActiveRange = [&](LiveRange *range) {
+      unsigned tileId = *range->tileId;
+      range->tileId = memoryTileId;
+      allocatedRanges.remove(range);
+      return tileId;
     };
-    auto setDiscardableIntAttr = [&](StringRef name, auto value) {
-      rewriter.modifyOpInPlace(tileOp, [&] {
-        func->setDiscardableAttr(name,
-                                 rewriter.getI32IntegerAttr((unsigned)value));
-      });
+
+    auto isTrivialSpill = [](LiveRange *allocatedRange) {
+      return allocatedRange->values.size() == 1 &&
+             isTriviallyCloneableTileOp(
+                 allocatedRange->values[0]
+                     .getDefiningOp<ArmSMETileOpInterface>());
     };
 
-    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
-    if (!tileType)
-      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
-    TileMask tilesInUse =
-        static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
-    auto tileId = allocateTileId(*tileType, tilesInUse);
-    bool tileIsInMemory = failed(tileId);
-    if (tileIsInMemory) {
-      // If we could not find a real tile ID, use an in-memory tile ID (ID >=
-      // 16). A later pass will insert the necessary spills and reloads.
-      tileId =
-          getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
-      tileOp->emitWarning(
-          "failed to allocate SME virtual tile to operation, all tile "
-          "operations will go through memory, expect degraded performance");
-    }
+    // Heuristic: Spill trivially copyable operations (usually free).
+    if (isTrivialSpill(newRange))
+      return memoryTileId;
+    auto trivialSpill = llvm::find_if(allocatedRanges, isTrivialSpill);
+    if (trivialSpill != allocatedRanges.end())
+      return spillActiveRange(*trivialSpill);
+
+    // Heuristic: Spill the live range that ends last.
+    LiveRange *lastActiveLiveRange = *std::max_element(
+        allocatedRanges.begin(), allocatedRanges.end(),
+        [](LiveRange *a, LiveRange *b) { return a->end() < b->end(); });
+    if (lastActiveLiveRange->end() >= newRange->end())
+      return spillActiveRange(lastActiveLiveRange);
+
+    return memoryTileId;
+  };
 
-    // Set all operations dependent on `tileOp` to use the same tile ID.
-    // This is a naive tile allocation scheme, but works for common cases. For
-    // example, as this only allocates tile IDs to existing ops, it can't solve
-    // cases like this (%tileA and %tileB come from different root operations):
-    //
-    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
-    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
-    // } else {
-    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
-    // }
-    //
-    // This case would require allocating a new tile for the result of the
-    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
-    // on the %some_cond).
-    // Find all the ops that (transitively) depend on this tile.
-    SetVector<Operation *> dependantOps;
-    findDependantOps(tileOp->getResult(0), dependantOps);
-    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
-        auto currentTileId = dependantTileOp.getTileId();
-        if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
-          return dependantTileOp.emitOpError(
-              "already assigned different SME virtual tile!");
+  for (LiveRange *newRange : liveRanges) {
+    // Release tiles from live ranges that have ended.
+    allocatedRanges.remove_if([&](LiveRange *allocatedRange) {
+      if (allocatedRange->end() <= newRange->start()) {
+        tileAllocator.releaseTileId(allocatedRange->getTileType(),
+                                    *allocatedRange->tileId);
+        return true;
       }
-    }
+      return false;
+    });
 
-    // Rewrite IR.
-    if (!tileIsInMemory)
-      setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+    // Allocate a tile ID to `newRange`.
+    auto tileId = tileAllocator.allocateTileId(newRange->getTileType());
+    if (succeeded(tileId))
+      newRange->tileId = *tileId;
     else
-      setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
-    rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
-        rewriter.modifyOpInPlace(
-            dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
+      newRange->tileId = chooseSpillUsingHeuristics(newRange);
+
+    // Insert the live range into the allocated ranges.
+    if (newRange->tileId < kInMemoryTileIdBase)
+      allocatedRanges.insert(newRange);
+  }
+}
+
+/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
+LogicalResult assignTileIdsAndResolveTrivialConflicts(
+    IRRewriter &rewriter, FunctionOpInterface function,
+    ArrayRef<LiveRange *> allocatedLiveRanges) {
+  for (LiveRange const *liveRange : allocatedLiveRanges) {
+    auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
+    auto isAllocatedToSameTile = [&](Value value) {
+      if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+          tileOp && tileOp.getTileId() == tileIdAttr)
+        return true;
+      return liveRange->values.contains(value);
+    };
+    for (Value value : liveRange->values) {
+      for (Operation *user : value.getUsers()) {
+        if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+          // Ensure ArmSME ops that don't produce a value still get a tile ID.
+          if (!hasTileResult(tileOp))
+            tileOp.setTileId(tileIdAttr);
+        }
+      }
+      auto copyOp = value.getDefiningOp<CopyTileOp>();
+      if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
+        // Fold redundant copies.
+        rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
+      } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
+        tileOp.setTileId(tileIdAttr);
+        // Rectify operand tile IDs with result tile IDs.
+        OpOperand *tileOperand = getTileOpOperand(tileOp);
+        if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
+          continue;
+        auto operandTileOp =
+            tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
+        if (!isTriviallyCloneableTileOp(operandTileOp))
+          return tileOp.emitOpError("failed to rectify tile operand with tile "
+                                    "result (move required)");
+        // Cloning prevents a move/spill (though may require recomputation).
+        rewriter.setInsertionPoint(tileOp);
+        auto clonedOp = operandTileOp.clone();
+        clonedOp.setTileId(tileOp.getTileId());
+        rewriter.insert(clonedOp);
+        if (copyOp)
+          rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
+        else
+          tileOperand->assign(clonedOp->getResult(0));
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        // Validate block arguments.
+        bool tileMismatch = false;
+        forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+          if (tileMismatch)
+            return;
+          if (!isAllocatedToSameTile(predecessorTile)) {
+            blockArg.getOwner()->getParentOp()->emitOpError(
+                "block argument not allocated to the same tile as "
+                "predecessors");
+            tileMismatch = true;
+          }
+        });
+        if (tileMismatch)
+          return failure();
       }
     }
+  }
+  return success();
+}
 
-    return success();
+/// Prints live ranges alongside operation names for debugging.
+void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                    ArrayRef<LiveRange const *> liveRanges,
+                    FunctionOpInterface function) {
+  llvm::errs() << "SME Tile Liveness: @" << function.getName()
+               << "\nKey:\nS - Start\nE - End\n| - Live\n";
+  for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
+    llvm::errs() << "^bb" << blockIdx << ":\n";
+    for (Operation &op : block.getOperations()) {
+      unsigned operationIndex = operationToIndexMap.at(&op);
+      for (LiveRange const *range : liveRanges) {
+        char liveness = ' ';
+        for (auto it = range->ranges->begin(); it != range->ranges->end();
+             ++it) {
+          if (it.start() == operationIndex)
+            liveness = (liveness == 'E' ? '|' : 'S');
+          else if (it.stop() == operationIndex)
+            liveness = (liveness == 'S' ? '|' : 'E');
+          else if (operationIndex >= it.start() && operationIndex < it.stop())
+            liveness = '|';
+        }
+        llvm::errs() << liveness;
+      }
+      llvm::errs() << ' ' << op.getName() << '\n';
+    }
   }
-};
+  llvm::errs() << "==========\n";
+}
 
-struct TileAllocationPass
-    : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
+struct TestTileAllocationPass
+    : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
+  using TestTileAllocationBase::TestTileAllocationBase;
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    patterns.add<AssignTileIDsPattern>(patterns.getContext());
-    GreedyRewriteConfig config;
-    // Setting useTopDownTraversal ensures tiles are allocated in program
-    // order.
-    config.useTopDownTraversal = true;
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            getOperation(), std::move(patterns), config))) {
+    if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
       signalPassFailure();
-    }
   }
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
-  return std::make_unique<TileAllocationPass>();
+LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
+                                              bool dumpRanges) {
+  LiveRange::Allocator liveRangeAllocator;
+  IRRewriter rewriter(function.getContext());
+
+  // 1. Insert copy operations at branch operations.
+  insertCopiesAtBranches(rewriter, function);
+
+  // 2. Gather live ranges for each ArmSME tile within the function.
+  Liveness liveness(function);
+  auto operationToIndexMap = generateOperationNumbering(function);
+  auto initialLiveRanges = gatherTileLiveRanges(
+      operationToIndexMap, liveRangeAllocator, liveness, function);
+  if (initialLiveRanges.empty())
+    return success();
+
+  if (dumpRanges) {
+    // Wrangle initial live ranges into a form suitable for printing.
+    auto nonEmpty = llvm::make_filter_range(
+        llvm::make_second_range(initialLiveRanges),
+        [&](LiveRange const &liveRange) { return !liveRange.empty(); });
+    auto initialRanges = llvm::to_vector(llvm::map_range(
+        nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
+    std::sort(initialRanges.begin(), initialRanges.end(),
+              [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
+    llvm::errs() << "\n========== Initial Live Ranges:\n";
+    dumpLiveRanges(operationToIndexMap, initialRanges, function);
+  }
+
+  // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
+  // for tile allocation. E.g. Unify the result of an operation with its
+  // operands.
+  auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
+
+  if (dumpRanges) {
+    llvm::errs() << "\n========== Coalesced Live Ranges:\n";
+    dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
+  }
+
+  // 4. Allocate tile IDs to live ranges.
+  allocateTilesToLiveRanges(coalescedLiveRanges);
+
+  // 5. Assign the tile IDs back to the ArmSME operations.
+  if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
+                                                     coalescedLiveRanges))) {
+    return failure();
+  }
+
+  /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+  /// users). This prevents the LLVM conversion needlessly inserting spills.
+  eraseTriviallyDeadTileOps(rewriter, function);
+  return success();
 }
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index f48046a8d7995..2922940460219 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
-
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file -verify-diagnostics | FileCheck %s
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 7a9e6b4215754..803f15b4b7d05 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \
 // RUN: FileCheck %s  --check-prefix=AFTER-TILE-ALLOC
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \
 // RUN:   -split-input-file -verify-diagnostics | \
 // RUN: FileCheck %s  --check-prefix=AFTER-LLVM-LOWERING
 
@@ -56,6 +56,9 @@ func.func @use_too_many_tiles() {
   %1 = arm_sme.zero : vector<[4]x[4]xi32>
   // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %2 = arm_sme.zero : vector<[8]x[8]xi16>
+  "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 // AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles
@@ -109,18 +112,16 @@ func.func @use_too_many_tiles() {
 /// Note: In this example an entire tile swap is inserted before/after the
 /// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a
 /// single tile slice (and can omit the initial load, like in the previous example).
-func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
-  %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8>
+func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref<?x?xf32>) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %mask = vector.constant_mask [4] : vector<[4]xi1>
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> ()
   return %loadSlice : vector<[4]x[4]xf32>
 }
 // AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
-//      AFTER-TILE-ALLOC: arm_sme.get_tile
-// AFTER-TILE-ALLOC-SAME:   tile_id = 0
 //      AFTER-TILE-ALLOC: arm_sme.load_tile_slice
 // AFTER-TILE-ALLOC-SAME:   tile_id = 16
 
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 15767ff1dec3f..8ca71e6bf783f 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" \
+// RUN: -split-input-file -allow-unregistered-dialect -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index e144bac970a7d..eb64a7b6aac58 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
 
 // -----
 
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index b7ba3f728c705..d9e3d66e370ef 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,18 +1,16 @@
 // RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
-// once it becomes unused, after lowering to control flow.
+// This tests that dead tile values are removed from control flow.
 
 // -----
 
-// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
-// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
 // CHECK-NOT: vector<[4]x[4]xf32>
-func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
   %c10 = arith.constant 10 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
 ^bb1(%1: index, %2: vector<[4]x[4]xf32>):  // 2 preds: ^bb0, ^bb2
   %3 = arith.cmpi slt, %1, %c10 : index
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
deleted file mode 100644
index 74e7293eaeca5..0000000000000
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
-// duplicates.
-
-// CHECK-LABEL: @zero_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @zero_tile() {
-  %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
-  %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
-  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
-  return
-}
-
-// CHECK-LABEL: @get_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @get_tile() {
-  %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
-  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
-  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
-  return
-}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 39d9ab6491e3b..06be7bd974707 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,19 +1,17 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics
 
 // -----
 
-func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %cond: i1) {
+// Select between tileA and tileB. This is currently unsupported as it would
+// require inserting (runtime) tile moves.
+func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %tileA : vector<[4]x[4]xi32>, %tileB : vector<[4]x[4]xi32>, %cond: i1) {
   %c0 = arith.constant 0 : index
-  %tileA = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %tileB = arm_sme.get_tile : vector<[4]x[4]xi32>
-  // Select between tileA and tileB. This is currently unsupported as it would
-  // require inserting tile move operations during tile allocation.
+  // expected-error at +1 {{op failed to rectify tile operand with tile result (move required)}}
   %tile = scf.if %cond -> vector<[4]x[4]xi32> {
     scf.yield %tileA : vector<[4]x[4]xi32>
   } else {
     scf.yield %tileB : vector<[4]x[4]xi32>
   }
-  // expected-error at +1 {{op already assigned different SME virtual tile!}}
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 2dedcb2fbc24e..dd2ee55d2afc8 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -1,18 +1,24 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK-BAD
-
-// This file tests some aspects of liveness issues in the SME tile allocator.
-// These tests were designed with a new liveness-based tile allocator in mind
-// (where the names of test cases make more sense), with the current tile
-// allocator these tests all give incorrect results (which is documented by
-// `CHECK-BAD`).
-
-// Incorrect result! The second `move_vector_to_tile_slice` overwrites the first (which is still live).
-//
-// CHECK-BAD-LABEL: @constant_with_multiple_users
-// CHECK-BAD: %[[ZERO_TILE:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation=dump-tile-live-ranges -mlir-disable-threading -split-input-file -verify-diagnostics 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-LIVE-RANGE
+
+// This file tests some simple aspects of using liveness in the SME tile allocator.
+
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @constant_with_multiple_users
+//       CHECK-LIVE-RANGE: ^bb0:
+//       CHECK-LIVE-RANGE: S  arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
+
+// CHECK-LABEL: @constant_with_multiple_users(
+// CHECK-SAME:                                %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
 func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
+  // CHECK-NEXT: %[[ZERO_TILE_0:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[ZERO_TILE_1:.*]] = arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_A]], %[[ZERO_TILE_1]], %{{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_B]], %[[ZERO_TILE_0]], %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
   %zero = arm_sme.zero : vector<[4]x[4]xf32>
   %tile_a = arm_sme.move_vector_to_tile_slice %a, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile_b = arm_sme.move_vector_to_tile_slice %b, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
@@ -23,12 +29,16 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
 
 // -----
 
-// (No tile IDs -- the current tile allocator ignores this case)
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @value_with_multiple_users
+//       CHECK-LIVE-RANGE: ^bb0:
+//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
-// CHECK-BAD-LABEL: @value_with_multiple_users
-// CHECK-BAD-NOT: tile_id
 func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
-  // A future allocator should error here (as `%tile` would need to be copied).
+  // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
   %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
@@ -38,12 +48,38 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
 
 // -----
 
-// CHECK-BAD-LABEL: @reuse_tiles_after_initial_use
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @reuse_tiles_after_initial_use
+//       CHECK-LIVE-RANGE: ^bb0:
+//  CHECK-LIVE-RANGE-NEXT: S        arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: |S       arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: ||S      arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: |||S     arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//  CHECK-LIVE-RANGE-NEXT: E|||     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:  E||     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:   E|     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:    E     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:     S    arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     |S   arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     ||S  arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     |||S arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//  CHECK-LIVE-RANGE-NEXT:     E||| test.some_use
+//  CHECK-LIVE-RANGE-NEXT:      E|| test.some_use
+//  CHECK-LIVE-RANGE-NEXT:       E| test.some_use
+//  CHECK-LIVE-RANGE-NEXT:        E test.some_use
+
+// CHECK-LABEL: @reuse_tiles_after_initial_use
 func.func @reuse_tiles_after_initial_use() {
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 0 : i32}
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 2 : i32}
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 3 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 2 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 3 : i32}
   %tile_a = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_b = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_c = arm_sme.get_tile : vector<[4]x[4]xf32>
@@ -55,19 +91,13 @@ func.func @reuse_tiles_after_initial_use() {
   "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
   "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
   "test.some_use"(%tile_d) : (vector<[4]x[4]xf32>) -> ()
-  // -> Spills after the fourth tile (unnecessary):
-  // CHECK-BAD: arm_sme.zero {tile_id = 16 : i32}
-  // CHECK-BAD: arm_sme.zero {tile_id = 17 : i32}
-  // CHECK-BAD: arm_sme.zero {tile_id = 18 : i32}
-  // CHECK-BAD: arm_sme.zero {tile_id = 19 : i32}
-  // Unnecessary spills:
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // CHECK: arm_sme.zero {tile_id = 0 : i32}
+  // CHECK: arm_sme.zero {tile_id = 1 : i32}
+  // CHECK: arm_sme.zero {tile_id = 2 : i32}
+  // CHECK: arm_sme.zero {tile_id = 3 : i32}
   %tile_1 = arm_sme.zero : vector<[4]x[4]xf32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_2 = arm_sme.zero : vector<[4]x[4]xf32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_3 = arm_sme.zero : vector<[4]x[4]xf32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_4 = arm_sme.zero : vector<[4]x[4]xf32>
   "test.dummy"(): () -> ()
   "test.dummy"(): () -> ()
@@ -81,16 +111,27 @@ func.func @reuse_tiles_after_initial_use() {
 
 // -----
 
-// Incorrect result! Both branches should yield the result via the same tile.
-//
-// CHECK-BAD-LABEL: @non_overlapping_branches
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @non_overlapping_branches
+//       CHECK-LIVE-RANGE: ^bb1:
+//  CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: E cf.br
+//  CHECK-LIVE-RANGE-NEXT: ^bb2:
+//  CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: E cf.br
+
+// CHECK-LABEL: @non_overlapping_branches
 func.func @non_overlapping_branches(%cond: i1) {
+  // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
   %tile = scf.if %cond -> vector<[4]x[4]xf32> {
+    // ^bb1:
     %zero = arm_sme.zero : vector<[4]x[4]xf32>
     scf.yield %zero : vector<[4]x[4]xf32>
   } else {
+    // ^bb2:
     %undef = arm_sme.get_tile : vector<[4]x[4]xf32>
     scf.yield %undef : vector<[4]x[4]xf32>
   }
@@ -100,13 +141,15 @@ func.func @non_overlapping_branches(%cond: i1) {
 
 // -----
 
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @constant_loop_init_with_multiple_users
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// <deliberately omitted>
+
+// CHECK-LABEL: @constant_loop_init_with_multiple_users
 func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
+  // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+  // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -126,26 +169,46 @@ func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vecto
 
 // -----
 
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @run_out_of_tiles_but_avoid_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @run_out_of_tiles_but_avoid_spill
+//       CHECK-LIVE-RANGE: ^bb2:
+//  CHECK-LIVE-RANGE-NEXT: |S    arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: ||S   arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: |||S  arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+
+// Note in the live ranges (above) there is five tile values, but we only have four tiles.
+
+// CHECK-LABEL: @run_out_of_tiles_but_avoid_spill
 func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
   %init = arm_sme.zero : vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
+  // Live = %init
   scf.for %i = %c0 to %c10 step %c1 {
+    // CHECK: arm_sme.zero {tile_id = 1 : i32}
+    // CHECK: arm_sme.zero {tile_id = 2 : i32}
+    // CHECK: arm_sme.zero {tile_id = 3 : i32}
+    // CHECK: arm_sme.zero {tile_id = 0 : i32}
     %tile_a, %tile_b, %tile_c, %tile_d = scf.for %j = %c0 to %c10 step %c1
       iter_args(%iter_a = %init, %iter_b = %init, %iter_c = %init, %iter_d = %init)
         -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32> , vector<[4]x[4]xf32> , vector<[4]x[4]xf32>) {
+        // ^bb2:
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 2 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 3 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_a = arm_sme.move_vector_to_tile_slice %a, %iter_a, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_b = arm_sme.move_vector_to_tile_slice %b, %iter_b, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_c = arm_sme.move_vector_to_tile_slice %c, %iter_c, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_d = arm_sme.move_vector_to_tile_slice %d, %iter_d, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         scf.yield %new_a, %new_b, %new_c, %new_d : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
     }
+    // Live = %init, %tile_a, %tile_b, %tile_c, %tile_d (out of tiles!)
+    // This should be resolved by duplicating the arm_sme.zero (from folding
+    // arm_sme.copy_tile operations inserted by the tile allocator).
     "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
@@ -156,24 +219,47 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
 
 // -----
 
-// Incorrect result! Everything other than zero assigned to tile 1 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @avoidable_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32}
+// We should be able to avoid spills like this, but logic handling this case is
+// not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
+
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @avoidable_spill
+//       CHECK-LIVE-RANGE: ^bb2:
+//  CHECK-LIVE-RANGE-NEXT: ||     test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||S    arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |||S   arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: ||||S  arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+//  CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||  E| test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||   E test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||     arith.addi
+//  CHECK-LIVE-RANGE-NEXT: EE     cf.br
+
+// Note in the live ranges (above) there is two constant live-ins (first two ranges),
+// which gives six overlapping live ranges. The allocator currently will spill the
+// first constant (which results in a real spill at it's use), however, this could
+// be avoided by using the knowledge that at the first "test.some_use" there's
+// actually only two live ranges (so we can fix this be duplicating the constant).
+
+// CHECK-LABEL: @avoidable_spill
 func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
+  // CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[4]x[4]xf32>
   %zero = arm_sme.zero : vector<[4]x[4]xf32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
   scf.for %i = %c0 to %c10 step %c1 {
+    // So spilled here (unnecessarily).
+    // The arm_sme.zero op could be moved into the loop to avoid this.
     "test.some_use"(%zero) : (vector<[4]x[4]xf32>) -> ()
     %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
     %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
     %tile_c = arm_sme.move_vector_to_tile_slice %c, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
     %tile_d = arm_sme.move_vector_to_tile_slice %d, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
+    // %zero is still live here (due the the backedge)
     "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index cac2dcc24d104..ca339be5fb56f 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,canonicalize))" | FileCheck %s
 
 // This test verifies the tile mask operand of the zero intrinsic zeroes
 // the correct tiles. Both integer and floating-point datatypes are checked.
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index 588b44a36c29f..14d9712e971a8 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -13,26 +13,29 @@
 /// performance (hence the warning).
 func.func @use_too_many_tiles(%a: memref<?x?xi16>, %b:  memref<?x?xi16>, %c: memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16>
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16>
   // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
 
   // CHECK-LABEL: tile_a:
   // CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0
   vector.print str "tile_a:\n"
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   vector.print %tile_a : vector<[8]x[8]xi16>
   // CHECK-LABEL: tile_b:
   // CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1
   vector.print str "tile_b:\n"
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   vector.print %tile_b : vector<[8]x[8]xi16>
   // CHECK-LABEL: tile_c:
   // CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2
   vector.print str "tile_c:\n"
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   vector.print %tile_c : vector<[8]x[8]xi16>
   // CHECK-LABEL: tile_d:
   // CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
index 1794564a6a724..0648e771b8891 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
@@ -1,5 +1,6 @@
 // DEFINE: %{entry_point} = main
-// DEFINE: %{compile} = mlir-opt %s -convert-arm-sme-to-llvm -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm),test-lower-to-llvm)"
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
index e942c7b8ac058..cdd8afe141421 100644
--- a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_library(MLIRArmSMETestPasses
   MLIRTransforms
   MLIRVectorToArmSME
   MLIRVectorToSCF
+  MLIRSCFToControlFlow
   )
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index 48d4a5859f8a0..d3dabaf200fdc 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -14,10 +14,12 @@
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
 #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -34,6 +36,10 @@ struct TestLowerToArmSMEOptions
       llvm::cl::desc("Fuse outer product operations via "
                      "'-arm-sme-outer-product-fusion' pass"),
       llvm::cl::init(true)};
+  PassOptions::Option<bool> dumpTileLiveRanges{
+      *this, "dump-tile-live-ranges",
+      llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
+      llvm::cl::init(false)};
 };
 
 void buildTestLowerToArmSME(OpPassManager &pm,
@@ -65,20 +71,17 @@ void buildTestLowerToArmSME(OpPassManager &pm,
   pm.addPass(createConvertVectorToSCFPass(
       VectorTransferToSCFOptions().enableFullUnroll()));
 
-  // Allocate tiles for ArmSME operations.
-  //
-  // Later passes may create further ArmSME ops that implement the
-  // ArmSMETileOpInterface, but tiles are allocated for root operations,
-  // all of which should now exist.
-  pm.addPass(arm_sme::createTileAllocationPass());
-
   // Enable streaming-mode and ZA.
   pm.addPass(arm_sme::createEnableArmStreamingPass(
       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
       /*onlyIfRequiredByOps=*/true));
 
+  // Convert SCF to CF (required for ArmSME tile allocation).
+  pm.addPass(createConvertSCFToCFPass());
+
   // Convert ArmSME to LLVM.
-  pm.addPass(createConvertArmSMEToLLVMPass());
+  pm.addNestedPass<func::FuncOp>(
+      createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
 
   // Sprinkle some cleanups.
   pm.addPass(createCanonicalizerPass());

>From 035742b41cb07e1b12a7e7347b42425a9b06916c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 11:43:11 +0000
Subject: [PATCH 2/9] Review fixups

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h        |  4 ----
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp                   |  4 ++--
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp  | 10 ++++------
 mlir/test/Dialect/ArmSME/roundtrip.mlir                |  9 +++++++++
 4 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
index f31062d8c25ed..0c26fa69c8587 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -5,10 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This file declares the Target dialect for ArmSME in MLIR.
-//
-//===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
 #define MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 00d764bf5caff..6c0ebddb5a2dc 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -153,10 +153,10 @@ OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
   auto isTileOperandType = [](OpOperand &operand) {
     return arm_sme::isValidSMETileVectorType(operand.get().getType());
   };
-  OpOperand *tileOperand =
-      llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
   assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
          "expected at most one tile operand");
+  OpOperand *tileOperand =
+      llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
   if (tileOperand == tileOp->getOpOperands().end())
     return nullptr;
   return tileOperand;
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e3cf52078a24e..7b3381652e28a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -131,7 +131,7 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
 
 class TileAllocator {
 public:
-  /// Allocates and returns a tile ID.
+  /// Allocates and returns a tile ID. Fails if there are no tiles left.
   FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
     auto masks = getMasks(tileType);
     for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
@@ -198,7 +198,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   }
 }
 
-/// Inserts tile copies `cf.br` operations.
+/// Inserts tile copies at `cf.br` operations.
 void insertCopiesAtBranches(IRRewriter &rewriter,
                             FunctionOpInterface function) {
   splitCondBranches(rewriter, function);
@@ -278,10 +278,8 @@ generateOperationNumbering(FunctionOpInterface function) {
       // This is only correct if all ArmSME have been converted to CF.
 #ifndef NDEBUG
       op.walk([&](ArmSMETileOpInterface nestedOp) {
-        if (&op != nestedOp.getOperation()) {
-          assert(false &&
-                 "ArmSME tile allocation does not support nested regions");
-        }
+        assert(&op == nestedOp.getOperation() &&
+               "ArmSME tile allocation does not support nested regions");
       });
 #endif
       operationToIndexMap.try_emplace(&op, index++);
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ab46c7adca596..6095fdc11ead8 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1403,3 +1403,12 @@ func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
   %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
   return %reuslt : vector<[2]x[2]xi64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.copy_tile
+//===----------------------------------------------------------------------===//
+
+func.func @arm_sme_copy_tile(%vec: vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+  %result = arm_sme.copy_tile %vec : vector<[4]x[4]xf32>
+  return %result : vector<[4]x[4]xf32>
+}

>From c3471d53f14fa1a765b777d079ccda9410fa165f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 14:23:57 +0000
Subject: [PATCH 3/9] More review fixups

---
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 6 +++---
 mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir    | 3 +--
 mlir/test/Dialect/ArmSME/canonicalize.mlir            | 4 +---
 3 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 7b3381652e28a..bcc6deeabd16a 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -414,7 +414,7 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
   return std::move(coalescedLiveRanges);
 }
 
-/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
 /// Note: This does not attempt to fill holes in live/allocated ranges.
 void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
   TileAllocator tileAllocator;
@@ -629,8 +629,8 @@ LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
     return failure();
   }
 
-  /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
-  /// users). This prevents the LLVM conversion needlessly inserting spills.
+  // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+  // users). This prevents the LLVM conversion needlessly inserting spills.
   eraseTriviallyDeadTileOps(rewriter, function);
   return success();
 }
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 8ca71e6bf783f..a62ca080ab8d9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" \
-// RUN: -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index d9e3d66e370ef..643dfd4a7cbd9 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,9 +1,7 @@
-// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s
 
 // This tests that dead tile values are removed from control flow.
 
-// -----
-
 // CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
 // CHECK-NOT: vector<[4]x[4]xf32>
 func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {

>From cb51a2f0ad1a0c7988774769533b90c208499123 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 17:04:15 +0000
Subject: [PATCH 4/9] More fixups

---
 .../Dialect/ArmSME/IR/ArmSMEOpInterfaces.h    |   3 +
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |   6 +-
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  52 +++---
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         |   2 +-
 .../Dialect/ArmSME/basic-tile-allocation.mlir |   2 +-
 .../ArmSME/tile-allocation-invalid.mlir       |   4 +-
 .../ArmSME/tile-allocation-liveness.mlir      | 160 ++++++++++--------
 7 files changed, 123 insertions(+), 106 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
index 0c26fa69c8587..9153fbb57ea88 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -17,7 +17,10 @@ namespace detail {
 LogicalResult verifyArmSMETileOpInterface(Operation *);
 }
 
+// The first in-memory SME tile ID. This is set to 16 as that is the first tile
+// ID larger than any virtual tile ID supported by the SME ISA.
 static constexpr unsigned kInMemoryTileIdBase = 16;
+
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
 } // namespace mlir::arm_sme
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index b9d74fec6756e..8c129ea623b6f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -126,20 +126,20 @@ def EnableArmStreaming
 
 def TestTileAllocation
     : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
-  let summary = "Tests SME tile allocation";
+  let summary = "Tests SME 'virtual tile' allocation";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
     'func.func' op level, and assigns tile IDs (via an attribute) to all ops
     that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
     to be used for testing, tile allocation is done as part of the ArmSME to
-    LLVM conversion.
+    LLVM conversion (`convert-arm-sme-to-llvm`).
   }];
   let options = [
     Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
            "bool", /*default=*/"false",
            "Dump the live ranges of SME tiles (for debugging)">
   ];
-  let dependentDialects = ["func::FuncDialect"];
+  let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
 }
 
 def OuterProductFusion
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 488f747af050f..562c58b73e7e3 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -885,34 +885,38 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::GetTileOp, arm_sme::CopyTileOp, arm_sme::aarch64_sme_zero,
-      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
-      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
-      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
-      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
-      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
-      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
-      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
-      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
-      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
-      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
-      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
-      arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
-      arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
-      arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
-      arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
-      arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
-      arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
-      arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
-      arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
-      arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
+      arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
+      arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
+      arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
+      arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
+      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
+      arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
+      arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
+      arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
+      arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
+      arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
+      arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
+      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
+      arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
+      arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
+      arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
+      arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
+      arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
+      arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
+      arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
+      arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
+      arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
+      arm_sme::aarch64_sme_cntsd>();
   target.addLegalDialect<arith::ArithDialect,
                          /* The following are used to lower tile spills/fills */
                          vector::VectorDialect, scf::SCFDialect,
                          memref::MemRefDialect>();
-  target.addLegalOp<UnrealizedConversionCastOp>();
+  // Pseudo operations. These cannot be code-generated but may exist in the
+  // input IR, or be generated during the conversion. They need to be eliminated
+  // before the final conversion to LLVM IR (and likely will be due to DCE).
+  target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
+                    UnrealizedConversionCastOp>();
 }
 
 void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 2922940460219..14b1f323da3a2 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file | FileCheck %s
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index eb64a7b6aac58..8b46998d56b04 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s
 
 // -----
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 06be7bd974707..b3112264cba9c 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics
-
-// -----
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics
 
 // Select between tileA and tileB. This is currently unsupported as it would
 // require inserting (runtime) tile moves.
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index dd2ee55d2afc8..a63312ded1b93 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -3,14 +3,14 @@
 
 // This file tests some simple aspects of using liveness in the SME tile allocator.
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @constant_with_multiple_users
-//       CHECK-LIVE-RANGE: ^bb0:
-//       CHECK-LIVE-RANGE: S  arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
-//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
+//  CHECK-LIVE-RANGE-LABEL: @constant_with_multiple_users
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//        CHECK-LIVE-RANGE: S  arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
 // CHECK-LABEL: @constant_with_multiple_users(
 // CHECK-SAME:                                %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
@@ -29,13 +29,13 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @value_with_multiple_users
-//       CHECK-LIVE-RANGE: ^bb0:
-//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
-//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
+//  CHECK-LIVE-RANGE-LABEL: @value_with_multiple_users
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//   CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
 func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
   // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
@@ -48,31 +48,31 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @reuse_tiles_after_initial_use
-//       CHECK-LIVE-RANGE: ^bb0:
-//  CHECK-LIVE-RANGE-NEXT: S        arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: |S       arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: ||S      arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: |||S     arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
-//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
-//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
-//  CHECK-LIVE-RANGE-NEXT: E|||     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:  E||     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:   E|     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:    E     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:     S    arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     |S   arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     ||S  arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     |||S arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
-//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
-//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
-//  CHECK-LIVE-RANGE-NEXT:     E||| test.some_use
-//  CHECK-LIVE-RANGE-NEXT:      E|| test.some_use
-//  CHECK-LIVE-RANGE-NEXT:       E| test.some_use
-//  CHECK-LIVE-RANGE-NEXT:        E test.some_use
+//  CHECK-LIVE-RANGE-LABEL: @reuse_tiles_after_initial_use
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//   CHECK-LIVE-RANGE-NEXT: S        arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: |S       arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: ||S      arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: |||S     arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//   CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//   CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//   CHECK-LIVE-RANGE-NEXT: E|||     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:  E||     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:   E|     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:    E     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:     S    arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     |S   arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     ||S  arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     |||S arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//   CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//   CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//   CHECK-LIVE-RANGE-NEXT:     E||| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:      E|| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:       E| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:        E test.some_use
 
 // CHECK-LABEL: @reuse_tiles_after_initial_use
 func.func @reuse_tiles_after_initial_use() {
@@ -111,16 +111,16 @@ func.func @reuse_tiles_after_initial_use() {
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @non_overlapping_branches
-//       CHECK-LIVE-RANGE: ^bb1:
-//  CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: E cf.br
-//  CHECK-LIVE-RANGE-NEXT: ^bb2:
-//  CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: E cf.br
+//  CHECK-LIVE-RANGE-LABEL: @non_overlapping_branches
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb1:
+//   CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: E cf.br
+//   CHECK-LIVE-RANGE-NEXT: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: E cf.br
 
 // CHECK-LABEL: @non_overlapping_branches
 func.func @non_overlapping_branches(%cond: i1) {
@@ -141,8 +141,20 @@ func.func @non_overlapping_branches(%cond: i1) {
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// <deliberately omitted>
+// Here %vecA and %vecB are not merged into the same live range (as they are unknown values).
+// This means that %vecA and %vecB are both allocated to different tiles (which is not legal).
+func.func @overlapping_branches(%cond: i1, %vecA: vector<[4]x[4]xf32>, %vecB: vector<[4]x[4]xf32>) {
+  // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
+  %tile = scf.if %cond -> vector<[4]x[4]xf32> {
+    scf.yield %vecA : vector<[4]x[4]xf32>
+  } else {
+    scf.yield %vecB : vector<[4]x[4]xf32>
+  }
+  "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
 
 // CHECK-LABEL: @constant_loop_init_with_multiple_users
 func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
@@ -169,14 +181,14 @@ func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vecto
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @run_out_of_tiles_but_avoid_spill
-//       CHECK-LIVE-RANGE: ^bb2:
-//  CHECK-LIVE-RANGE-NEXT: |S    arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: ||S   arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: |||S  arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+//  CHECK-LIVE-RANGE-LABEL: @run_out_of_tiles_but_avoid_spill
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: |S    arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: ||S   arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: |||S  arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
 
 // Note in the live ranges (above) there is five tile values, but we only have four tiles.
 
@@ -222,20 +234,20 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
 // We should be able to avoid spills like this, but logic handling this case is
 // not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @avoidable_spill
-//       CHECK-LIVE-RANGE: ^bb2:
-//  CHECK-LIVE-RANGE-NEXT: ||     test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||S    arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |||S   arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: ||||S  arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
-//  CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||  E| test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||   E test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||     arith.addi
-//  CHECK-LIVE-RANGE-NEXT: EE     cf.br
+//  CHECK-LIVE-RANGE-LABEL: @avoidable_spill
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: ||     test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||S    arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |||S   arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: ||||S  arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+//   CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||  E| test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||   E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||     arith.addi
+//   CHECK-LIVE-RANGE-NEXT: EE     cf.br
 
 // Note in the live ranges (above) there is two constant live-ins (first two ranges),
 // which gives six overlapping live ranges. The allocator currently will spill the

>From c14e7b5c56104624264fa8197f4165df5f558e1f Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 2 May 2024 15:44:52 +0000
Subject: [PATCH 5/9] Add test for tile copies

---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  5 +-
 .../ArmSME/Transforms/TileAllocation.cpp      |  7 +-
 .../ArmSME/tile-allocation-copies.mlir        | 82 +++++++++++++++++++
 3 files changed, 92 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8c129ea623b6f..81730166d3c01 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -137,7 +137,10 @@ def TestTileAllocation
   let options = [
     Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
            "bool", /*default=*/"false",
-           "Dump the live ranges of SME tiles (for debugging)">
+           "Dump the live ranges of SME tiles (for debugging)">,
+    Option<"tileCopiesOnly", "tile-copies-only", "bool", /*default=*/"false",
+           "Only insert tile copies needed for tile allocation "
+           "(but do not allocate any tiles)">
   ];
   let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index bcc6deeabd16a..ce8f341524751 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -575,7 +575,12 @@ struct TestTileAllocationPass
     : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
   using TestTileAllocationBase::TestTileAllocationBase;
   void runOnOperation() override {
-    if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
+    FunctionOpInterface function = getOperation();
+    if (tileCopiesOnly) {
+      IRRewriter rewriter(function);
+      return insertCopiesAtBranches(rewriter, function);
+    }
+    if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
new file mode 100644
index 0000000000000..5e43efe93b556
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation=tile-copies-only -split-input-file | FileCheck %s
+
+// This file tests the inserting copies for the SME tile allocation. Copies are
+// inserted at `cf.br` ops (the predecessors to block arguments). Conditional
+// branches are split to prevent conflicts (see cond_br_with_backedge).
+
+// CHECK-LABEL: func.func @simple_branch(
+//  CHECK-SAME:   %[[TILE:.*]]: vector<[4]x[4]xf32>)
+//   %[[COPY:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+//   cf.br ^bb1(%[[COPY]] : vector<[4]x[4]xf32>)
+// ^bb1(%[[BLOCK_ARG:.*]]: vector<[4]x[4]xf32>):
+
+func.func @simple_branch(%tile : vector<[4]x[4]xf32>) {
+  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+  return
+}
+
+// -----
+
+// Note: The ^POINTLESS_SHIM_FOR_BB2 block is added as the cond_br splitting does
+// not check if it needs to insert a copy or not (there is no harm in the empty
+// block though -- it will fold away later).
+
+// CHECK-LABEL: func.func @cond_branch(
+//  CHECK-SAME:   %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32>
+//       CHECK:   cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[POINTLESS_SHIM_FOR_BB2:[[:alnum:]]+]]
+//       CHECK: ^[[POINTLESS_SHIM_FOR_BB2]]:
+//       CHECK:   cf.br ^[[BB2:.*]]
+//       CHECK: ^[[BB1_COPIES]]:
+//       CHECK:   arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+//       CHECK:   cf.br ^[[BB1:.*]]
+func.func @cond_branch(%cond: i1, %tile: vector<[4]x[4]xf32>) {
+  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+  return
+^bb2:
+  return
+}
+
+// -----
+
+// Reduction of a real world example that shows why we must split conditional branches.
+
+// CHECK-LABEL: @cond_branch_with_backedge(
+//  CHECK-SAME:    %{{[[:alnum:]]+}}: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+//  CHECK-SAME:    %[[TILEC:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILED:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+//       CHECK: ^bb1(%[[CURRENT_INDEX:.*]]: index, %[[ITER_TILE:.*]]: vector<[4]x[4]xf32>):
+//       CHECK:   %[[CONTINUE_LOOP:.*]] = arith.cmpi
+//       CHECK:   cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]]
+//       CHECK: ^[[BB3_COPIES]]:
+//  CHECK-NEXT:   arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   cf.br ^[[BB3:[[:alnum:]]+]]
+//       CHECK: ^[[BB3]](%{{.*}}: vector<[4]x[4]xf32>):
+//  CHECK-NEXT:   return
+
+func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector<[4]x[4]xf32>, %tileC: vector<[4]x[4]xf32>, %tileD: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  // Live here: %tileA, %tileB, %tileC, %tileD
+  cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+  %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  // %iterTile dies at the `cf.cond_br`, but %tileB, %tileC, %tileD are live out (in the ^bb2 case).
+  // If we inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
+  // However, note that the copies are only needed if we take the ^bb3 path. So, if we add
+  // a new block along that path we can insert the copies without any conflicts.
+  cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+  %nextIndex = arith.addi %currentIndex, %c1 : index
+  cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+  // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+  return
+}

>From 8d5e8c6627a64e20a513124b5ce1ebfd76fac8ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 2 May 2024 16:41:34 +0000
Subject: [PATCH 6/9] Wrap op modifications in `rewriter.modifyOpInPlace()`

---
 .../ArmSME/Transforms/TileAllocation.cpp      | 28 +++++++++++--------
 1 file changed, 17 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index ce8f341524751..af8079c8dee5c 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -191,10 +191,12 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
                condBranch.getTrueDestOperands());
     insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
                condBranch.getFalseDestOperands());
-    condBranch.getFalseDestOperandsMutable().clear();
-    condBranch.getTrueDestOperandsMutable().clear();
-    condBranch.setSuccessor(newTrueBranch, 0);
-    condBranch.setSuccessor(newFalseBranch, 1);
+    rewriter.modifyOpInPlace(condBranch, [&] {
+      condBranch.getFalseDestOperandsMutable().clear();
+      condBranch.getTrueDestOperandsMutable().clear();
+      condBranch.setSuccessor(newTrueBranch, 0);
+      condBranch.setSuccessor(newFalseBranch, 1);
+    });
   }
 }
 
@@ -211,7 +213,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
       if (isValidSMETileVectorType(operand.get().getType())) {
         auto copy =
             rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
-        operand.assign(copy);
+        rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
       }
     }
   }
@@ -494,7 +496,8 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
           // Ensure ArmSME ops that don't produce a value still get a tile ID.
           if (!hasTileResult(tileOp))
-            tileOp.setTileId(tileIdAttr);
+            rewriter.modifyOpInPlace(tileOp,
+                                     [&] { tileOp.setTileId(tileIdAttr); });
         }
       }
       auto copyOp = value.getDefiningOp<CopyTileOp>();
@@ -502,7 +505,7 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         // Fold redundant copies.
         rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
       } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
-        tileOp.setTileId(tileIdAttr);
+        rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
         // Rectify operand tile IDs with result tile IDs.
         OpOperand *tileOperand = getTileOpOperand(tileOp);
         if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
@@ -515,12 +518,15 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         // Cloning prevents a move/spill (though may require recomputation).
         rewriter.setInsertionPoint(tileOp);
         auto clonedOp = operandTileOp.clone();
-        clonedOp.setTileId(tileOp.getTileId());
+        rewriter.modifyOpInPlace(
+            clonedOp, [&] { clonedOp.setTileId(tileOp.getTileId()); });
         rewriter.insert(clonedOp);
-        if (copyOp)
+        if (copyOp) {
           rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
-        else
-          tileOperand->assign(clonedOp->getResult(0));
+        } else {
+          rewriter.modifyOpInPlace(
+              tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
+        }
       } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
         // Validate block arguments.
         bool tileMismatch = false;

>From 70bd8c9625e4ca9fba9f446f165b08265f1586e7 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 3 May 2024 15:58:44 +0000
Subject: [PATCH 7/9] More docs + naming

---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  6 +-
 .../ArmSME/Transforms/TileAllocation.cpp      | 62 ++++++++++++++-----
 2 files changed, 50 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 562c58b73e7e3..3dbc8e9916df6 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -406,11 +406,11 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
 ///  AFTER:
 ///  ```mlir
 ///     "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
-///     %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
+///     %v = arm_sme.get_tile : vector<[4]x[4]xi32>
 ///  ```
 ///
-///  The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
-///  once all ArmSME ops have been converted to LLVM intrinsics.
+///  The 'arm_sme.get_tile' (which models the return) will fold away once all
+///  ArmSME ops have been converted to LLVM intrinsics.
 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index af8079c8dee5c..bc41386f5903e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -164,9 +164,23 @@ class TileAllocator {
   unsigned nextInMemoryTileId = kInMemoryTileIdBase;
 };
 
-// Add new intermediate blocks for the true and false destinations of a
-// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
-// branches.
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+///  ^bb1_copy:
+///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ^bb2_copy:
+///    cf.br ^bb2
+///  ```
 void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   SmallVector<cf::CondBranchOp> worklist;
   function.walk([&](cf::CondBranchOp condBranch) {
@@ -200,7 +214,18 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   }
 }
 
-/// Inserts tile copies at `cf.br` operations.
+/// Splits conditional branches (see `splitCondBranches`), then inserts tile
+/// copies at `cf.br` operations.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+///  ```
 void insertCopiesAtBranches(IRRewriter &rewriter,
                             FunctionOpInterface function) {
   splitCondBranches(rewriter, function);
@@ -219,7 +244,9 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
   }
 }
 
-/// A range where a tile value is live. The range may contain holes.
+/// A live range for a (collection of) tile values. A live range is built up of
+/// intervals [start, end) which represent parts of the program where the value
+/// needs to be live (i.e. in an SME virtual tile).
 struct LiveRange {
   using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
                                      llvm::IntervalMapHalfOpenInfo<unsigned>>;
@@ -296,33 +323,38 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                      LiveRange::Allocator &liveRangeAllocator,
                      Liveness &liveness, FunctionOpInterface function) {
   DenseMap<Value, LiveRange> liveRanges;
-  auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
-                              LivenessBlockInfo const &livenessInfo,
-                              bool liveAtBlockEntry = false) {
+  /// Defines or updates a live range for an SME tile value. Live-ins may update
+  /// an existing live range (rather than define a new one).
+  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+                                          LivenessBlockInfo const &livenessInfo,
+                                          bool liveAtBlockEntry = false) {
     if (!isValidSMETileVectorType(value.getType()))
       return;
-    auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+    // Find or create a live range for `value`.
+    auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+    LiveRange &valueLiveRange = it->second;
     auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
     unsigned start =
         operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
     unsigned end = operationToIndexMap.at(lastUseInBlock);
-    it->second.insert(value, start, end);
+    valueLiveRange.insert(value, start, end);
   };
 
   for (Block &block : function.getBlocks()) {
     LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
     // Handle block arguments:
     for (Value argument : block.getArguments())
-      updateLiveRanges(argument, &block.front(), *livenessInfo,
-                       /*liveAtBlockEntry=*/true);
+      defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
     // Handle live-ins:
     for (Value liveIn : livenessInfo->in())
-      updateLiveRanges(liveIn, &block.front(), *livenessInfo,
-                       /*liveAtBlockEntry=*/true);
+      defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
     // Handle new definitions:
     for (Operation &op : block) {
       for (Value result : op.getResults())
-        updateLiveRanges(result, &op, *livenessInfo);
+        defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
     }
   }
 

>From c6aa47d4cdfdf21cff4943ceca99c477d518a94e Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 May 2024 10:47:47 +0000
Subject: [PATCH 8/9] Add CF allocation test for cond_br with backedge

---
 .../ArmSME/tile-allocation-liveness.mlir      | 72 +++++++++++++++++++
 1 file changed, 72 insertions(+)

diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index a63312ded1b93..b46c0cfc9a274 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -279,3 +279,75 @@ func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<
   }
   return
 }
+
+// -----
+
+// This test is a follow up to the test of the same name in `tile-allocation-copies.mlir`.
+// This shows the live ranges (which are why we need to split the conditional branch).
+
+//  CHECK-LIVE-RANGE-LABEL: @cond_branch_with_backedge
+//        CHECK-LIVE-RANGE: ^bb1:
+//  CHECK-LIVE-RANGE--NEXT:  ||| |           arith.cmpi
+//  CHECK-LIVE-RANGE--NEXT:  EEE E           cf.cond_br
+//
+//  CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES:[[:alnum:]]+]]:
+//  CHECK-LIVE-RANGE--NEXT:  ||| ES          arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:  E||  |S         arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:   E|  ||S        arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:    E  |||S       arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:       EEEE       cf.br
+//
+// It is important to note that the first three live ranges in ^bb1 do not end
+// at the `cf.cond_br` they are live-out via the backedge bb1 -> bb2 -> bb1.
+// This means that if we placed the `arm_sme.tile_copies` before the `cf.cond_br`
+// then those live ranges would not end at the copies, resulting in unwanted
+// overlapping live ranges (and hence tile spills).
+//
+// With the conditional branch split and the copies placed in the BB3_COPIES
+// block the first three live ranges end at the copy operations (as the
+// BB3_COPIES block is on the path out of the loop and has no backedge). This
+// means there is no overlaps and the live ranges all merge, as shown below.
+//
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb1:
+//  CHECK-LIVE-RANGE--NEXT: |||| arith.cmpi
+//  CHECK-LIVE-RANGE--NEXT: EEEE cf.cond_br
+//
+//  CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES]]:
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: EEEE cf.br
+
+// CHECK-LABEL: @cond_branch_with_backedge
+// CHECK-NOT: tile_id = 16
+// CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 2 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 3 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-NOT tile_id = 16
+func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %tileD = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  // Live here: %tileA, %tileB, %tileC, %tileD
+  cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+  %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+  %nextIndex = arith.addi %currentIndex, %c1 : index
+  cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+  // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+  return
+}

>From fc96a7a69491844dc51476174021dfe18697f5d3 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 May 2024 11:06:40 +0000
Subject: [PATCH 9/9] Fix comment

---
 mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
index 5e43efe93b556..da7f4843f75fe 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -66,8 +66,8 @@ func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector
 ^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
   %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
   // Live here: %iterTile, %tileB, %tileC, %tileD
-  // %iterTile dies at the `cf.cond_br`, but %tileB, %tileC, %tileD are live out (in the ^bb2 case).
-  // If we inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
+  // %iterTile, %tileB, %tileC, %tileD are live out (in the ^bb2 case). If we
+  // inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
   // However, note that the copies are only needed if we take the ^bb3 path. So, if we add
   // a new block along that path we can insert the copies without any conflicts.
   cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)



More information about the Mlir-commits mailing list