[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

Benjamin Maxwell llvmlistbot at llvm.org
Thu May 9 09:10:59 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/90448

>From 7494e9e74834932249efa77dcab82df40241e36d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 26 Apr 2024 15:09:20 +0000
Subject: [PATCH 01/13] [mlir][ArmSME] Use liveness information in the tile
 allocator

This patch rewrites the ArmSME tile allocator to use liveness
information to make better tile allocation decisions and improve the
correctness of the ArmSME dialect. This algorithm used here is a linear
scan over live ranges, where live ranges are assigned to tiles as they
appear in the program (chronologically). Live ranges release their
assigned tile ID when the current program point is passed their end.
This is a greedy algorithm (which is mainly to keep the implementation
relatively straightforward), and because it seems to be sufficient for
most kernels (e.g. matmuls) that use ArmSME. The general steps of this
are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf,
though there have been a few simplifications and assumptions made for
our use case.

Hopefully, the only changes needed for a user of the ArmSME dialect is
that:

- `-allocate-arm-sme-tiles` will no longer be a standalone pass
  - `-test-arm-sme-tile-allocation` is only for unit tests
- `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf`
  - SME tile allocation is now part of the LLVM conversion

By integrating this into the `ArmSME -> LLVM` conversion we can allow
high-level (value-based) ArmSME operations to be side-effect-free, as we
can guarantee nothing will rearrange ArmSME operations before we emit
intrinsics (which could invalidate the tile allocation).

The hope is for ArmSME operations to have no hidden state/side effects
and allow easily lowering dialects such as `vector` and `arith` to SME,
without making assumptions about how the input IR looks, as the
semantics of the operations will be the same. That is no (new) side
effects and the IR follows the rules of SSA (a value will never change).

The aim is correctness, so we have a base for working on optimizations.
---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h    |   4 +-
 mlir/include/mlir/Conversion/Passes.td        |   7 +-
 mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h  |   6 +-
 .../Dialect/ArmSME/IR/ArmSMEOpInterfaces.h    |  28 +
 .../mlir/Dialect/ArmSME/IR/ArmSMEOps.td       | 141 ++--
 .../mlir/Dialect/ArmSME/Transforms/Passes.h   |   3 -
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  17 +-
 .../Dialect/ArmSME/Transforms/Transforms.h    |   9 +
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  20 +
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  69 +-
 .../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp    |  29 +-
 mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp         |   6 +
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |  46 ++
 .../ArmSME/Transforms/TileAllocation.cpp      | 644 +++++++++++++-----
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         |   3 +-
 .../ArmSMEToLLVM/tile-spills-and-fills.mlir   |  17 +-
 .../Conversion/ArmSMEToLLVM/unsupported.mlir  |   3 +-
 .../Dialect/ArmSME/basic-tile-allocation.mlir |   2 +-
 mlir/test/Dialect/ArmSME/canonicalize.mlir    |  10 +-
 mlir/test/Dialect/ArmSME/cse.mlir             |  30 -
 .../ArmSME/tile-allocation-invalid.mlir       |  12 +-
 .../ArmSME/tile-allocation-liveness.mlir      | 196 ++++--
 mlir/test/Dialect/ArmSME/tile-zero-masks.mlir |   2 +-
 .../Linalg/CPU/ArmSME/use-too-many-tiles.mlir |   7 +-
 .../ArmSME/Emulated/test-setArmSVLBits.mlir   |   3 +-
 mlir/test/lib/Dialect/ArmSME/CMakeLists.txt   |   1 +
 .../lib/Dialect/ArmSME/TestLowerToArmSME.cpp  |  19 +-
 27 files changed, 894 insertions(+), 440 deletions(-)
 create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
 delete mode 100644 mlir/test/Dialect/ArmSME/cse.mlir

diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab4999..403f811a2569 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
 class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
 void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab..e6d678dc1b12 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
 // ArmSMEToLLVM
 //===----------------------------------------------------------------------===//
 
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
   let summary = "Lower the operations from the ArmSME dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
     "arm_sme::ArmSMEDialect",
     "LLVM::LLVMDialect"
   ];
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a..dac54712c7f4 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 000000000000..f31062d8c25e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,28 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+static constexpr unsigned kInMemoryTileIdBase = 16;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d..9178655f010c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
 
 def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   let description = [{
-    An interface for operations that use or allocate Arm SME tiles. These
-    operations need to be assigned a tile ID, an i32 attribute, which specifies
-    which virtual tile within the ZA storage to use. The number of tiles
-    available depends on the type of the tile. This is summarized below:
+    An interface for operations that use Arm SME tiles. These operations need to
+    be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+    within the ZA storage to use. The number of tiles available depends on the
+    type of the tile. This is summarized below:
 
     | Tile Vector Types                                                       | Possible Tile IDs   |
     |-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
     | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
     | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
     | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
-
-    Operations that allocate a new tile (such as arm_sme.get_tile), are used as
-    the roots for tile allocation, with all operations that (transitively)
-    depend on a root being assigned the same tile ID.
   }];
   let methods = [
     InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
         return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
       }]
     >,
-    InterfaceMethod<
-      [{
-        The type of tile this operation allocates. Returns none (std::nullopt)
-        if this operation does not allocate a tile.
-      }],
-      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
-      /*methodName=*/"getAllocatedTileType",
-      /*arguments=*/(ins),
-      /*methodBody=*/[{}],
-      /*defaultImpl=*/ [{
-        // This operation does not allocate a tile.
-        return std::nullopt;
-      }]
-    >,
     InterfaceMethod<
       "Returns the VectorType of the tile used by this operation.",
       /*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   ];
 
   let extraSharedClassDeclaration = [{
-    // A helper to create a new operation and propagate this operations tile ID.
-    template<typename T, typename... Args>
-    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
-      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
-      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
-        tileOp.setTileId($_op.getTileId());
-      return op;
-    }
-
-    // A helper to replace this operation and forward its tile ID (if present).
-    template<typename T, typename... Args>
-    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
-      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
-      rewriter.replaceOp($_op, newOp);
-      return newOp;
-    }
-
     bool isInMemoryTile() {
       auto tileId = getTileId();
       return tileId && tileId.getInt() >= kInMemoryTileIdBase;
     }
   }];
 
-  let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+  let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
-  let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates an undefined value of SME virtual tile type";
   let description = [{
-    Allocates a new SME "virtual tile" within a function. The contents of the
-    tile returned from this operation are undefined.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are undefined.
 
     Example 1:
 
     ```mlir
-    // Allocate an 8-bit element "virtual tile"
+    // Create an 8-bit element "virtual tile" value:
     %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
     Example 2:
 
     ```mlir
-    // Allocate two 16-bit element "virtual tiles"
+    // Create two 16-bit element "virtual tiles" values:
     %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
     Example 3:
     ```mlir
-    // Allocate an 128-bit element "virtual tile"
+    // Create an 128-bit element "virtual tile" value:
     %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
   }];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
     VectorType getTileType() {
       return ::llvm::cast<VectorType>(getTile().getType());
     }
-
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getTileType());
-    }
-  }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
-  let summary = "SME tile placeholder";
-  let description = [{
-    A placeholder to preserve dataflow while lowering to SME intrinsics (which
-    do not take or return SME virtual tile values). This operation is intended
-    to be DCE'd once all ArmSME operations have been lowered.
-
-    This operation is not intended to be used outside of the ArmSME -> LLVM
-    conversion.
   }];
-  let results = (outs SMETile:$tile);
-  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
-  let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates a zero-initialized value of SME virtual tile type";
   let results = (outs SMETile:$res);
   let description = [{
-    Initialise ZA with 0. This operation is convenient wrapper for the SME
-    `zero` intrinsic and instruction.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are zero-initialized.
 
     Example 1: Zero an 8-bit element ZA tile.
 
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+  Pure,
+  ArmSMETileOpInterface,
+  AllTypesMatch<["tile", "result"]>
+]> {
+  let summary = "Copies an SME tile value";
+  let arguments = (ins SMETile:$tile);
+  let results = (outs SMETile:$result);
+  let description = [{
+    Copies an SME "virtual tile" value to a new SSA value. This operation is
+    primarily intended to be used to normalize the IR prior to tile allocation.
+
+    Example:
+
+    ```mlir
+    %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+  let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
 def TileLoadOp : ArmSME_Op<"tile_load", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
 }
 
 def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
                                list<Type> allowedResultVectorTypes,
                                int numOuterProducts> :
   ArmSME_Op<mnemonic, [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874..156744ba57e7 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
     const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
 
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
 std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e892..b9d74fec6756 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,16 +124,21 @@ def EnableArmStreaming
   let dependentDialects = ["func::FuncDialect"];
 }
 
-def TileAllocation
-    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
-  let summary = "Allocate SME tiles";
+def TestTileAllocation
+    : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+  let summary = "Tests SME tile allocation";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
     'func.func' op level, and assigns tile IDs (via an attribute) to all ops
-    that implement the `ArmSMETileOpInterface`. An error will be emitted when
-    there's no tiles left.
+    that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+    to be used for testing, tile allocation is done as part of the ArmSME to
+    LLVM conversion.
   }];
-  let constructor = "mlir::arm_sme::createTileAllocationPass()";
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
   let dependentDialects = ["func::FuncDialect"];
 }
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e699..a25b844f01ea 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_H
 
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
 namespace mlir {
 
 class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
 class RewritePatternSet;
 
 namespace arm_sme {
+
 void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+                               bool dumpRanges = false);
+
 } // namespace arm_sme
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92..9ea1c5a5d63f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include <optional>
 
 namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+inline bool isValidSMETileVectorType(Type type) {
+  auto vType = dyn_cast<VectorType>(type);
+  return vType && isValidSMETileVectorType(vType);
+}
+
 /// Returns the type of SME tile this vector type corresponds to, or none if the
 /// vector type does not fit within an SME tile.
 std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,19 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
 /// Creates a vector type for the SME tile of `elementType`.
 VectorType getSMETileTypeForElement(Type elementType);
 
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+                               FunctionOpInterface function);
+
+/// Returns true if `tileOp` can be cloned to resolve conflicts.
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns true if `tileOp` produces a tile result.
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns the tile `OpOperand` for this `tileOp` (or null).
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp);
+
 } // namespace mlir::arm_sme
 
 #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 1ba1b88fc123..488f747af050 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -245,6 +246,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
     if (!tileOp.isInMemoryTile())
       return failure();
 
+    tileOp->emitWarning(
+        "failed to allocate SME virtual tile to operation, all tile "
+        "operations will go through memory, expect degraded performance");
+
     // Step 1. Create an alloca for the tile at the top of the function (if one
     // does not already exist).
     auto loc = tileOp.getLoc();
@@ -391,20 +396,6 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
   (addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
 }
 
-struct GetTileConversion
-    : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
-                                          RequiresSpillsAndFills::No> {
-  using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
-        getTile, getTile.getTileType());
-    return success();
-  }
-};
-
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
@@ -436,7 +427,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
     // The base mask is just the mask to zero the first tile (of a size).
     // These masks are derived from:
     // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
-    arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
+    arm_sme::ArmSMETileType tileType =
+        *arm_sme::getSMETileType(zero.getTileType());
     auto baseMaskForSize = [&] {
       switch (tileType) {
       case arm_sme::ArmSMETileType::ZAB:
@@ -488,8 +480,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
         loc, rewriter.getI32IntegerAttr(zeroMask));
 
     // Create a placeholder op to preserve dataflow.
-    rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
-        zero, zero.getVectorType());
+    rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
 
     return success();
   }
@@ -746,10 +737,12 @@ struct OuterProductOpConversion
     auto loc = outerProductOp.getLoc();
 
     Value acc = outerProductOp.getAcc();
-    if (!acc)
+    if (!acc) {
       // Initalize accumulator with zero.
-      acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
-          rewriter, loc, resultVectorType);
+      auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+      zero.setTileId(tileId);
+      acc = zero;
+    }
 
     Value lhsMask = outerProductOp.getLhsMask();
     Value rhsMask = outerProductOp.getRhsMask();
@@ -791,25 +784,27 @@ struct OuterProductWideningOpConversion
     if (!tileId)
       return failure();
 
+    auto loc = op.getLoc();
     Value acc = op.getAcc();
-    if (!acc)
+    if (!acc) {
       // Initalize accumulator with zero.
-      acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
-          rewriter, op.getLoc(), op.getResultType());
+      auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+      zero.setTileId(tileId);
+      acc = zero;
+    }
 
     Value lhsMask = op.getLhsMask();
     Value rhsMask = op.getRhsMask();
     if (!lhsMask || !rhsMask) {
       auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
       Value allActiveMask = rewriter.create<arith::ConstantOp>(
-          op.getLoc(), DenseElementsAttr::get(predTy, true));
+          loc, DenseElementsAttr::get(predTy, true));
       lhsMask = allActiveMask;
       rhsMask = allActiveMask;
     }
 
-    rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
-                                                rhsMask, adaptor.getLhs(),
-                                                adaptor.getRhs());
+    rewriter.create<OuterProductWideningIntrOp>(
+        loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
 
     // The outerproduct intrinsics have no result, replace
     // 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -865,15 +860,22 @@ namespace {
 
 struct ConvertArmSMEToLLVMPass
     : public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+  ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+    this->dumpTileLiveRanges = dumpTileLiveRanges;
+  }
   void runOnOperation() override {
+    auto function = getOperation();
+
+    if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
+      return signalPassFailure();
+
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
     LLVMTypeConverter converter(&getContext());
     configureArmSMEToLLVMConversionLegality(target);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
-    if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+    if (failed(applyPartialConversion(function, target, std::move(patterns))))
       signalPassFailure();
   }
 };
@@ -883,7 +885,7 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
+      arm_sme::GetTileOp, arm_sme::CopyTileOp, arm_sme::aarch64_sme_zero,
       arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
       arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
       arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
@@ -955,9 +957,10 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                        arm_sme::aarch64_sme_usmopa_wide>,
       OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
                                        arm_sme::aarch64_sme_usmops_wide>,
-      ZeroOpConversion, GetTileConversion>(patterns, converter);
+      ZeroOpConversion>(patterns, converter);
 }
 
-std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
-  return std::make_unique<ConvertArmSMEToLLVMPass>();
+std::unique_ptr<Pass>
+mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+  return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
 }
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 16b61c282749..9f55932c33af 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -196,12 +196,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
       // Initialize tile with zero to satisfy padding. Inactive cols will be
       // zeroed anyway since the loads use zeroing predication. For inactive
       // rows however, no load will occur so these need to be zeroed.
-      initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
-          rewriter, loc, tileType);
+      initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
     } else {
-      // Allocate a new SME tile.
-      initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
-          rewriter, loc, tileType);
+      initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
     }
 
     // Create a loop to load the active tile slices from memory.
@@ -212,10 +209,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
             Value currentTile) -> Value {
           // Create 'arm_sme.load_tile_slice' to load tile slice from memory
           // into tile.
-          return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
-              rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
-              currentTile, memrefIndices, tileSliceIndex,
-              tileLoadOp.getLayout());
+          return rewriter.create<arm_sme::LoadTileSliceOp>(
+              loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
+              memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
         });
 
     if (failed(forOp))
@@ -292,9 +288,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
     auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
         loc, rewriter.getI32Type(), numCols);
 
-    // Allocate a new SME tile.
-    auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
-        rewriter, loc, tileType);
+    auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
 
     // Create a loop that loads each ZA tile slice from memory.
     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -339,10 +333,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
         /*passthru=*/pad1DOp);
 
     // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
