[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 29 08:12:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This patch rewrites the ArmSME tile allocator to use liveness information to make better tile allocation decisions and improve the correctness of the ArmSME dialect. This algorithm used here is a linear scan over live ranges, where live ranges are assigned to tiles as they appear in the program (chronologically). Live ranges release their assigned tile ID when the current program point is passed their end. This is a greedy algorithm (which is mainly to keep the implementation relatively straightforward), and because it seems to be sufficient for most kernels (e.g. matmuls) that use ArmSME. The general steps of this are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf, though there have been a few simplifications and assumptions made for our use case.

Hopefully, the only changes needed for a user of the ArmSME dialect is that:

- `-allocate-arm-sme-tiles` will no longer be a standalone pass 
  - `-test-arm-sme-tile-allocation` is only for unit tests 
- `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf` 
   - SME tile allocation is now part of the LLVM conversion

By integrating this into the `ArmSME -> LLVM` conversion we can allow high-level (value-based) ArmSME operations to be side-effect-free, as we can guarantee nothing will rearrange ArmSME operations before we emit intrinsics (which could invalidate the tile allocation).

The hope is for ArmSME operations to have no hidden state/side effects and allow easily lowering dialects such as `vector` and `arith` to SME, without making assumptions about how the input IR looks, as the semantics of the operations will be the same. That is no (new) side effects and the IR follows the rules of SSA (a value will never change).

The aim is correctness, so we have a base for working on optimizations.

---

Patch is 129.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90448.diff


30 Files Affected:

- (modified) mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h (+3-1) 
- (modified) mlir/include/mlir/Conversion/Passes.td (+6-1) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h (+1-5) 
- (added) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h (+28) 
- (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+47-94) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (-3) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+11-6) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h (+9) 
- (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+20) 
- (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+36-33) 
- (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+11-18) 
- (modified) mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp (+6) 
- (modified) mlir/lib/Dialect/ArmSME/IR/Utils.cpp (+46) 
- (modified) mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp (+487-157) 
- (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+23-4) 
- (modified) mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir (+8-7) 
- (modified) mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir (+2-1) 
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+6) 
- (renamed) mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir (+170-95) 
- (modified) mlir/test/Dialect/ArmSME/canonicalize.mlir (+4-6) 
- (removed) mlir/test/Dialect/ArmSME/cse.mlir (-30) 
- (modified) mlir/test/Dialect/ArmSME/enable-arm-za.mlir (+6-14) 
- (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+4-3) 
- (modified) mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir (+5-7) 
- (added) mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir (+269) 
- (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+30-15) 
- (modified) mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir (+5-2) 
- (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir (+2-1) 
- (modified) mlir/test/lib/Dialect/ArmSME/CMakeLists.txt (+1) 
- (modified) mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp (+11-8) 


``````````diff
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab499983..403f811a2569a0 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
 #include <memory>
 
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 
 namespace mlir {
 class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
 #include "mlir/Conversion/Passes.h.inc"
 
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
 void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..e6d678dc1b12b3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
 // ArmSMEToLLVM
 //===----------------------------------------------------------------------===//
 
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
   let summary = "Lower the operations from the ArmSME dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
     "arm_sme::ArmSMEDialect",
     "LLVM::LLVMDialect"
   ];
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a74..dac54712c7f47a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 00000000000000..f31062d8c25ed7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,28 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+static constexpr unsigned kInMemoryTileIdBase = 16;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d2a..9178655f010c9a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
 
 def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   let description = [{
-    An interface for operations that use or allocate Arm SME tiles. These
-    operations need to be assigned a tile ID, an i32 attribute, which specifies
-    which virtual tile within the ZA storage to use. The number of tiles
-    available depends on the type of the tile. This is summarized below:
+    An interface for operations that use Arm SME tiles. These operations need to
+    be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+    within the ZA storage to use. The number of tiles available depends on the
+    type of the tile. This is summarized below:
 
     | Tile Vector Types                                                       | Possible Tile IDs   |
     |-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
     | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>`                          | 0 to 3 (inclusive)  |
     | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>`                          | 0 to 7 (inclusive)  |
     | `vector<[1]x[1]xi128>`                                                  | 0 to 15 (inclusive) |
-
-    Operations that allocate a new tile (such as arm_sme.get_tile), are used as
-    the roots for tile allocation, with all operations that (transitively)
-    depend on a root being assigned the same tile ID.
   }];
   let methods = [
     InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
         return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
       }]
     >,