-    auto moveSlice =
-        tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
-            rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
-            tileSliceIndex, tileLoadOp.getLayout());
+    auto moveSlice = rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+        loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
+        tileLoadOp.getLayout());
     rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
 
     rewriter.setInsertionPointAfter(forOp);
@@ -386,8 +379,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
         tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
         tileStoreOp.getMask(),
         [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
-          tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
-              rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+          rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+              tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
               predicate, tileStoreOp.getBase(), memrefIndices,
               tileStoreOp.getLayout());
         });
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 29fa9085a0a9..cb3a66584487 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -20,6 +20,12 @@
 using namespace mlir;
 using namespace mlir::arm_sme;
 
+namespace mlir::arm_sme::detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *op) {
+  return verifyOperationHasValidTileId(op);
+}
+} // namespace mlir::arm_sme::detail
+
 //===----------------------------------------------------------------------===//
 // Tablegen Definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 6a9e02218222..00d764bf5caf 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -116,4 +116,50 @@ VectorType getSMETileTypeForElement(Type elementType) {
   return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
 }
 
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+                               FunctionOpInterface function) {
+  SmallVector<Operation *> worklist;
+  function->walk([&](Operation *op) {
+    auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
+    if (armSMEOp && isOpTriviallyDead(armSMEOp))
+      worklist.push_back(armSMEOp);
+  });
+  while (!worklist.empty()) {
+    Operation *op = worklist.pop_back_val();
+    if (!isOpTriviallyDead(op))
+      continue;
+    for (Value value : op->getOperands()) {
+      if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
+        worklist.push_back(armSMEOp);
+    }
+    rewriter.eraseOp(op);
+  }
+}
+
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
+  return tileOp && tileOp->getNumResults() == 1 &&
+         tileOp->getNumOperands() == 0 && isPure(tileOp);
+}
+
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
+  for (Value result : tileOp->getResults()) {
+    if (arm_sme::isValidSMETileVectorType(result.getType()))
+      return true;
+  }
+  return false;
+}
+
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
+  auto isTileOperandType = [](OpOperand &operand) {
+    return arm_sme::isValidSMETileVectorType(operand.get().getType());
+  };
+  OpOperand *tileOperand =
+      llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
+  assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
+         "expected at most one tile operand");
+  if (tileOperand == tileOp->getOpOperands().end())
+    return nullptr;
+  return tileOperand;
+}
+
 } // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 4acb2a8fb7b5..e3cf52078a24 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -6,7 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This pass allocates SME tiles at the 'func.func' op level for ArmSME
+// This transform allocates SME tiles at the 'func.func' op level for ArmSME
 // operations. It does this using a 16-bit tile mask that has a bit for each
 // 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
 //
@@ -32,39 +32,31 @@
 //   ZA6.D   ZA6.Q, ZA14.Q
 //   ZA7.D   ZA7.Q, ZA15.Q
 //
-// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
-// that is initalized during the first tile allocation within a function and
-// updated on each subsequent allocation.
-//
 // [1] https://developer.arm.com/documentation/ddi0616/aa
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Analysis/Liveness.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/IntervalMap.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
 
-#define DEBUG_TYPE "allocate-arm-sme-tiles"
-
-namespace mlir {
-namespace arm_sme {
-#define GEN_PASS_DEF_TILEALLOCATION
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_TESTTILEALLOCATION
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
 
 using namespace mlir;
 using namespace mlir::arm_sme;
 
 namespace {
 
-static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
-static constexpr StringLiteral
-    kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
-
 enum class TileMask : unsigned {
   // clang-format off
   kZA0B  = 0xffff, // 1111 1111 1111 1111
@@ -137,172 +129,510 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
-                                          TileMask &tilesInUse) {
-  auto masks = getMasks(tileType);
-  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
-    if ((tilesInUse & tileMask) == TileMask::kNone) {
-      tilesInUse |= tileMask;
-      return tileId;
+class TileAllocator {
+public:
+  /// Allocates and returns a tile ID.
+  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+    auto masks = getMasks(tileType);
+    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+      if ((tilesInUse & tileMask) == TileMask::kNone) {
+        tilesInUse |= tileMask;
+        return tileId;
+      }
     }
+    return failure();
+  }
+
+  /// Releases a previously allocated tile ID.
+  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) != TileMask::kNone &&
+           "cannot release unallocated tile!");
+    tilesInUse ^= tileMask;
+  }
+
+  /// Allocates an in-memory tile ID.
+  unsigned allocateInMemoryTileId() {
+    // Note: We never release in-memory tile IDs. We could, which may allow
+    // reusing an allocation, but as we _never_ want to spill an SME tile this
+    // is not optimized.
+    return nextInMemoryTileId++;
   }
-  return failure();
-}
 
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
-                             SetVector<Operation *> &dependantOps) {
-  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
-    for (auto [idx, value] : llvm::enumerate(inputValues)) {
-      if (value == rootValue)
-        findDependantOps(exitValues[idx], dependantOps);
+private:
+  TileMask tilesInUse = TileMask::kNone;
+  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+// Add new intermediate blocks for the true and false destinations of a
+// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
+// branches.
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+  SmallVector<cf::CondBranchOp> worklist;
+  function.walk([&](cf::CondBranchOp condBranch) {
+    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+          return isValidSMETileVectorType(value.getType());
+        })) {
+      worklist.push_back(condBranch);
     }
+  });
+
+  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+    rewriter.setInsertionPointToEnd(source);
+    rewriter.create<cf::BranchOp>(loc, dest, args);
   };
-  for (Operation *user : rootValue.getUsers()) {
-    if (dependantOps.contains(user))
+
+  for (auto condBranch : worklist) {
+    auto loc = condBranch.getLoc();
+    Block *block = condBranch->getBlock();
+    auto newTrueBranch = rewriter.splitBlock(block, block->end());
+    auto newFalseBranch = rewriter.splitBlock(block, block->end());
+    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+               condBranch.getTrueDestOperands());
+    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+               condBranch.getFalseDestOperands());
+    condBranch.getFalseDestOperandsMutable().clear();
+    condBranch.getTrueDestOperandsMutable().clear();
+    condBranch.setSuccessor(newTrueBranch, 0);
+    condBranch.setSuccessor(newFalseBranch, 1);
+  }
+}
+
+/// Inserts tile copies `cf.br` operations.
+void insertCopiesAtBranches(IRRewriter &rewriter,
+                            FunctionOpInterface function) {
+  splitCondBranches(rewriter, function);
+  for (Block &block : function.getBlocks()) {
+    Operation *terminator = block.getTerminator();
+    if (!isa<cf::BranchOp>(terminator))
       continue;
-    dependantOps.insert(user);
-    TypeSwitch<Operation *>(user)
-        .Case<cf::BranchOp>([&](auto branchOp) {
-          // (CF) Follow branch.
-          traverseCorrespondingValues(branchOp.getDestOperands(),
-                                      branchOp.getDest()->getArguments());
-        })
-        .Case<cf::CondBranchOp>([&](auto condBranchOp) {
-          // (CF) Follow true branch.
-          traverseCorrespondingValues(
-              condBranchOp.getTrueOperands(),
-              condBranchOp.getTrueDest()->getArguments());
-          // (CF) Follow false branch.
-          traverseCorrespondingValues(
-              condBranchOp.getFalseOperands(),
-              condBranchOp.getFalseDest()->getArguments());
-        })
-        .Case<LoopLikeOpInterface>([&](auto loopOp) {
-          // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
-          traverseCorrespondingValues(loopOp.getInits(),
-                                      loopOp.getRegionIterArgs());
-        })
-        .Case<scf::YieldOp>([&](auto yieldOp) {
-          // (SCF) Follow yields of (basic) control flow (e.g. for loops).
-          auto parent = user->getParentOp();
-          traverseCorrespondingValues(user->getOperands(),
-                                      parent->getResults());
+    rewriter.setInsertionPoint(terminator);
+    for (OpOperand &operand : terminator->getOpOperands()) {
+      if (isValidSMETileVectorType(operand.get().getType())) {
+        auto copy =
+            rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+        operand.assign(copy);
+      }
+    }
+  }
+}
+
+/// A range where a tile value is live. The range may contain holes.
+struct LiveRange {
+  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+                                     llvm::IntervalMapHalfOpenInfo<unsigned>>;
+  using Allocator = RangeSet::Allocator;
+  static constexpr uint8_t kValidLiveRange = 0xff;
+
+  LiveRange(Allocator &allocator)
+      : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+  /// Returns true if this range overlaps with `otherRange`.
+  bool overlaps(LiveRange const &otherRange) const {
+    return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+                                                         *otherRange.ranges)
+        .valid();
+  }
+
+  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+  void unionWith(LiveRange const &otherRange) {
+    for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+         ++it)
+      ranges->insert(it.start(), it.stop(), kValidLiveRange);
+    values.set_union(otherRange.values);
+  }
+
+  /// Inserts an interval [start, end) for `value` into this range.
+  void insert(Value value, unsigned start, unsigned end) {
+    values.insert(value);
+    if (start != end)
+      ranges->insert(start, end, kValidLiveRange);
+  }
+
+  bool empty() const { return ranges->empty(); }
+  unsigned start() const { return ranges->start(); }
+  unsigned end() const { return ranges->stop(); }
+  bool operator<(LiveRange const &other) const {
+    return start() < other.start();
+  }
+
+  ArmSMETileType getTileType() const {
+    return *getSMETileType(cast<VectorType>(values[0].getType()));
+  }
+
+  std::unique_ptr<RangeSet> ranges;
+  SetVector<Value> values;
+  std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+  unsigned index = 0;
+  SetVector<Block *> blocks =
+      getTopologicallySortedBlocks(function.getFunctionBody());
+  DenseMap<Operation *, unsigned> operationToIndexMap;
+  for (Block *block : blocks) {
+    index++; // We want block args to have their own number.
+    for (Operation &op : block->getOperations()) {
+      // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+      op.walk([&](ArmSMETileOpInterface nestedOp) {
+        if (&op != nestedOp.getOperation()) {
+          assert(false &&
+                 "ArmSME tile allocation does not support nested regions");
+        }
+      });
+#endif
+      operationToIndexMap.try_emplace(&op, index++);
+    }
+  }
+  return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                     LiveRange::Allocator &liveRangeAllocator,
+                     Liveness &liveness, FunctionOpInterface function) {
+  DenseMap<Value, LiveRange> liveRanges;
+  auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
+                              LivenessBlockInfo const &livenessInfo,
+                              bool liveAtBlockEntry = false) {
+    if (!isValidSMETileVectorType(value.getType()))
+      return;
+    auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+    auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    unsigned start =
+        operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+    unsigned end = operationToIndexMap.at(lastUseInBlock);
+    it->second.insert(value, start, end);
+  };
+
+  for (Block &block : function.getBlocks()) {
+    LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+    // Handle block arguments:
+    for (Value argument : block.getArguments())
+      updateLiveRanges(argument, &block.front(), *livenessInfo,
+                       /*liveAtBlockEntry=*/true);
+    // Handle live-ins:
+    for (Value liveIn : livenessInfo->in())
+      updateLiveRanges(liveIn, &block.front(), *livenessInfo,
+                       /*liveAtBlockEntry=*/true);
+    // Handle new definitions:
+    for (Operation &op : block) {
+      for (Value result : op.getResults())
+        updateLiveRanges(result, &op, *livenessInfo);
+    }
+  }
+
+  return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+                                        function_ref<void(Value)> callback) {
+  Block *block = blockArg.getOwner();
+  unsigned argNumber = blockArg.getArgNumber();
+  for (Block *pred : block->getPredecessors()) {
+    TypeSwitch<Operation *>(pred->getTerminator())
+        .Case<cf::BranchOp>([&](auto branch) {
+          Value predecessorOperand = branch.getDestOperands()[argNumber];
+          callback(predecessorOperand);
         })
-        .Default([&](auto) {
-          // Otherwise, assume users of _any_ result are dependant.
-          for (Value result : user->getResults())
-            findDependantOps(result, dependantOps);
+        .Case<cf::CondBranchOp>([&](auto condBranch) {
+          if (condBranch.getFalseDest() == block) {
+            Value predecessorOperand =
+                condBranch.getFalseDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
+          if (condBranch.getTrueDest() == block) {
+            Value predecessorOperand =
+                condBranch.getTrueDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
         });
   }
 }
-struct AssignTileIDsPattern
-    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
-                                PatternRewriter &rewriter) const override {
-    if (tileOp.getTileId())
-      return failure();
-
-    auto func = tileOp->getParentOfType<FunctionOpInterface>();
-    auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
-      if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
-              func->getDiscardableAttr(name)))
-        return unsigned(attr.getInt());
-      return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+  DenseMap<Value, LiveRange *> liveRanges;
+  for (auto &[value, liveRange] : initialLiveRanges) {
+    liveRanges.insert({value, &liveRange});
+  }
+
+  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+    LiveRange *aLiveRange = liveRanges.at(a);
+    LiveRange *bLiveRange = liveRanges.at(b);
+    if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+      aLiveRange->unionWith(*bLiveRange);
+      for (Value value : bLiveRange->values)
+        liveRanges[value] = aLiveRange;
+    }
+  };
+
+  // Merge the live ranges of new definitions with their tile operands.
+  auto unifyDefinitionsWithOperands = [&](Value value) {
+    auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+    if (!armSMEOp)
+      return;
+    for (auto operand : armSMEOp->getOperands()) {
+      if (isValidSMETileVectorType(operand.getType()))
+        mergeValuesIfNonOverlapping(value, operand);
+    }
+  };
+
+  // Merge the live ranges of block arguments with their predecessors.
+  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+    auto blockArg = dyn_cast<BlockArgument>(value);
+    if (!blockArg)
+      return;
+    forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+      mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+    });
+  };
+
+  auto applyRule = [&](auto rule) {
+    llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+  };
+
+  // Unify as many live ranges as we can. This prevents unnecessary moves.
+  applyRule(unifyBlockArgumentsWithPredecessors);
+  applyRule(unifyDefinitionsWithOperands);
+
+  // Remove duplicate live range entries.
+  SetVector<LiveRange *> uniqueLiveRanges;
+  for (auto [_, liveRange] : liveRanges) {
+    if (!liveRange->empty())
+      uniqueLiveRanges.insert(liveRange);
+  }
+
+  // Sort the new live ranges by starting point (ready for tile allocation).
+  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+  std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+            [](LiveRange *a, LiveRange *b) { return *a < *b; });
+  return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+  TileAllocator tileAllocator;
+  SetVector<LiveRange *> allocatedRanges;
+
+  auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
+    unsigned memoryTileId = tileAllocator.allocateInMemoryTileId();
+    auto spillActiveRange = [&](LiveRange *range) {
+      unsigned tileId = *range->tileId;
+      range->tileId = memoryTileId;
+      allocatedRanges.remove(range);
+      return tileId;
     };
-    auto setDiscardableIntAttr = [&](StringRef name, auto value) {
-      rewriter.modifyOpInPlace(tileOp, [&] {
-        func->setDiscardableAttr(name,
-                                 rewriter.getI32IntegerAttr((unsigned)value));
-      });
+
+    auto isTrivialSpill = [](LiveRange *allocatedRange) {
+      return allocatedRange->values.size() == 1 &&
+             isTriviallyCloneableTileOp(
+                 allocatedRange->values[0]
+                     .getDefiningOp<ArmSMETileOpInterface>());
     };
 
-    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
-    if (!tileType)
-      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
-    TileMask tilesInUse =
-        static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
-    auto tileId = allocateTileId(*tileType, tilesInUse);
-    bool tileIsInMemory = failed(tileId);
-    if (tileIsInMemory) {
-      // If we could not find a real tile ID, use an in-memory tile ID (ID >=
-      // 16). A later pass will insert the necessary spills and reloads.
-      tileId =
-          getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
-      tileOp->emitWarning(
-          "failed to allocate SME virtual tile to operation, all tile "
-          "operations will go through memory, expect degraded performance");
-    }
+    // Heuristic: Spill trivially copyable operations (usually free).
+    if (isTrivialSpill(newRange))
+      return memoryTileId;
+    auto trivialSpill = llvm::find_if(allocatedRanges, isTrivialSpill);
+    if (trivialSpill != allocatedRanges.end())
+      return spillActiveRange(*trivialSpill);
+
+    // Heuristic: Spill the live range that ends last.
+    LiveRange *lastActiveLiveRange = *std::max_element(
+        allocatedRanges.begin(), allocatedRanges.end(),
+        [](LiveRange *a, LiveRange *b) { return a->end() < b->end(); });
+    if (lastActiveLiveRange->end() >= newRange->end())
+      return spillActiveRange(lastActiveLiveRange);
+
+    return memoryTileId;
+  };
 
-    // Set all operations dependent on `tileOp` to use the same tile ID.
-    // This is a naive tile allocation scheme, but works for common cases. For
-    // example, as this only allocates tile IDs to existing ops, it can't solve
-    // cases like this (%tileA and %tileB come from different root operations):
-    //
-    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
-    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
-    // } else {
-    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
-    // }
-    //
-    // This case would require allocating a new tile for the result of the
-    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
-    // on the %some_cond).
-    // Find all the ops that (transitively) depend on this tile.
-    SetVector<Operation *> dependantOps;
-    findDependantOps(tileOp->getResult(0), dependantOps);
-    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
-        auto currentTileId = dependantTileOp.getTileId();
-        if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
-          return dependantTileOp.emitOpError(
-              "already assigned different SME virtual tile!");
+  for (LiveRange *newRange : liveRanges) {
+    // Release tiles from live ranges that have ended.
+    allocatedRanges.remove_if([&](LiveRange *allocatedRange) {
+      if (allocatedRange->end() <= newRange->start()) {
+        tileAllocator.releaseTileId(allocatedRange->getTileType(),
+                                    *allocatedRange->tileId);
+        return true;
       }
-    }
+      return false;
+    });
 
-    // Rewrite IR.
-    if (!tileIsInMemory)
-      setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+    // Allocate a tile ID to `newRange`.
+    auto tileId = tileAllocator.allocateTileId(newRange->getTileType());
+    if (succeeded(tileId))
+      newRange->tileId = *tileId;
     else
-      setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
-    rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
-        rewriter.modifyOpInPlace(
-            dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
+      newRange->tileId = chooseSpillUsingHeuristics(newRange);
+
+    // Insert the live range into the allocated ranges.
+    if (newRange->tileId < kInMemoryTileIdBase)
+      allocatedRanges.insert(newRange);
+  }
+}
+
+/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
+LogicalResult assignTileIdsAndResolveTrivialConflicts(
+    IRRewriter &rewriter, FunctionOpInterface function,
+    ArrayRef<LiveRange *> allocatedLiveRanges) {
+  for (LiveRange const *liveRange : allocatedLiveRanges) {
+    auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
+    auto isAllocatedToSameTile = [&](Value value) {
+      if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+          tileOp && tileOp.getTileId() == tileIdAttr)
+        return true;
+      return liveRange->values.contains(value);
+    };
+    for (Value value : liveRange->values) {
+      for (Operation *user : value.getUsers()) {
+        if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+          // Ensure ArmSME ops that don't produce a value still get a tile ID.
+          if (!hasTileResult(tileOp))
+            tileOp.setTileId(tileIdAttr);
+        }
+      }
+      auto copyOp = value.getDefiningOp<CopyTileOp>();
+      if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
+        // Fold redundant copies.
+        rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
+      } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
+        tileOp.setTileId(tileIdAttr);
+        // Rectify operand tile IDs with result tile IDs.
+        OpOperand *tileOperand = getTileOpOperand(tileOp);
+        if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
+          continue;
+        auto operandTileOp =
+            tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
+        if (!isTriviallyCloneableTileOp(operandTileOp))
+          return tileOp.emitOpError("failed to rectify tile operand with tile "
+                                    "result (move required)");
+        // Cloning prevents a move/spill (though may require recomputation).
+        rewriter.setInsertionPoint(tileOp);
+        auto clonedOp = operandTileOp.clone();
+        clonedOp.setTileId(tileOp.getTileId());
+        rewriter.insert(clonedOp);
+        if (copyOp)
+          rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
+        else
+          tileOperand->assign(clonedOp->getResult(0));
+      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+        // Validate block arguments.
+        bool tileMismatch = false;
+        forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+          if (tileMismatch)
+            return;
+          if (!isAllocatedToSameTile(predecessorTile)) {
+            blockArg.getOwner()->getParentOp()->emitOpError(
+                "block argument not allocated to the same tile as "
+                "predecessors");
+            tileMismatch = true;
+          }
+        });
+        if (tileMismatch)
+          return failure();
       }
     }
+  }
+  return success();
+}
 
-    return success();
+/// Prints live ranges alongside operation names for debugging.
+void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                    ArrayRef<LiveRange const *> liveRanges,
+                    FunctionOpInterface function) {
+  llvm::errs() << "SME Tile Liveness: @" << function.getName()
+               << "\nKey:\nS - Start\nE - End\n| - Live\n";
+  for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
+    llvm::errs() << "^bb" << blockIdx << ":\n";
+    for (Operation &op : block.getOperations()) {
+      unsigned operationIndex = operationToIndexMap.at(&op);
+      for (LiveRange const *range : liveRanges) {
+        char liveness = ' ';
+        for (auto it = range->ranges->begin(); it != range->ranges->end();
+             ++it) {
+          if (it.start() == operationIndex)
+            liveness = (liveness == 'E' ? '|' : 'S');
+          else if (it.stop() == operationIndex)
+            liveness = (liveness == 'S' ? '|' : 'E');
+          else if (operationIndex >= it.start() && operationIndex < it.stop())
+            liveness = '|';
+        }
+        llvm::errs() << liveness;
+      }
+      llvm::errs() << ' ' << op.getName() << '\n';
+    }
   }
-};
+  llvm::errs() << "==========\n";
+}
 
-struct TileAllocationPass
-    : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
+struct TestTileAllocationPass
+    : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
+  using TestTileAllocationBase::TestTileAllocationBase;
   void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    patterns.add<AssignTileIDsPattern>(patterns.getContext());
-    GreedyRewriteConfig config;
-    // Setting useTopDownTraversal ensures tiles are allocated in program
-    // order.
-    config.useTopDownTraversal = true;
-    if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
-            getOperation(), std::move(patterns), config))) {
+    if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
       signalPassFailure();
-    }
   }
 };
 } // namespace
 
-std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
-  return std::make_unique<TileAllocationPass>();
+LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
+                                              bool dumpRanges) {
+  LiveRange::Allocator liveRangeAllocator;
+  IRRewriter rewriter(function.getContext());
+
+  // 1. Insert copy operations at branch operations.
+  insertCopiesAtBranches(rewriter, function);
+
+  // 2. Gather live ranges for each ArmSME tile within the function.
+  Liveness liveness(function);
+  auto operationToIndexMap = generateOperationNumbering(function);
+  auto initialLiveRanges = gatherTileLiveRanges(
+      operationToIndexMap, liveRangeAllocator, liveness, function);
+  if (initialLiveRanges.empty())
+    return success();
+
+  if (dumpRanges) {
+    // Wrangle initial live ranges into a form suitable for printing.
+    auto nonEmpty = llvm::make_filter_range(
+        llvm::make_second_range(initialLiveRanges),
+        [&](LiveRange const &liveRange) { return !liveRange.empty(); });
+    auto initialRanges = llvm::to_vector(llvm::map_range(
+        nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
+    std::sort(initialRanges.begin(), initialRanges.end(),
+              [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
+    llvm::errs() << "\n========== Initial Live Ranges:\n";
+    dumpLiveRanges(operationToIndexMap, initialRanges, function);
+  }
+
+  // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
+  // for tile allocation. E.g. Unify the result of an operation with its
+  // operands.
+  auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
+
+  if (dumpRanges) {
+    llvm::errs() << "\n========== Coalesced Live Ranges:\n";
+    dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
+  }
+
+  // 4. Allocate tile IDs to live ranges.
+  allocateTilesToLiveRanges(coalescedLiveRanges);
+
+  // 5. Assign the tile IDs back to the ArmSME operations.
+  if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
+                                                     coalescedLiveRanges))) {
+    return failure();
+  }
+
+  /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+  /// users). This prevents the LLVM conversion needlessly inserting spills.
+  eraseTriviallyDeadTileOps(rewriter, function);
+  return success();
 }
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index f48046a8d799..292294046021 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
-
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file -verify-diagnostics | FileCheck %s
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index a9c1a65a296f..2c3868d7f25c 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \
 // RUN: FileCheck %s  --check-prefix=AFTER-TILE-ALLOC
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \
 // RUN:   -split-input-file -verify-diagnostics | \
 // RUN: FileCheck %s  --check-prefix=AFTER-LLVM-LOWERING
 
@@ -56,6 +56,9 @@ func.func @use_too_many_tiles() {
   %1 = arm_sme.zero : vector<[4]x[4]xi32>
   // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %2 = arm_sme.zero : vector<[8]x[8]xi16>
+  "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 // AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles
@@ -131,18 +134,16 @@ func.func @use_too_many_tiles() {
 /// Note: In this example an entire tile swap is inserted before/after the
 /// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a
 /// single tile slice (and can omit the initial load, like in the previous example).
-func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
-  %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8>
+func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref<?x?xf32>) -> vector<[4]x[4]xf32> {
   %c0 = arith.constant 0 : index
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %mask = vector.constant_mask [4] : vector<[4]xi1>
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
-  "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> ()
+  "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> ()
+  return %loadSlice : vector<[4]x[4]xf32>
 }
 // AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
-//      AFTER-TILE-ALLOC: arm_sme.get_tile
-// AFTER-TILE-ALLOC-SAME:   tile_id = 0
 //      AFTER-TILE-ALLOC: arm_sme.load_tile_slice
 // AFTER-TILE-ALLOC-SAME:   tile_id = 16
 
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 15767ff1dec3..8ca71e6bf783 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" \
+// RUN: -split-input-file -allow-unregistered-dialect -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index e144bac970a7..eb64a7b6aac5 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
 
 // -----
 
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index b7ba3f728c70..d9e3d66e370e 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,18 +1,16 @@
 // RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
 
-// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
-// once it becomes unused, after lowering to control flow.
+// This tests that dead tile values are removed from control flow.
 
 // -----
 
-// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
-// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
 // CHECK-NOT: vector<[4]x[4]xf32>
-func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
   %c10 = arith.constant 10 : index
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
-  %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
 ^bb1(%1: index, %2: vector<[4]x[4]xf32>):  // 2 preds: ^bb0, ^bb2
   %3 = arith.cmpi slt, %1, %c10 : index
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
deleted file mode 100644
index 74e7293eaeca..000000000000
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
-// duplicates.
-
-// CHECK-LABEL: @zero_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @zero_tile() {
-  %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
-  %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
-  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
-  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
-  return
-}
-
-// CHECK-LABEL: @get_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @get_tile() {
-  %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
-  "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
-  "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
-  return
-}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 39d9ab6491e3..06be7bd97470 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,19 +1,17 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics
 
 // -----
 