-    InterfaceMethod<
-      [{
-        The type of tile this operation allocates. Returns none (std::nullopt)
-        if this operation does not allocate a tile.
-      }],
-      /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
-      /*methodName=*/"getAllocatedTileType",
-      /*arguments=*/(ins),
-      /*methodBody=*/[{}],
-      /*defaultImpl=*/ [{
-        // This operation does not allocate a tile.
-        return std::nullopt;
-      }]
-    >,
     InterfaceMethod<
       "Returns the VectorType of the tile used by this operation.",
       /*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
   ];
 
   let extraSharedClassDeclaration = [{
-    // A helper to create a new operation and propagate this operations tile ID.
-    template<typename T, typename... Args>
-    T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
-      auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
-      if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
-        tileOp.setTileId($_op.getTileId());
-      return op;
-    }
-
-    // A helper to replace this operation and forward its tile ID (if present).
-    template<typename T, typename... Args>
-    T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
-      auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
-      rewriter.replaceOp($_op, newOp);
-      return newOp;
-    }
-
     bool isInMemoryTile() {
       auto tileId = getTileId();
       return tileId && tileId.getInt() >= kInMemoryTileIdBase;
     }
   }];
 
-  let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+  let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
 class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
   Op<ArmSME_Dialect, mnemonic, traits> {}
 
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
-  let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates an undefined value of SME virtual tile type";
   let description = [{
-    Allocates a new SME "virtual tile" within a function. The contents of the
-    tile returned from this operation are undefined.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are undefined.
 
     Example 1:
 
     ```mlir
-    // Allocate an 8-bit element "virtual tile"
+    // Create an 8-bit element "virtual tile" value:
     %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
     ```
 
     Example 2:
 
     ```mlir
-    // Allocate two 16-bit element "virtual tiles"
+    // Create two 16-bit element "virtual tiles" values:
     %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
     ```
 
     Example 3:
     ```mlir
-    // Allocate an 128-bit element "virtual tile"
+    // Create an 128-bit element "virtual tile" value:
     %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
     ```
   }];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
     VectorType getTileType() {
       return ::llvm::cast<VectorType>(getTile().getType());
     }
-
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getTileType());
-    }
-  }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
-  let summary = "SME tile placeholder";
-  let description = [{
-    A placeholder to preserve dataflow while lowering to SME intrinsics (which
-    do not take or return SME virtual tile values). This operation is intended
-    to be DCE'd once all ArmSME operations have been lowered.
-
-    This operation is not intended to be used outside of the ArmSME -> LLVM
-    conversion.
   }];
-  let results = (outs SMETile:$tile);
-  let assemblyFormat = "attr-dict `:` type($tile)";
 }
 
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
-  let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+  let summary = "Creates a zero-initialized value of SME virtual tile type";
   let results = (outs SMETile:$res);
   let description = [{
-    Initialise ZA with 0. This operation is convenient wrapper for the SME
-    `zero` intrinsic and instruction.
+    Creates a new SME "virtual tile" value within a function. The contents of
+    the tile returned from this operation are zero-initialized.
 
     Example 1: Zero an 8-bit element ZA tile.
 
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getRes().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+  Pure,
+  ArmSMETileOpInterface,
+  AllTypesMatch<["tile", "result"]>
+]> {
+  let summary = "Copies an SME tile value";
+  let arguments = (ins SMETile:$tile);
+  let results = (outs SMETile:$result);
+  let description = [{
+    Copies an SME "virtual tile" value to a new SSA value. This operation is
+    primarily intended to be used to normalize the IR prior to tile allocation.
+
+    Example:
+
+    ```mlir
+    %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+    ```
+  }];
+  let extraClassDeclaration = [{
+    VectorType getTileType() {
+      return ::llvm::cast<VectorType>(getResult().getType());
+    }
+  }];
+  let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
 def TileLoadOp : ArmSME_Op<"tile_load", [
   ArmSMETileOpInterface,
   AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
     VectorType getVectorType() {
       return ::llvm::cast<VectorType>(getResult().getType());
     }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      return arm_sme::getSMETileType(getVectorType());
-    }
     VectorType getTileType() {
       return getVectorType();
     }
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
     ```
   }];
   let arguments = (ins
-    Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+    Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
     SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
     ArmSME_TileSliceLayoutAttr:$layout
   );
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
 }
 
 def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     AllTypesMatch<["tile", "result"]>,
     TypesMatchWith<
       "type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
 }
 
 def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
-    ArmSMETileOpInterface,
+    ArmSMETileOpInterface, Pure,
     TypesMatchWith<
       "type of 'result' matches type of 'tile' slice",
       "tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
 
 def OuterProductOp :
   ArmSME_Op<"outerproduct", [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
                                list<Type> allowedResultVectorTypes,
                                int numOuterProducts> :
   ArmSME_Op<mnemonic, [
+    Pure,
     ArmSMETileOpInterface,
     AttrSizedOperandSegments,
     AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
     VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
     VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
     VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
-    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
-      // The outerproduct op allocates a new tile if no accumulator is passed.
-      if (!getAcc())
-        return arm_sme::getSMETileType(getResultType());
-      return std::nullopt;
-    }
     VectorType getTileType() {
       return getResultType();
     }
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874ec..156744ba57e7b2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
     const ArmStreamingMode = ArmStreamingMode::Streaming,
     const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
 
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
 /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
 /// variants.
 std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e89267..b9d74fec6756e3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,16 +124,21 @@ def EnableArmStreaming
   let dependentDialects = ["func::FuncDialect"];
 }
 
-def TileAllocation
-    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
-  let summary = "Allocate SME tiles";
+def TestTileAllocation
+    : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+  let summary = "Tests SME tile allocation";
   let description = [{
     This pass does tile allocation for SME "virtual tiles". It is run at the
     'func.func' op level, and assigns tile IDs (via an attribute) to all ops
-    that implement the `ArmSMETileOpInterface`. An error will be emitted when
-    there's no tiles left.
+    that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+    to be used for testing, tile allocation is done as part of the ArmSME to
+    LLVM conversion.
   }];
-  let constructor = "mlir::arm_sme::createTileAllocationPass()";
+  let options = [
+    Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+           "bool", /*default=*/"false",
+           "Dump the live ranges of SME tiles (for debugging)">
+  ];
   let dependentDialects = ["func::FuncDialect"];
 }
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e69992..a25b844f01eaa6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
 #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
 #define MLIR_DIALECT_ARMSME_TRANSFORMS_H
 
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
 namespace mlir {
 
 class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
 class RewritePatternSet;
 
 namespace arm_sme {
+
 void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+                               bool dumpRanges = false);
+
 } // namespace arm_sme
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92f5..9ea1c5a5d63fe5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
 #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
 
 #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
 #include <optional>
 
 namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+inline bool isValidSMETileVectorType(Type type) {
+  auto vType = dyn_cast<VectorType>(type);
+  return vType && isValidSMETileVectorType(vType);
+}
+
 /// Returns the type of SME tile this vector type corresponds to, or none if the
 /// vector type does not fit within an SME tile.
 std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,19 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
 /// Creates a vector type for the SME tile of `elementType`.
 VectorType getSMETileTypeForElement(Type elementType);
 
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+                               FunctionOpInterfa...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/90448


More information about the Mlir-commits mailing list