-func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %cond: i1) {
+// Select between tileA and tileB. This is currently unsupported as it would
+// require inserting (runtime) tile moves.
+func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %tileA : vector<[4]x[4]xi32>, %tileB : vector<[4]x[4]xi32>, %cond: i1) {
   %c0 = arith.constant 0 : index
-  %tileA = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %tileB = arm_sme.get_tile : vector<[4]x[4]xi32>
-  // Select between tileA and tileB. This is currently unsupported as it would
-  // require inserting tile move operations during tile allocation.
+  // expected-error at +1 {{op failed to rectify tile operand with tile result (move required)}}
   %tile = scf.if %cond -> vector<[4]x[4]xi32> {
     scf.yield %tileA : vector<[4]x[4]xi32>
   } else {
     scf.yield %tileB : vector<[4]x[4]xi32>
   }
-  // expected-error at +1 {{op already assigned different SME virtual tile!}}
   arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 2dedcb2fbc24..dd2ee55d2afc 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -1,18 +1,24 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK-BAD
-
-// This file tests some aspects of liveness issues in the SME tile allocator.
-// These tests were designed with a new liveness-based tile allocator in mind
-// (where the names of test cases make more sense), with the current tile
-// allocator these tests all give incorrect results (which is documented by
-// `CHECK-BAD`).
-
-// Incorrect result! The second `move_vector_to_tile_slice` overwrites the first (which is still live).
-//
-// CHECK-BAD-LABEL: @constant_with_multiple_users
-// CHECK-BAD: %[[ZERO_TILE:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation=dump-tile-live-ranges -mlir-disable-threading -split-input-file -verify-diagnostics 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-LIVE-RANGE
+
+// This file tests some simple aspects of using liveness in the SME tile allocator.
+
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @constant_with_multiple_users
+//       CHECK-LIVE-RANGE: ^bb0:
+//       CHECK-LIVE-RANGE: S  arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
+
+// CHECK-LABEL: @constant_with_multiple_users(
+// CHECK-SAME:                                %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
 func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
+  // CHECK-NEXT: %[[ZERO_TILE_0:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[ZERO_TILE_1:.*]] = arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_A]], %[[ZERO_TILE_1]], %{{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK-NEXT: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_B]], %[[ZERO_TILE_0]], %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
   %zero = arm_sme.zero : vector<[4]x[4]xf32>
   %tile_a = arm_sme.move_vector_to_tile_slice %a, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile_b = arm_sme.move_vector_to_tile_slice %b, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
@@ -23,12 +29,16 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
 
 // -----
 
-// (No tile IDs -- the current tile allocator ignores this case)
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @value_with_multiple_users
+//       CHECK-LIVE-RANGE: ^bb0:
+//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
-// CHECK-BAD-LABEL: @value_with_multiple_users
-// CHECK-BAD-NOT: tile_id
 func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
-  // A future allocator should error here (as `%tile` would need to be copied).
+  // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
   %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
@@ -38,12 +48,38 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
 
 // -----
 
-// CHECK-BAD-LABEL: @reuse_tiles_after_initial_use
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @reuse_tiles_after_initial_use
+//       CHECK-LIVE-RANGE: ^bb0:
+//  CHECK-LIVE-RANGE-NEXT: S        arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: |S       arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: ||S      arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: |||S     arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//  CHECK-LIVE-RANGE-NEXT: E|||     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:  E||     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:   E|     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:    E     test.some_use
+//  CHECK-LIVE-RANGE-NEXT:     S    arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     |S   arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     ||S  arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     |||S arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//  CHECK-LIVE-RANGE-NEXT:     E||| test.some_use
+//  CHECK-LIVE-RANGE-NEXT:      E|| test.some_use
+//  CHECK-LIVE-RANGE-NEXT:       E| test.some_use
+//  CHECK-LIVE-RANGE-NEXT:        E test.some_use
+
+// CHECK-LABEL: @reuse_tiles_after_initial_use
 func.func @reuse_tiles_after_initial_use() {
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 0 : i32}
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 2 : i32}
-  // CHECK-BAD: arm_sme.get_tile {tile_id = 3 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 2 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = 3 : i32}
   %tile_a = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_b = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_c = arm_sme.get_tile : vector<[4]x[4]xf32>
@@ -55,19 +91,13 @@ func.func @reuse_tiles_after_initial_use() {
   "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
   "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
   "test.some_use"(%tile_d) : (vector<[4]x[4]xf32>) -> ()
-  // -> Spills after the fourth tile (unnecessary):
-  // CHECK-BAD: arm_sme.zero {tile_id = 16 : i32}
-  // CHECK-BAD: arm_sme.zero {tile_id = 17 : i32}
-  // CHECK-BAD: arm_sme.zero {tile_id = 18 : i32}
-  // CHECK-BAD: arm_sme.zero {tile_id = 19 : i32}
-  // Unnecessary spills:
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // CHECK: arm_sme.zero {tile_id = 0 : i32}
+  // CHECK: arm_sme.zero {tile_id = 1 : i32}
+  // CHECK: arm_sme.zero {tile_id = 2 : i32}
+  // CHECK: arm_sme.zero {tile_id = 3 : i32}
   %tile_1 = arm_sme.zero : vector<[4]x[4]xf32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_2 = arm_sme.zero : vector<[4]x[4]xf32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_3 = arm_sme.zero : vector<[4]x[4]xf32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_4 = arm_sme.zero : vector<[4]x[4]xf32>
   "test.dummy"(): () -> ()
   "test.dummy"(): () -> ()
@@ -81,16 +111,27 @@ func.func @reuse_tiles_after_initial_use() {
 
 // -----
 
-// Incorrect result! Both branches should yield the result via the same tile.
-//
-// CHECK-BAD-LABEL: @non_overlapping_branches
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @non_overlapping_branches
+//       CHECK-LIVE-RANGE: ^bb1:
+//  CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: E cf.br
+//  CHECK-LIVE-RANGE-NEXT: ^bb2:
+//  CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: E cf.br
+
+// CHECK-LABEL: @non_overlapping_branches
 func.func @non_overlapping_branches(%cond: i1) {
+  // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
   %tile = scf.if %cond -> vector<[4]x[4]xf32> {
+    // ^bb1:
     %zero = arm_sme.zero : vector<[4]x[4]xf32>
     scf.yield %zero : vector<[4]x[4]xf32>
   } else {
+    // ^bb2:
     %undef = arm_sme.get_tile : vector<[4]x[4]xf32>
     scf.yield %undef : vector<[4]x[4]xf32>
   }
@@ -100,13 +141,15 @@ func.func @non_overlapping_branches(%cond: i1) {
 
 // -----
 
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @constant_loop_init_with_multiple_users
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// <deliberately omitted>
+
+// CHECK-LABEL: @constant_loop_init_with_multiple_users
 func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
+  // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+  // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+  // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
@@ -126,26 +169,46 @@ func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vecto
 
 // -----
 
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @run_out_of_tiles_but_avoid_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @run_out_of_tiles_but_avoid_spill
+//       CHECK-LIVE-RANGE: ^bb2:
+//  CHECK-LIVE-RANGE-NEXT: |S    arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: ||S   arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: |||S  arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+//  CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+
+// Note in the live ranges (above) there is five tile values, but we only have four tiles.
+
+// CHECK-LABEL: @run_out_of_tiles_but_avoid_spill
 func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
   %init = arm_sme.zero : vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
+  // Live = %init
   scf.for %i = %c0 to %c10 step %c1 {
+    // CHECK: arm_sme.zero {tile_id = 1 : i32}
+    // CHECK: arm_sme.zero {tile_id = 2 : i32}
+    // CHECK: arm_sme.zero {tile_id = 3 : i32}
+    // CHECK: arm_sme.zero {tile_id = 0 : i32}
     %tile_a, %tile_b, %tile_c, %tile_d = scf.for %j = %c0 to %c10 step %c1
       iter_args(%iter_a = %init, %iter_b = %init, %iter_c = %init, %iter_d = %init)
         -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32> , vector<[4]x[4]xf32> , vector<[4]x[4]xf32>) {
+        // ^bb2:
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 2 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 3 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+        // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_a = arm_sme.move_vector_to_tile_slice %a, %iter_a, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_b = arm_sme.move_vector_to_tile_slice %b, %iter_b, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_c = arm_sme.move_vector_to_tile_slice %c, %iter_c, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         %new_d = arm_sme.move_vector_to_tile_slice %d, %iter_d, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
         scf.yield %new_a, %new_b, %new_c, %new_d : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
     }
+    // Live = %init, %tile_a, %tile_b, %tile_c, %tile_d (out of tiles!)
+    // This should be resolved by duplicating the arm_sme.zero (from folding
+    // arm_sme.copy_tile operations inserted by the tile allocator).
     "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
@@ -156,24 +219,47 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
 
 // -----
 
-// Incorrect result! Everything other than zero assigned to tile 1 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @avoidable_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32}
+// We should be able to avoid spills like this, but logic handling this case is
+// not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
+
+//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//  CHECK-LIVE-RANGE-NEXT: @avoidable_spill
+//       CHECK-LIVE-RANGE: ^bb2:
+//  CHECK-LIVE-RANGE-NEXT: ||     test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||S    arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |||S   arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: ||||S  arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+//  CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+//  CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||  E| test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||   E test.some_use
+//  CHECK-LIVE-RANGE-NEXT: ||     arith.addi
+//  CHECK-LIVE-RANGE-NEXT: EE     cf.br
+
+// Note in the live ranges (above) there is two constant live-ins (first two ranges),
+// which gives six overlapping live ranges. The allocator currently will spill the
+// first constant (which results in a real spill at it's use), however, this could
+// be avoided by using the knowledge that at the first "test.some_use" there's
+// actually only two live ranges (so we can fix this be duplicating the constant).
+
+// CHECK-LABEL: @avoidable_spill
 func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
+  // CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[4]x[4]xf32>
   %zero = arm_sme.zero : vector<[4]x[4]xf32>
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
   scf.for %i = %c0 to %c10 step %c1 {
+    // So spilled here (unnecessarily).
+    // The arm_sme.zero op could be moved into the loop to avoid this.
     "test.some_use"(%zero) : (vector<[4]x[4]xf32>) -> ()
     %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
     %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
     %tile_c = arm_sme.move_vector_to_tile_slice %c, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
     %tile_d = arm_sme.move_vector_to_tile_slice %d, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
+    // %zero is still live here (due the the backedge)
     "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
     "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index cac2dcc24d10..ca339be5fb56 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,canonicalize))" | FileCheck %s
 
 // This test verifies the tile mask operand of the zero intrinsic zeroes
 // the correct tiles. Both integer and floating-point datatypes are checked.
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index 588b44a36c29..14d9712e971a 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -13,26 +13,29 @@
 /// performance (hence the warning).
 func.func @use_too_many_tiles(%a: memref<?x?xi16>, %b:  memref<?x?xi16>, %c: memref<?x?xi16>) {
   %c0 = arith.constant 0 : index
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16>
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16>
   // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
 
   // CHECK-LABEL: tile_a:
   // CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0
   vector.print str "tile_a:\n"
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   vector.print %tile_a : vector<[8]x[8]xi16>
   // CHECK-LABEL: tile_b:
   // CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1
   vector.print str "tile_b:\n"
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   vector.print %tile_b : vector<[8]x[8]xi16>
   // CHECK-LABEL: tile_c:
   // CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2
   vector.print str "tile_c:\n"
+  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
   vector.print %tile_c : vector<[8]x[8]xi16>
   // CHECK-LABEL: tile_d:
   // CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
index 1794564a6a72..0648e771b889 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
@@ -1,5 +1,6 @@
 // DEFINE: %{entry_point} = main
-// DEFINE: %{compile} = mlir-opt %s -convert-arm-sme-to-llvm -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm),test-lower-to-llvm)"
 // DEFINE: %{run} = %mcr_aarch64_cmd \
 // DEFINE:  -march=aarch64 -mattr=+sve,+sme \
 // DEFINE:  -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
index e942c7b8ac05..cdd8afe14142 100644
--- a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_library(MLIRArmSMETestPasses
   MLIRTransforms
   MLIRVectorToArmSME
   MLIRVectorToSCF
+  MLIRSCFToControlFlow
   )
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index 48d4a5859f8a..d3dabaf200fd 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -14,10 +14,12 @@
 #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
 #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
 #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/DialectRegistry.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
@@ -34,6 +36,10 @@ struct TestLowerToArmSMEOptions
       llvm::cl::desc("Fuse outer product operations via "
                      "'-arm-sme-outer-product-fusion' pass"),
       llvm::cl::init(true)};
+  PassOptions::Option<bool> dumpTileLiveRanges{
+      *this, "dump-tile-live-ranges",
+      llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
+      llvm::cl::init(false)};
 };
 
 void buildTestLowerToArmSME(OpPassManager &pm,
@@ -65,20 +71,17 @@ void buildTestLowerToArmSME(OpPassManager &pm,
   pm.addPass(createConvertVectorToSCFPass(
       VectorTransferToSCFOptions().enableFullUnroll()));
 
-  // Allocate tiles for ArmSME operations.
-  //
-  // Later passes may create further ArmSME ops that implement the
-  // ArmSMETileOpInterface, but tiles are allocated for root operations,
-  // all of which should now exist.
-  pm.addPass(arm_sme::createTileAllocationPass());
-
   // Enable streaming-mode and ZA.
   pm.addPass(arm_sme::createEnableArmStreamingPass(
       arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
       /*onlyIfRequiredByOps=*/true));
 
+  // Convert SCF to CF (required for ArmSME tile allocation).
+  pm.addPass(createConvertSCFToCFPass());
+
   // Convert ArmSME to LLVM.
-  pm.addPass(createConvertArmSMEToLLVMPass());
+  pm.addNestedPass<func::FuncOp>(
+      createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
 
   // Sprinkle some cleanups.
   pm.addPass(createCanonicalizerPass());

>From 3e52199b1643f6eb79a159eea158744508b69759 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 11:43:11 +0000
Subject: [PATCH 02/13] Review fixups

---
 .../mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h        |  4 ----
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp                   |  4 ++--
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp  | 10 ++++------
 mlir/test/Dialect/ArmSME/roundtrip.mlir                |  9 +++++++++
 4 files changed, 15 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
index f31062d8c25e..0c26fa69c858 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -5,10 +5,6 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-//
-// This file declares the Target dialect for ArmSME in MLIR.
-//
-//===----------------------------------------------------------------------===//
 
 #ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
 #define MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 00d764bf5caf..6c0ebddb5a2d 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -153,10 +153,10 @@ OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
   auto isTileOperandType = [](OpOperand &operand) {
     return arm_sme::isValidSMETileVectorType(operand.get().getType());
   };
-  OpOperand *tileOperand =
-      llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
   assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
          "expected at most one tile operand");
+  OpOperand *tileOperand =
+      llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
   if (tileOperand == tileOp->getOpOperands().end())
     return nullptr;
   return tileOperand;
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e3cf52078a24..7b3381652e28 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -131,7 +131,7 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
 
 class TileAllocator {
 public:
-  /// Allocates and returns a tile ID.
+  /// Allocates and returns a tile ID. Fails if there are no tiles left.
   FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
     auto masks = getMasks(tileType);
     for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
@@ -198,7 +198,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   }
 }
 
-/// Inserts tile copies `cf.br` operations.
+/// Inserts tile copies at `cf.br` operations.
 void insertCopiesAtBranches(IRRewriter &rewriter,
                             FunctionOpInterface function) {
   splitCondBranches(rewriter, function);
@@ -278,10 +278,8 @@ generateOperationNumbering(FunctionOpInterface function) {
       // This is only correct if all ArmSME have been converted to CF.
 #ifndef NDEBUG
       op.walk([&](ArmSMETileOpInterface nestedOp) {
-        if (&op != nestedOp.getOperation()) {
-          assert(false &&
-                 "ArmSME tile allocation does not support nested regions");
-        }
+        assert(&op == nestedOp.getOperation() &&
+               "ArmSME tile allocation does not support nested regions");
       });
 #endif
       operationToIndexMap.try_emplace(&op, index++);
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ab46c7adca59..6095fdc11ead 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1403,3 +1403,12 @@ func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
   %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
   return %reuslt : vector<[2]x[2]xi64>
 }
+
+//===----------------------------------------------------------------------===//
+// arm_sme.copy_tile
+//===----------------------------------------------------------------------===//
+
+func.func @arm_sme_copy_tile(%vec: vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+  %result = arm_sme.copy_tile %vec : vector<[4]x[4]xf32>
+  return %result : vector<[4]x[4]xf32>
+}

>From fc27a9b3029b77384b29346812aa19d69897fc3c Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 14:23:57 +0000
Subject: [PATCH 03/13] More review fixups

---
 mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 6 +++---
 mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir    | 3 +--
 mlir/test/Dialect/ArmSME/canonicalize.mlir            | 4 +---
 3 files changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 7b3381652e28..bcc6deeabd16 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -414,7 +414,7 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
   return std::move(coalescedLiveRanges);
 }
 
-/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
 /// Note: This does not attempt to fill holes in live/allocated ranges.
 void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
   TileAllocator tileAllocator;
@@ -629,8 +629,8 @@ LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
     return failure();
   }
 
-  /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
-  /// users). This prevents the LLVM conversion needlessly inserting spills.
+  // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+  // users). This prevents the LLVM conversion needlessly inserting spills.
   eraseTriviallyDeadTileOps(rewriter, function);
   return success();
 }
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 8ca71e6bf783..a62ca080ab8d 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" \
-// RUN: -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics
 
 //===----------------------------------------------------------------------===//
 // arm_sme.outerproduct
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index d9e3d66e370e..643dfd4a7cbd 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,9 +1,7 @@
-// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s
 
 // This tests that dead tile values are removed from control flow.
 
-// -----
-
 // CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
 // CHECK-NOT: vector<[4]x[4]xf32>
 func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {

>From 3f90f8327408e23838b4c1538ab1bca1a15e874a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 17:04:15 +0000
Subject: [PATCH 04/13] More fixups

---
 .../Dialect/ArmSME/IR/ArmSMEOpInterfaces.h    |   3 +
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |   6 +-
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  52 +++---
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         |   2 +-
 .../Dialect/ArmSME/basic-tile-allocation.mlir |   2 +-
 .../ArmSME/tile-allocation-invalid.mlir       |   4 +-
 .../ArmSME/tile-allocation-liveness.mlir      | 160 ++++++++++--------
 7 files changed, 123 insertions(+), 106 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
index 0c26fa69c858..9153fbb57ea8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -17,7 +17,10 @@ namespace detail {
 LogicalResult verifyArmSMETileOpInterface(Operation *);
 }
 
+// The first in-memory SME tile ID. This is set to 16 as that is the first tile
+// ID larger than any virtual tile ID supported by the SME ISA.
 static constexpr unsigned kInMemoryTileIdBase = 16;
+
 #include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
 } // namespace mlir::arm_sme
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index b9d74fec6756..8c129ea623b6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -126,20 +126,20 @@ def EnableArmStreaming
 
 def TestTileAllocation
     : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
-  let summary = "Tests SME tile allocation";
+  let summary = "Tests SME 'virtual tile' allocation";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
     'func.func' op level, and assigns tile IDs (via an attribute) to all ops
     that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
     to be used for testing, tile allocation is done as part of the ArmSME to
-    LLVM conversion.
+    LLVM conversion (`convert-arm-sme-to-llvm`).
   }];
   let options = [
     Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
            "bool", /*default=*/"false",
            "Dump the live ranges of SME tiles (for debugging)">
   ];
-  let dependentDialects = ["func::FuncDialect"];
+  let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
 }
 
 def OuterProductFusion
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 488f747af050..562c58b73e7e 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -885,34 +885,38 @@ struct ConvertArmSMEToLLVMPass
 void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
-      arm_sme::GetTileOp, arm_sme::CopyTileOp, arm_sme::aarch64_sme_zero,
-      arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
-      arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
-      arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
-      arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
-      arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
-      arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
-      arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
-      arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
-      arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
-      arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
-      arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
-      arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
-      arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
-      arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
-      arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
-      arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
-      arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
-      arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
-      arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
-      arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
-      arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
-      arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
+      arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
+      arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
+      arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
+      arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
+      arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+      arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
+      arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
+      arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
+      arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
+      arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
+      arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
+      arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
+      arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
+      arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
+      arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
+      arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
+      arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
+      arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
+      arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
+      arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
+      arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
+      arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
+      arm_sme::aarch64_sme_cntsd>();
   target.addLegalDialect<arith::ArithDialect,
                          /* The following are used to lower tile spills/fills */
                          vector::VectorDialect, scf::SCFDialect,
                          memref::MemRefDialect>();
-  target.addLegalOp<UnrealizedConversionCastOp>();
+  // Pseudo operations. These cannot be code-generated but may exist in the
+  // input IR, or be generated during the conversion. They need to be eliminated
+  // before the final conversion to LLVM IR (and likely will be due to DCE).
+  target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
+                    UnrealizedConversionCastOp>();
 }
 
 void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 292294046021..14b1f323da3a 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file | FileCheck %s
 // Test conversion of ArmSME ops to LLVM intrinsics.
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index eb64a7b6aac5..8b46998d56b0 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s
 
 // -----
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 06be7bd97470..b3112264cba9 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics
-
-// -----
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics
 
 // Select between tileA and tileB. This is currently unsupported as it would
 // require inserting (runtime) tile moves.
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index dd2ee55d2afc..a63312ded1b9 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -3,14 +3,14 @@
 
 // This file tests some simple aspects of using liveness in the SME tile allocator.
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @constant_with_multiple_users
-//       CHECK-LIVE-RANGE: ^bb0:
-//       CHECK-LIVE-RANGE: S  arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
-//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
+//  CHECK-LIVE-RANGE-LABEL: @constant_with_multiple_users
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//        CHECK-LIVE-RANGE: S  arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
 // CHECK-LABEL: @constant_with_multiple_users(
 // CHECK-SAME:                                %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
@@ -29,13 +29,13 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @value_with_multiple_users
-//       CHECK-LIVE-RANGE: ^bb0:
-//  CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |E test.some_use
-//  CHECK-LIVE-RANGE-NEXT: E  test.some_use
+//  CHECK-LIVE-RANGE-LABEL: @value_with_multiple_users
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//   CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
 func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
   // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
@@ -48,31 +48,31 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @reuse_tiles_after_initial_use
-//       CHECK-LIVE-RANGE: ^bb0:
-//  CHECK-LIVE-RANGE-NEXT: S        arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: |S       arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: ||S      arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: |||S     arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
-//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
-//  CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
-//  CHECK-LIVE-RANGE-NEXT: E|||     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:  E||     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:   E|     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:    E     test.some_use
-//  CHECK-LIVE-RANGE-NEXT:     S    arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     |S   arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     ||S  arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     |||S arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
-//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
-//  CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
-//  CHECK-LIVE-RANGE-NEXT:     E||| test.some_use
-//  CHECK-LIVE-RANGE-NEXT:      E|| test.some_use
-//  CHECK-LIVE-RANGE-NEXT:       E| test.some_use
-//  CHECK-LIVE-RANGE-NEXT:        E test.some_use
+//  CHECK-LIVE-RANGE-LABEL: @reuse_tiles_after_initial_use
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//   CHECK-LIVE-RANGE-NEXT: S        arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: |S       arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: ||S      arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: |||S     arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//   CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//   CHECK-LIVE-RANGE-NEXT: ||||     test.dummy
+//   CHECK-LIVE-RANGE-NEXT: E|||     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:  E||     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:   E|     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:    E     test.some_use
+//   CHECK-LIVE-RANGE-NEXT:     S    arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     |S   arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     ||S  arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     |||S arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//   CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//   CHECK-LIVE-RANGE-NEXT:     |||| test.dummy
+//   CHECK-LIVE-RANGE-NEXT:     E||| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:      E|| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:       E| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:        E test.some_use
 
 // CHECK-LABEL: @reuse_tiles_after_initial_use
 func.func @reuse_tiles_after_initial_use() {
@@ -111,16 +111,16 @@ func.func @reuse_tiles_after_initial_use() {
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @non_overlapping_branches
-//       CHECK-LIVE-RANGE: ^bb1:
-//  CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
-//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: E cf.br
-//  CHECK-LIVE-RANGE-NEXT: ^bb2:
-//  CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
-//  CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: E cf.br
+//  CHECK-LIVE-RANGE-LABEL: @non_overlapping_branches
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb1:
+//   CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+//   CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: E cf.br
+//   CHECK-LIVE-RANGE-NEXT: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: E cf.br
 
 // CHECK-LABEL: @non_overlapping_branches
 func.func @non_overlapping_branches(%cond: i1) {
@@ -141,8 +141,20 @@ func.func @non_overlapping_branches(%cond: i1) {
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// <deliberately omitted>
+// Here %vecA and %vecB are not merged into the same live range (as they are unknown values).
+// This means that %vecA and %vecB are both allocated to different tiles (which is not legal).
+func.func @overlapping_branches(%cond: i1, %vecA: vector<[4]x[4]xf32>, %vecB: vector<[4]x[4]xf32>) {
+  // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
+  %tile = scf.if %cond -> vector<[4]x[4]xf32> {
+    scf.yield %vecA : vector<[4]x[4]xf32>
+  } else {
+    scf.yield %vecB : vector<[4]x[4]xf32>
+  }
+  "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
 
 // CHECK-LABEL: @constant_loop_init_with_multiple_users
 func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
@@ -169,14 +181,14 @@ func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vecto
 
 // -----
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @run_out_of_tiles_but_avoid_spill
-//       CHECK-LIVE-RANGE: ^bb2:
-//  CHECK-LIVE-RANGE-NEXT: |S    arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: ||S   arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: |||S  arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
-//  CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+//  CHECK-LIVE-RANGE-LABEL: @run_out_of_tiles_but_avoid_spill
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: |S    arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: ||S   arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: |||S  arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
 
 // Note in the live ranges (above) there is five tile values, but we only have four tiles.
 
@@ -222,20 +234,20 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
 // We should be able to avoid spills like this, but logic handling this case is
 // not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
 
-//       CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-//  CHECK-LIVE-RANGE-NEXT: @avoidable_spill
-//       CHECK-LIVE-RANGE: ^bb2:
-//  CHECK-LIVE-RANGE-NEXT: ||     test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||S    arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |||S   arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: ||||S  arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
-//  CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
-//  CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||  E| test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||   E test.some_use
-//  CHECK-LIVE-RANGE-NEXT: ||     arith.addi
-//  CHECK-LIVE-RANGE-NEXT: EE     cf.br
+//  CHECK-LIVE-RANGE-LABEL: @avoidable_spill
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: ||     test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||S    arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |||S   arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: ||||S  arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+//   CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+//   CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||  E| test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||   E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: ||     arith.addi
+//   CHECK-LIVE-RANGE-NEXT: EE     cf.br
 
 // Note in the live ranges (above) there is two constant live-ins (first two ranges),
 // which gives six overlapping live ranges. The allocator currently will spill the

>From ab2b92e325b143dd02dbdd44e1d8592d258f0d30 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 2 May 2024 15:44:52 +0000
Subject: [PATCH 05/13] Add test for tile copies

---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |  5 +-
 .../ArmSME/Transforms/TileAllocation.cpp      |  7 +-
 .../ArmSME/tile-allocation-copies.mlir        | 82 +++++++++++++++++++
 3 files changed, 92 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8c129ea623b6..81730166d3c0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -137,7 +137,10 @@ def TestTileAllocation
   let options = [
     Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
            "bool", /*default=*/"false",
-           "Dump the live ranges of SME tiles (for debugging)">
+           "Dump the live ranges of SME tiles (for debugging)">,
+    Option<"tileCopiesOnly", "tile-copies-only", "bool", /*default=*/"false",
+           "Only insert tile copies needed for tile allocation "
+           "(but do not allocate any tiles)">
   ];
   let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index bcc6deeabd16..ce8f34152475 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -575,7 +575,12 @@ struct TestTileAllocationPass
     : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
   using TestTileAllocationBase::TestTileAllocationBase;
   void runOnOperation() override {
-    if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
+    FunctionOpInterface function = getOperation();
+    if (tileCopiesOnly) {
+      IRRewriter rewriter(function);
+      return insertCopiesAtBranches(rewriter, function);
+    }
+    if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
       signalPassFailure();
   }
 };
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
new file mode 100644
index 000000000000..5e43efe93b55
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation=tile-copies-only -split-input-file | FileCheck %s
+
+// This file tests the inserting copies for the SME tile allocation. Copies are
+// inserted at `cf.br` ops (the predecessors to block arguments). Conditional
+// branches are split to prevent conflicts (see cond_br_with_backedge).
+
+// CHECK-LABEL: func.func @simple_branch(
+//  CHECK-SAME:   %[[TILE:.*]]: vector<[4]x[4]xf32>)
+//   %[[COPY:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+//   cf.br ^bb1(%[[COPY]] : vector<[4]x[4]xf32>)
+// ^bb1(%[[BLOCK_ARG:.*]]: vector<[4]x[4]xf32>):
+
+func.func @simple_branch(%tile : vector<[4]x[4]xf32>) {
+  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+  return
+}
+
+// -----
+
+// Note: The ^POINTLESS_SHIM_FOR_BB2 block is added as the cond_br splitting does
+// not check if it needs to insert a copy or not (there is no harm in the empty
+// block though -- it will fold away later).
+
+// CHECK-LABEL: func.func @cond_branch(
+//  CHECK-SAME:   %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32>
+//       CHECK:   cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[POINTLESS_SHIM_FOR_BB2:[[:alnum:]]+]]
+//       CHECK: ^[[POINTLESS_SHIM_FOR_BB2]]:
+//       CHECK:   cf.br ^[[BB2:.*]]
+//       CHECK: ^[[BB1_COPIES]]:
+//       CHECK:   arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+//       CHECK:   cf.br ^[[BB1:.*]]
+func.func @cond_branch(%cond: i1, %tile: vector<[4]x[4]xf32>) {
+  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+  return
+^bb2:
+  return
+}
+
+// -----
+
+// Reduction of a real world example that shows why we must split conditional branches.
+
+// CHECK-LABEL: @cond_branch_with_backedge(
+//  CHECK-SAME:    %{{[[:alnum:]]+}}: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+//  CHECK-SAME:    %[[TILEC:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILED:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+//       CHECK: ^bb1(%[[CURRENT_INDEX:.*]]: index, %[[ITER_TILE:.*]]: vector<[4]x[4]xf32>):
+//       CHECK:   %[[CONTINUE_LOOP:.*]] = arith.cmpi
+//       CHECK:   cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]]
+//       CHECK: ^[[BB3_COPIES]]:
+//  CHECK-NEXT:   arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   cf.br ^[[BB3:[[:alnum:]]+]]
+//       CHECK: ^[[BB3]](%{{.*}}: vector<[4]x[4]xf32>):
+//  CHECK-NEXT:   return
+
+func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector<[4]x[4]xf32>, %tileC: vector<[4]x[4]xf32>, %tileD: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  // Live here: %tileA, %tileB, %tileC, %tileD
+  cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+  %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  // %iterTile dies at the `cf.cond_br`, but %tileB, %tileC, %tileD are live out (in the ^bb2 case).
+  // If we inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
+  // However, note that the copies are only needed if we take the ^bb3 path. So, if we add
+  // a new block along that path we can insert the copies without any conflicts.
+  cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+  %nextIndex = arith.addi %currentIndex, %c1 : index
+  cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+  // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+  return
+}

>From 17503dc413d40baed2b38497d575f53c67287e3b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 2 May 2024 16:41:34 +0000
Subject: [PATCH 06/13] Wrap op modifications in `rewriter.modifyOpInPlace()`

---
 .../ArmSME/Transforms/TileAllocation.cpp      | 28 +++++++++++--------
 1 file changed, 17 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index ce8f34152475..af8079c8dee5 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -191,10 +191,12 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
                condBranch.getTrueDestOperands());
     insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
                condBranch.getFalseDestOperands());
-    condBranch.getFalseDestOperandsMutable().clear();
-    condBranch.getTrueDestOperandsMutable().clear();
-    condBranch.setSuccessor(newTrueBranch, 0);
-    condBranch.setSuccessor(newFalseBranch, 1);
+    rewriter.modifyOpInPlace(condBranch, [&] {
+      condBranch.getFalseDestOperandsMutable().clear();
+      condBranch.getTrueDestOperandsMutable().clear();
+      condBranch.setSuccessor(newTrueBranch, 0);
+      condBranch.setSuccessor(newFalseBranch, 1);
+    });
   }
 }
 
@@ -211,7 +213,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
       if (isValidSMETileVectorType(operand.get().getType())) {
         auto copy =
             rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
-        operand.assign(copy);
+        rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
       }
     }
   }
@@ -494,7 +496,8 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
           // Ensure ArmSME ops that don't produce a value still get a tile ID.
           if (!hasTileResult(tileOp))
-            tileOp.setTileId(tileIdAttr);
+            rewriter.modifyOpInPlace(tileOp,
+                                     [&] { tileOp.setTileId(tileIdAttr); });
         }
       }
       auto copyOp = value.getDefiningOp<CopyTileOp>();
@@ -502,7 +505,7 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         // Fold redundant copies.
         rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
       } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
-        tileOp.setTileId(tileIdAttr);
+        rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
         // Rectify operand tile IDs with result tile IDs.
         OpOperand *tileOperand = getTileOpOperand(tileOp);
         if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
@@ -515,12 +518,15 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         // Cloning prevents a move/spill (though may require recomputation).
         rewriter.setInsertionPoint(tileOp);
         auto clonedOp = operandTileOp.clone();
-        clonedOp.setTileId(tileOp.getTileId());
+        rewriter.modifyOpInPlace(
+            clonedOp, [&] { clonedOp.setTileId(tileOp.getTileId()); });
         rewriter.insert(clonedOp);
-        if (copyOp)
+        if (copyOp) {
           rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
-        else
-          tileOperand->assign(clonedOp->getResult(0));
+        } else {
+          rewriter.modifyOpInPlace(
+              tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
+        }
       } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
         // Validate block arguments.
         bool tileMismatch = false;

>From 3527c9ee65ad8514d845a132cdcc4b938768bfdf Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 3 May 2024 15:58:44 +0000
Subject: [PATCH 07/13] More docs + naming

---
 .../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp  |  6 +-
 .../ArmSME/Transforms/TileAllocation.cpp      | 62 ++++++++++++++-----
 2 files changed, 50 insertions(+), 18 deletions(-)

diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 562c58b73e7e..3dbc8e9916df 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -406,11 +406,11 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
 ///  AFTER:
 ///  ```mlir
 ///     "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
-///     %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
+///     %v = arm_sme.get_tile : vector<[4]x[4]xi32>
 ///  ```
 ///
-///  The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
-///  once all ArmSME ops have been converted to LLVM intrinsics.
+///  The 'arm_sme.get_tile' (which models the return) will fold away once all
+///  ArmSME ops have been converted to LLVM intrinsics.
 struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
   using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
 
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index af8079c8dee5..bc41386f5903 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -164,9 +164,23 @@ class TileAllocator {
   unsigned nextInMemoryTileId = kInMemoryTileIdBase;
 };
 
-// Add new intermediate blocks for the true and false destinations of a
-// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
-// branches.
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+///  ^bb1_copy:
+///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ^bb2_copy:
+///    cf.br ^bb2
+///  ```
 void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   SmallVector<cf::CondBranchOp> worklist;
   function.walk([&](cf::CondBranchOp condBranch) {
@@ -200,7 +214,18 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   }
 }
 
-/// Inserts tile copies at `cf.br` operations.
+/// Splits conditional branches (see `splitCondBranches`), then inserts tile
+/// copies at `cf.br` operations.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+///  ```
 void insertCopiesAtBranches(IRRewriter &rewriter,
                             FunctionOpInterface function) {
   splitCondBranches(rewriter, function);
@@ -219,7 +244,9 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
   }
 }
 
-/// A range where a tile value is live. The range may contain holes.
+/// A live range for a (collection of) tile values. A live range is built up of
+/// intervals [start, end) which represent parts of the program where the value
+/// needs to be live (i.e. in an SME virtual tile).
 struct LiveRange {
   using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
                                      llvm::IntervalMapHalfOpenInfo<unsigned>>;
@@ -296,33 +323,38 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                      LiveRange::Allocator &liveRangeAllocator,
                      Liveness &liveness, FunctionOpInterface function) {
   DenseMap<Value, LiveRange> liveRanges;
-  auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
-                              LivenessBlockInfo const &livenessInfo,
-                              bool liveAtBlockEntry = false) {
+  /// Defines or updates a live range for an SME tile value. Live-ins may update
+  /// an existing live range (rather than define a new one).
+  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+                                          LivenessBlockInfo const &livenessInfo,
+                                          bool liveAtBlockEntry = false) {
     if (!isValidSMETileVectorType(value.getType()))
       return;
-    auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+    // Find or create a live range for `value`.
+    auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+    LiveRange &valueLiveRange = it->second;
     auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
     unsigned start =
         operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
     unsigned end = operationToIndexMap.at(lastUseInBlock);
-    it->second.insert(value, start, end);
+    valueLiveRange.insert(value, start, end);
   };
 
   for (Block &block : function.getBlocks()) {
     LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
     // Handle block arguments:
     for (Value argument : block.getArguments())
-      updateLiveRanges(argument, &block.front(), *livenessInfo,
-                       /*liveAtBlockEntry=*/true);
+      defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
     // Handle live-ins:
     for (Value liveIn : livenessInfo->in())
-      updateLiveRanges(liveIn, &block.front(), *livenessInfo,
-                       /*liveAtBlockEntry=*/true);
+      defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
     // Handle new definitions:
     for (Operation &op : block) {
       for (Value result : op.getResults())
-        updateLiveRanges(result, &op, *livenessInfo);
+        defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
     }
   }
 

>From 4f0ca31405c0bc4756fd96b316046cb0478c1379 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 May 2024 10:47:47 +0000
Subject: [PATCH 08/13] Add CF allocation test for cond_br with backedge

---
 .../ArmSME/tile-allocation-liveness.mlir      | 72 +++++++++++++++++++
 1 file changed, 72 insertions(+)

diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index a63312ded1b9..b46c0cfc9a27 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -279,3 +279,75 @@ func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<
   }
   return
 }
+
+// -----
+
+// This test is a follow up to the test of the same name in `tile-allocation-copies.mlir`.
+// This shows the live ranges (which are why we need to split the conditional branch).
+
+//  CHECK-LIVE-RANGE-LABEL: @cond_branch_with_backedge
+//        CHECK-LIVE-RANGE: ^bb1:
+//  CHECK-LIVE-RANGE--NEXT:  ||| |           arith.cmpi
+//  CHECK-LIVE-RANGE--NEXT:  EEE E           cf.cond_br
+//
+//  CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES:[[:alnum:]]+]]:
+//  CHECK-LIVE-RANGE--NEXT:  ||| ES          arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:  E||  |S         arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:   E|  ||S        arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:    E  |||S       arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT:       EEEE       cf.br
+//
+// It is important to note that the first three live ranges in ^bb1 do not end
+// at the `cf.cond_br` they are live-out via the backedge bb1 -> bb2 -> bb1.
+// This means that if we placed the `arm_sme.tile_copies` before the `cf.cond_br`
+// then those live ranges would not end at the copies, resulting in unwanted
+// overlapping live ranges (and hence tile spills).
+//
+// With the conditional branch split and the copies placed in the BB3_COPIES
+// block the first three live ranges end at the copy operations (as the
+// BB3_COPIES block is on the path out of the loop and has no backedge). This
+// means there is no overlaps and the live ranges all merge, as shown below.
+//
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb1:
+//  CHECK-LIVE-RANGE--NEXT: |||| arith.cmpi
+//  CHECK-LIVE-RANGE--NEXT: EEEE cf.cond_br
+//
+//  CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES]]:
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+//  CHECK-LIVE-RANGE--NEXT: EEEE cf.br
+
+// CHECK-LABEL: @cond_branch_with_backedge
+// CHECK-NOT: tile_id = 16
+// CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 2 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 3 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-NOT tile_id = 16
+func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %tileD = arm_sme.get_tile : vector<[4]x[4]xf32>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 10 : index
+  // Live here: %tileA, %tileB, %tileC, %tileD
+  cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+  %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+  // Live here: %iterTile, %tileB, %tileC, %tileD
+  %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+  %nextIndex = arith.addi %currentIndex, %c1 : index
+  cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+  // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+  return
+}

>From cc7bc3a336f095b0b0be8b990cd99ac35bd8edd1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 May 2024 11:06:40 +0000
Subject: [PATCH 09/13] Fix comment

---
 mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
index 5e43efe93b55..da7f4843f75f 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -66,8 +66,8 @@ func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector
 ^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
   %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
   // Live here: %iterTile, %tileB, %tileC, %tileD
-  // %iterTile dies at the `cf.cond_br`, but %tileB, %tileC, %tileD are live out (in the ^bb2 case).
-  // If we inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
+  // %iterTile, %tileB, %tileC, %tileD are live out (in the ^bb2 case). If we
+  // inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
   // However, note that the copies are only needed if we take the ^bb3 path. So, if we add
   // a new block along that path we can insert the copies without any conflicts.
   cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)

>From 3136b73b9f29ecac064d30a26d3303a4da5b09d1 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 8 May 2024 12:25:08 +0000
Subject: [PATCH 10/13] Use diagnostics to point at tile operand if move is
 required

---
 .../Dialect/ArmSME/Transforms/TileAllocation.cpp    | 13 +++++++++----
 .../Dialect/ArmSME/tile-allocation-invalid.mlir     |  4 +++-
 .../Dialect/ArmSME/tile-allocation-liveness.mlir    |  7 +++++--
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index bc41386f5903..116a02a9faca 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -544,9 +544,14 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
           continue;
         auto operandTileOp =
             tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
-        if (!isTriviallyCloneableTileOp(operandTileOp))
-          return tileOp.emitOpError("failed to rectify tile operand with tile "
-                                    "result (move required)");
+        if (!isTriviallyCloneableTileOp(operandTileOp)) {
+          auto error =
+              tileOp.emitOpError("tile operand allocated to different SME "
+                                 "virtial tile (move required)");
+          error.attachNote(tileOperand->get().getLoc())
+              << "tile operand is: " << tileOperand->get();
+          return error;
+        }
         // Cloning prevents a move/spill (though may require recomputation).
         rewriter.setInsertionPoint(tileOp);
         auto clonedOp = operandTileOp.clone();
@@ -567,7 +572,7 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
             return;
           if (!isAllocatedToSameTile(predecessorTile)) {
             blockArg.getOwner()->getParentOp()->emitOpError(
-                "block argument not allocated to the same tile as "
+                "block argument not allocated to the same SME virtial tile as "
                 "predecessors");
             tileMismatch = true;
           }
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index b3112264cba9..6b5e44365bf5 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -2,9 +2,11 @@
 
 // Select between tileA and tileB. This is currently unsupported as it would
 // require inserting (runtime) tile moves.
+
+// expected-note at below {{tile operand is: <block argument> of type 'vector<[4]x[4]xi32>'}}
 func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %tileA : vector<[4]x[4]xi32>, %tileB : vector<[4]x[4]xi32>, %cond: i1) {
   %c0 = arith.constant 0 : index
-  // expected-error at +1 {{op failed to rectify tile operand with tile result (move required)}}
+  // expected-error at below {{op tile operand allocated to different SME virtial tile (move required)}}
   %tile = scf.if %cond -> vector<[4]x[4]xi32> {
     scf.yield %tileA : vector<[4]x[4]xi32>
   } else {
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index b46c0cfc9a27..9deeac486373 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -37,8 +37,9 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
 //   CHECK-LIVE-RANGE-NEXT: |E test.some_use
 //   CHECK-LIVE-RANGE-NEXT: E  test.some_use
 
+// expected-note at below {{tile operand is: <block argument> of type 'vector<[4]x[4]xf32>'}}
 func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
-  // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
+  // expected-error at below {{op tile operand allocated to different SME virtial tile (move required)}}
   %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
   "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
@@ -143,8 +144,10 @@ func.func @non_overlapping_branches(%cond: i1) {
 
 // Here %vecA and %vecB are not merged into the same live range (as they are unknown values).
 // This means that %vecA and %vecB are both allocated to different tiles (which is not legal).
+
+// expected-note at below {{tile operand is: <block argument> of type 'vector<[4]x[4]xf32>'}}
 func.func @overlapping_branches(%cond: i1, %vecA: vector<[4]x[4]xf32>, %vecB: vector<[4]x[4]xf32>) {
-  // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
+  // expected-error at below {{op tile operand allocated to different SME virtial tile (move required)}}
   %tile = scf.if %cond -> vector<[4]x[4]xf32> {
     scf.yield %vecA : vector<[4]x[4]xf32>
   } else {

>From 1373a8e59e2a827e93b75d11a6707065aa73ecbf Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Wed, 8 May 2024 18:00:41 +0000
Subject: [PATCH 11/13] More stuff

---
 .../mlir/Dialect/ArmSME/Transforms/Passes.td  |   4 +-
 .../include/mlir/Dialect/ArmSME/Utils/Utils.h |  11 +-
 mlir/lib/Dialect/ArmSME/IR/Utils.cpp          |   2 +
 .../ArmSME/Transforms/TileAllocation.cpp      | 202 ++++++++++++------
 .../ArmSME/tile-allocation-copies.mlir        |   2 +-
 5 files changed, 148 insertions(+), 73 deletions(-)

diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 81730166d3c0..869a031d6cae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -138,8 +138,8 @@ def TestTileAllocation
     Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
            "bool", /*default=*/"false",
            "Dump the live ranges of SME tiles (for debugging)">,
-    Option<"tileCopiesOnly", "tile-copies-only", "bool", /*default=*/"false",
-           "Only insert tile copies needed for tile allocation "
+    Option<"preprocessOnly", "preprocess-only", "bool", /*default=*/"false",
+           "Only preprocess IR so it is ready for tile allocation "
            "(but do not allocate any tiles)">
   ];
   let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 9ea1c5a5d63f..5bd12a7ba6ff 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -74,7 +74,16 @@ VectorType getSMETileTypeForElement(Type elementType);
 void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
                                FunctionOpInterface function);
 
-/// Returns true if `tileOp` can be cloned to resolve conflicts.
+/// Returns true if `tileOp` is trivially cloneable. A tile operation is
+/// trivially cloneable if:
+///
+///  1. It has no operands (and only a single tile result)
+///  2. It is 'Pure'
+///
+/// This ensures that the cloned operation will not share any dependencies with
+/// the original operation (which could also need to be considered), and that
+/// inserting the cloned operation at a different point in the program won't
+/// change the semantics of the program (as it has no side effects).
 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
 
 /// Returns true if `tileOp` produces a tile result.
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 6c0ebddb5a2d..4d5cd1648e42 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -150,6 +150,8 @@ bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
 }
 
 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
+  if (!tileOp)
+    return nullptr;
   auto isTileOperandType = [](OpOperand &operand) {
     return arm_sme::isValidSMETileVectorType(operand.get().getType());
   };
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 116a02a9faca..f4d1570daa9e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -7,11 +7,17 @@
 //===----------------------------------------------------------------------===//
 //
 // This transform allocates SME tiles at the 'func.func' op level for ArmSME
-// operations. It does this using a 16-bit tile mask that has a bit for each
-// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
+// operations. It roughly implements a linear scan register allocator, similar
+// to the one outlined in [1], but with simplifications and assumptions made for
+// our use case. Note that this is a greedy allocator (so it may not always find
+// the most optimal allocation of tiles).
+//
+// The allocator operates at the CF dialect level. It is the responsibility of
+// users to ensure the IR has been lowered to CF before invoking the tile
+// allocator.
 //
 // The 128-bit tiles overlap with other element tiles as follows (see section
-// B2.3.2 of SME spec [1]):
+// B2.3.2 of SME spec [2]):
 //
 //   Tile    Overlaps
 //   ---------------------------------------------------------------------------
@@ -32,7 +38,10 @@
 //   ZA6.D   ZA6.Q, ZA14.Q
 //   ZA7.D   ZA7.Q, ZA15.Q
 //
-// [1] https://developer.arm.com/documentation/ddi0616/aa
+// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
+//      Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
+//     https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
+// [2] https://developer.arm.com/documentation/ddi0616/aa
 //
 //===----------------------------------------------------------------------===//
 
@@ -214,8 +223,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
   }
 }
 
-/// Splits conditional branches (see `splitCondBranches`), then inserts tile
-/// copies at `cf.br` operations.
+/// Inserts tile copies at `cf.br` operations.
 ///
 ///  BEFORE:
 ///  ```mlir
@@ -228,7 +236,6 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
 ///  ```
 void insertCopiesAtBranches(IRRewriter &rewriter,
                             FunctionOpInterface function) {
-  splitCondBranches(rewriter, function);
   for (Block &block : function.getBlocks()) {
     Operation *terminator = block.getTerminator();
     if (!isa<cf::BranchOp>(terminator))
@@ -244,6 +251,20 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
   }
 }
 
+/// Prepares the IR for tile allocation. It does this by first 'splitting'
+/// conditional branches (see `splitCondBranches`), then inserting tile copies
+/// at branch operations. The conditional branches are split to prevent the
+/// copies needed for them overlapping between the true and false paths of the
+/// branch (see `tile-allocation-copies.mlir` and
+/// `tile-allocation-liveness.mlir` for examples). The copies break up live
+/// ranges and ensure when moving out of SSA the semantics of the program are
+/// persevered.
+void preprocessForTileAllocation(IRRewriter &rewriter,
+                                 FunctionOpInterface function) {
+  splitCondBranches(rewriter, function);
+  insertCopiesAtBranches(rewriter, function);
+}
+
 /// A live range for a (collection of) tile values. A live range is built up of
 /// intervals [start, end) which represent parts of the program where the value
 /// needs to be live (i.e. in an SME virtual tile).
@@ -295,6 +316,9 @@ struct LiveRange {
 };
 
 /// Number operations within a function to allow computing live ranges.
+/// Operations are numbered consecutively wihin blocks, and the blocks are
+/// topologically sorted (using forward edges). This function is only correct if
+/// all ArmSME have been converted to CF (which is asserted).
 DenseMap<Operation *, unsigned>
 generateOperationNumbering(FunctionOpInterface function) {
   unsigned index = 0;
@@ -304,7 +328,6 @@ generateOperationNumbering(FunctionOpInterface function) {
   for (Block *block : blocks) {
     index++; // We want block args to have their own number.
     for (Operation &op : block->getOperations()) {
-      // This is only correct if all ArmSME have been converted to CF.
 #ifndef NDEBUG
       op.walk([&](ArmSMETileOpInterface nestedOp) {
         assert(&op == nestedOp.getOperation() &&
@@ -324,7 +347,9 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                      Liveness &liveness, FunctionOpInterface function) {
   DenseMap<Value, LiveRange> liveRanges;
   /// Defines or updates a live range for an SME tile value. Live-ins may update
-  /// an existing live range (rather than define a new one).
+  /// an existing live range (rather than define a new one). Note: If
+  /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
+  /// the block.
   auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
                                           LivenessBlockInfo const &livenessInfo,
                                           bool liveAtBlockEntry = false) {
@@ -335,10 +360,10 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
     LiveRange &valueLiveRange = it->second;
     auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
     // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
-    unsigned start =
+    unsigned startOpIdx =
         operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
-    unsigned end = operationToIndexMap.at(lastUseInBlock);
-    valueLiveRange.insert(value, start, end);
+    unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
+    valueLiveRange.insert(value, startOpIdx, endOpIdx);
   };
 
   for (Block &block : function.getBlocks()) {
@@ -511,6 +536,20 @@ void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
   }
 }
 
+/// Assigns a tile ID to an MLIR value.
+void assignTileIdToValue(IRRewriter &rewriter, Value value,
+                         IntegerAttr tileIdAttr) {
+  if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
+    rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
+  for (Operation *user : value.getUsers()) {
+    if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+      // Ensure ArmSME ops that don't produce a value still get a tile ID.
+      if (!hasTileResult(tileOp))
+        rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
+    }
+  }
+}
+
 /// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
 LogicalResult assignTileIdsAndResolveTrivialConflicts(
     IRRewriter &rewriter, FunctionOpInterface function,
@@ -523,63 +562,88 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
         return true;
       return liveRange->values.contains(value);
     };
-    for (Value value : liveRange->values) {
-      for (Operation *user : value.getUsers()) {
-        if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
-          // Ensure ArmSME ops that don't produce a value still get a tile ID.
-          if (!hasTileResult(tileOp))
-            rewriter.modifyOpInPlace(tileOp,
-                                     [&] { tileOp.setTileId(tileIdAttr); });
-        }
-      }
+
+    /// Eliminates copies where the operand has the same tile ID.
+    auto foldRedundantCopies = [&](Value value) -> LogicalResult {
       auto copyOp = value.getDefiningOp<CopyTileOp>();
-      if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
-        // Fold redundant copies.
-        rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
-      } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
-        rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
-        // Rectify operand tile IDs with result tile IDs.
-        OpOperand *tileOperand = getTileOpOperand(tileOp);
-        if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
-          continue;
-        auto operandTileOp =
-            tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
-        if (!isTriviallyCloneableTileOp(operandTileOp)) {
-          auto error =
-              tileOp.emitOpError("tile operand allocated to different SME "
-                                 "virtial tile (move required)");
-          error.attachNote(tileOperand->get().getLoc())
-              << "tile operand is: " << tileOperand->get();
-          return error;
+      if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
+        return failure();
+      rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
+      return success();
+    };
+
+    /// Validates each predecessor to a tile block argument has been assigned
+    /// the same tile ID.
+    auto validateBlockArguments = [&](Value value) {
+      auto blockArg = dyn_cast<BlockArgument>(value);
+      if (!blockArg) {
+        // Not a block argument (nothing to validate).
+        return success();
+      }
+      bool tileMismatch = false;
+      forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+        if (tileMismatch)
+          return;
+        if (!isAllocatedToSameTile(predecessorTile)) {
+          blockArg.getOwner()->getParentOp()->emitOpError(
+              "block argument not allocated to the same SME virtial tile as "
+              "predecessors");
+          tileMismatch = true;
         }
-        // Cloning prevents a move/spill (though may require recomputation).
-        rewriter.setInsertionPoint(tileOp);
-        auto clonedOp = operandTileOp.clone();
+      });
+      return success(/*isSuccess=*/!tileMismatch);
+    };
+
+    /// Attempts to resolve (trivial) tile ID conflicts.
+    auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
+      auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+      OpOperand *tileOperand = getTileOpOperand(tileOp);
+      if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
+        // Operand already allocated to the correct tile.
+        // No conflict to resolve.
+        return success();
+      }
+      auto operandTileOp =
+          tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
+      if (!isTriviallyCloneableTileOp(operandTileOp)) {
+        auto error =
+            tileOp.emitOpError("tile operand allocated to different SME "
+                               "virtial tile (move required)");
+        error.attachNote(tileOperand->get().getLoc())
+            << "tile operand is: " << tileOperand->get();
+        return error;
+      }
+      // Cloning prevents a move/spill (though may require recomputation).
+      rewriter.setInsertionPoint(tileOp);
+      auto clonedOp = operandTileOp.clone();
+      rewriter.modifyOpInPlace(clonedOp,
+                               [&] { clonedOp.setTileId(tileOp.getTileId()); });
+      rewriter.insert(clonedOp);
+      if (isa<CopyTileOp>(tileOp)) {
+        rewriter.replaceAllUsesWith(tileOp->getResult(0),
+                                    clonedOp->getResult(0));
+      } else {
         rewriter.modifyOpInPlace(
-            clonedOp, [&] { clonedOp.setTileId(tileOp.getTileId()); });
-        rewriter.insert(clonedOp);
-        if (copyOp) {
-          rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
-        } else {
-          rewriter.modifyOpInPlace(
-              tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
-        }
-      } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
-        // Validate block arguments.
-        bool tileMismatch = false;
-        forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
-          if (tileMismatch)
-            return;
-          if (!isAllocatedToSameTile(predecessorTile)) {
-            blockArg.getOwner()->getParentOp()->emitOpError(
-                "block argument not allocated to the same SME virtial tile as "
-                "predecessors");
-            tileMismatch = true;
-          }
-        });
-        if (tileMismatch)
-          return failure();
+            tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
       }
+      return success();
+    };
+
+    for (Value value : liveRange->values) {
+      // 1. Assign the tile ID to the value.
+      assignTileIdToValue(rewriter, value, tileIdAttr);
+
+      // 2. Attempt to eliminate redundant tile copies.
+      if (succeeded(foldRedundantCopies(value)))
+        continue;
+
+      // 3. Validate tile block arguments.
+      if (failed(validateBlockArguments(value)))
+        return failure();
+
+      // 4. Attempt to resolve (trivial) tile ID conflicts.
+      if (failed(resolveTrivialTileConflicts(value)))
+        return failure();
     }
   }
   return success();
@@ -619,9 +683,9 @@ struct TestTileAllocationPass
   using TestTileAllocationBase::TestTileAllocationBase;
   void runOnOperation() override {
     FunctionOpInterface function = getOperation();
-    if (tileCopiesOnly) {
+    if (preprocessOnly) {
       IRRewriter rewriter(function);
-      return insertCopiesAtBranches(rewriter, function);
+      return preprocessForTileAllocation(rewriter, function);
     }
     if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
       signalPassFailure();
@@ -634,8 +698,8 @@ LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
   LiveRange::Allocator liveRangeAllocator;
   IRRewriter rewriter(function.getContext());
 
-  // 1. Insert copy operations at branch operations.
-  insertCopiesAtBranches(rewriter, function);
+  // 1. Preprocess the IR for tile allocation.
+  preprocessForTileAllocation(rewriter, function);
 
   // 2. Gather live ranges for each ArmSME tile within the function.
   Liveness liveness(function);
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
index da7f4843f75f..3d44950f463e 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-arm-sme-tile-allocation=tile-copies-only -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation=preprocess-only -split-input-file | FileCheck %s
 
 // This file tests the inserting copies for the SME tile allocation. Copies are
 // inserted at `cf.br` ops (the predecessors to block arguments). Conditional

>From 16c1fc78f40e724d907e5f8eda27656f2cc27703 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 9 May 2024 10:09:16 +0000
Subject: [PATCH 12/13] More stuff

---
 .../ArmSME/tile-allocation-copies.mlir        | 93 +++++++++++++++++--
 1 file changed, 85 insertions(+), 8 deletions(-)

diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
index 3d44950f463e..6d9cbf36a162 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -43,20 +43,27 @@ func.func @cond_branch(%cond: i1, %tile: vector<[4]x[4]xf32>) {
 // Reduction of a real world example that shows why we must split conditional branches.
 
 // CHECK-LABEL: @cond_branch_with_backedge(
-//  CHECK-SAME:    %{{[[:alnum:]]+}}: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+//  CHECK-SAME:    %[[TILEA:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
 //  CHECK-SAME:    %[[TILEC:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILED:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+//       CHECK:   %[[BB1_COPY_0:.*]] = arm_sme.copy_tile %[[TILEA]] : vector<[4]x[4]xf32>
+//       CHECK:   cf.br ^bb1(%{{[[:alnum:]]+}}, %[[BB1_COPY_0]]
 //       CHECK: ^bb1(%[[CURRENT_INDEX:.*]]: index, %[[ITER_TILE:.*]]: vector<[4]x[4]xf32>):
 //       CHECK:   %[[CONTINUE_LOOP:.*]] = arith.cmpi
-//       CHECK:   cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]]
+//       CHECK:   cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2_COPIES:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]]
 //       CHECK: ^[[BB3_COPIES]]:
-//  CHECK-NEXT:   arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32>
-//  CHECK-NEXT:   arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32>
-//  CHECK-NEXT:   arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32>
-//  CHECK-NEXT:   arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32>
-//  CHECK-NEXT:   cf.br ^[[BB3:[[:alnum:]]+]]
+//  CHECK-NEXT:   %[[BB3_COPY_0:.*]] = arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   %[[BB3_COPY_1:.*]] = arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   %[[BB3_COPY_2:.*]] = arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   %[[BB3_COPY_3:.*]] = arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32>
+//  CHECK-NEXT:   cf.br ^[[BB3:[[:alnum:]]+]](%[[BB3_COPY_0]], %[[BB3_COPY_1]], %[[BB3_COPY_2]], %[[BB3_COPY_3]]
+//       CHECK: ^[[BB2_COPIES]]:
+//  CHECK-NEXT:   cf.br ^[[BB2:[[:alnum:]]+]]
+//       CHECK: ^[[BB2]]:
+//  CHECK-NEXT:   %[[NEXT_TILE:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}}, %[[ITER_TILE]]
+//       CHECK:   %[[BB1_COPY_1:.*]] = arm_sme.copy_tile %[[NEXT_TILE]] : vector<[4]x[4]xf32>
+//       CHECK:   cf.br ^bb1(%{{[[:alnum:]]+}}, %[[BB1_COPY_1]]
 //       CHECK: ^[[BB3]](%{{.*}}: vector<[4]x[4]xf32>):
 //  CHECK-NEXT:   return
-
 func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector<[4]x[4]xf32>, %tileC: vector<[4]x[4]xf32>, %tileD: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -80,3 +87,73 @@ func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector
   // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
   return
 }
+
+// -----
+
+// CHECK-LABEL: @tile_dominance
+// CHECK-NOT: arm_sme.copy_tile
+func.func @tile_dominance(%arg0: vector<[4]x[4]xf32>) {
+  cf.br ^bb1
+^bb1:  // 2 preds: ^bb0, ^bb4
+  "test.some_use"(%arg0) : (vector<[4]x[4]xf32>) -> ()
+  return
+^bb2:  // no predecessors
+  %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.br ^bb3
+^bb3:  // pred: ^bb2
+  "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
+  return
+^bb4:  // no predecessors
+  cf.br ^bb1
+^bb5:  // no predecessors
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cond_branch_true_and_false_tile_args(
+//  CHECK-SAME:   %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32>
+//  CHECK-NEXT:   cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[BB2_COPIES:[[:alnum:]]+]]
+//       CHECK: ^[[BB2_COPIES]]:
+//  CHECK-NEXT:   %[[COPY_0:.*]] = arm_sme.copy_tile %[[TILE]] :  vector<[4]x[4]xf32>
+//  CHECK-NEXT:   cf.br ^[[BB2:[[:alnum:]]+]](%[[COPY_0]]
+//       CHECK: ^[[BB1_COPIES]]:
+//  CHECK-NEXT:   %[[COPY_1:.*]] = arm_sme.copy_tile %[[TILE]] :  vector<[4]x[4]xf32>
+//  CHECK-NEXT:   cf.br ^[[BB1:[[:alnum:]]+]](%[[COPY_1]]
+//       CHECK: ^[[BB1]]{{.*}}:
+//  CHECK-NEXT:   return
+//       CHECK: ^[[BB2]]{{.*}}:
+//  CHECK-NEXT:   return
+func.func @cond_branch_true_and_false_tile_args(%cond: i1, %tile: vector<[4]x[4]xf32>) {
+  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2(%tile: vector<[4]x[4]xf32>)
+^bb1(%blockArg0: vector<[4]x[4]xf32>):
+  return
+^bb2(%blockArg1: vector<[4]x[4]xf32>):
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @multiple_predecessors
+//      CHECK: ^bb1:
+// CHECK-NEXT:   %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
+// CHECK-NEXT:   %[[COPY_0:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// CHECK-NEXT:    cf.br ^bb3(%[[COPY_0]] : vector<[4]x[4]xf32>)
+//      CHECK: ^bb2:
+// CHECK-NEXT:   %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xf32>
+// CHECK-NEXT:   %[[COPY_1:.*]] = arm_sme.copy_tile %[[ZERO]] : vector<[4]x[4]xf32>
+// CHECK-NEXT:    cf.br ^bb3(%[[COPY_1]] : vector<[4]x[4]xf32>)
+//      CHECK: ^bb3({{.*}}):
+// CHECK-NEXT:  return
+func.func @multiple_predecessors(%cond: i1)
+{
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.br ^bb3(%tile : vector<[4]x[4]xf32>)
+^bb2:
+  %zero = arm_sme.zero : vector<[4]x[4]xf32>
+  cf.br ^bb3(%zero : vector<[4]x[4]xf32>)
+^bb3(%blockArg: vector<[4]x[4]xf32>): // pred: ^bb1, ^bb2
+  return
+}

>From a91685a5872c8e9971e900080032d50d61be7d17 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 9 May 2024 10:41:51 +0000
Subject: [PATCH 13/13] More stuff

---
 .../ArmSME/Transforms/TileAllocation.cpp      | 26 ++++++++++++++++---
 1 file changed, 22 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index f4d1570daa9e..2b20ef59e3d3 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -233,6 +233,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
 ///  AFTER:
 ///  ```mlir
 ///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+///  cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
 ///  ```
 void insertCopiesAtBranches(IRRewriter &rewriter,
                             FunctionOpInterface function) {
@@ -258,7 +259,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
 /// branch (see `tile-allocation-copies.mlir` and
 /// `tile-allocation-liveness.mlir` for examples). The copies break up live
 /// ranges and ensure when moving out of SSA the semantics of the program are
-/// persevered.
+/// preserved.
 void preprocessForTileAllocation(IRRewriter &rewriter,
                                  FunctionOpInterface function) {
   splitCondBranches(rewriter, function);
@@ -266,8 +267,10 @@ void preprocessForTileAllocation(IRRewriter &rewriter,
 }
 
 /// A live range for a (collection of) tile values. A live range is built up of
-/// intervals [start, end) which represent parts of the program where the value
-/// needs to be live (i.e. in an SME virtual tile).
+/// non-overlapping intervals [start, end) which represent parts of the program
+/// where a value in the range needs to be live (i.e. in an SME virtual tile).
+/// Note that as the intervals are non-overlapping all values within a live
+/// range can be allocated to the same SME virtual tile.
 struct LiveRange {
   using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
                                      llvm::IntervalMapHalfOpenInfo<unsigned>>;
@@ -310,8 +313,14 @@ struct LiveRange {
     return *getSMETileType(cast<VectorType>(values[0].getType()));
   }
 
-  std::unique_ptr<RangeSet> ranges;
+  /// The values contained in this live range.
   SetVector<Value> values;
+
+  /// A set of (non-overlapping) intervals that mark where any value in `values`
+  /// is live.
+  std::unique_ptr<RangeSet> ranges;
+
+  /// The tile ID (or none) assigned to this live range.
   std::optional<unsigned> tileId;
 };
 
@@ -345,6 +354,7 @@ DenseMap<Value, LiveRange>
 gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
                      LiveRange::Allocator &liveRangeAllocator,
                      Liveness &liveness, FunctionOpInterface function) {
+  assert(!operationToIndexMap.empty() && "expected operation numbering");
   DenseMap<Value, LiveRange> liveRanges;
   /// Defines or updates a live range for an SME tile value. Live-ins may update
   /// an existing live range (rather than define a new one). Note: If
@@ -420,6 +430,9 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
     liveRanges.insert({value, &liveRange});
   }
 
+  // Merge the live ranges of values `a` and `b` into one (if they do not
+  // overlap). After this, the values `a` and `b` will both point to the same
+  // live range (which will contain multiple values).
   auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
     LiveRange *aLiveRange = liveRanges.at(a);
     LiveRange *bLiveRange = liveRanges.at(b);
@@ -695,6 +708,11 @@ struct TestTileAllocationPass
 
 LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
                                               bool dumpRanges) {
+  if (function.empty()) {
+    // TODO: Also return early if the function contains no ArmSME ops?
+    return success();
+  }
+
   LiveRange::Allocator liveRangeAllocator;
   IRRewriter rewriter(function.getContext());
 



More information about the Mlir-commits mailing list