[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri May 3 09:00:02 PDT 2024
https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/90448
>From 13d2408bdc08b871bf2db7f1b6077cf4af649b26 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 26 Apr 2024 15:09:20 +0000
Subject: [PATCH 1/7] [mlir][ArmSME] Use liveness information in the tile
allocator
This patch rewrites the ArmSME tile allocator to use liveness
information to make better tile allocation decisions and improve the
correctness of the ArmSME dialect. This algorithm used here is a linear
scan over live ranges, where live ranges are assigned to tiles as they
appear in the program (chronologically). Live ranges release their
assigned tile ID when the current program point is passed their end.
This is a greedy algorithm (which is mainly to keep the implementation
relatively straightforward), and because it seems to be sufficient for
most kernels (e.g. matmuls) that use ArmSME. The general steps of this
are roughly from https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf,
though there have been a few simplifications and assumptions made for
our use case.
Hopefully, the only changes needed for a user of the ArmSME dialect is
that:
- `-allocate-arm-sme-tiles` will no longer be a standalone pass
- `-test-arm-sme-tile-allocation` is only for unit tests
- `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf`
- SME tile allocation is now part of the LLVM conversion
By integrating this into the `ArmSME -> LLVM` conversion we can allow
high-level (value-based) ArmSME operations to be side-effect-free, as we
can guarantee nothing will rearrange ArmSME operations before we emit
intrinsics (which could invalidate the tile allocation).
The hope is for ArmSME operations to have no hidden state/side effects
and allow easily lowering dialects such as `vector` and `arith` to SME,
without making assumptions about how the input IR looks, as the
semantics of the operations will be the same. That is no (new) side
effects and the IR follows the rules of SSA (a value will never change).
The aim is correctness, so we have a base for working on optimizations.
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h | 4 +-
mlir/include/mlir/Conversion/Passes.td | 7 +-
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h | 6 +-
.../Dialect/ArmSME/IR/ArmSMEOpInterfaces.h | 28 +
.../mlir/Dialect/ArmSME/IR/ArmSMEOps.td | 141 ++--
.../mlir/Dialect/ArmSME/Transforms/Passes.h | 3 -
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 17 +-
.../Dialect/ArmSME/Transforms/Transforms.h | 9 +
.../include/mlir/Dialect/ArmSME/Utils/Utils.h | 20 +
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 69 +-
.../Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp | 29 +-
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp | 6 +
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 46 ++
.../ArmSME/Transforms/TileAllocation.cpp | 644 +++++++++++++-----
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 27 +-
.../ArmSMEToLLVM/tile-spills-and-fills.mlir | 15 +-
.../Conversion/ArmSMEToLLVM/unsupported.mlir | 3 +-
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 6 +
...cation.mlir => basic-tile-allocation.mlir} | 265 ++++---
mlir/test/Dialect/ArmSME/canonicalize.mlir | 10 +-
mlir/test/Dialect/ArmSME/cse.mlir | 30 -
mlir/test/Dialect/ArmSME/enable-arm-za.mlir | 20 +-
.../Dialect/ArmSME/outer-product-fusion.mlir | 7 +-
.../ArmSME/tile-allocation-invalid.mlir | 12 +-
.../ArmSME/tile-allocation-liveness.mlir | 196 ++++--
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir | 45 +-
.../Linalg/CPU/ArmSME/use-too-many-tiles.mlir | 7 +-
.../ArmSME/Emulated/test-setArmSVLBits.mlir | 3 +-
mlir/test/lib/Dialect/ArmSME/CMakeLists.txt | 1 +
.../lib/Dialect/ArmSME/TestLowerToArmSME.cpp | 19 +-
30 files changed, 1129 insertions(+), 566 deletions(-)
create mode 100644 mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
rename mlir/test/Dialect/ArmSME/{tile-allocation.mlir => basic-tile-allocation.mlir} (52%)
delete mode 100644 mlir/test/Dialect/ArmSME/cse.mlir
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab499983..403f811a2569a0 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
#include <memory>
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir {
class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab95..e6d678dc1b12b3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
// ArmSMEToLLVM
//===----------------------------------------------------------------------===//
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
"arm_sme::ArmSMEDialect",
"LLVM::LLVMDialect"
];
+ let options = [
+ Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+ "bool", /*default=*/"false",
+ "Dump the live ranges of SME tiles (for debugging)">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a74..dac54712c7f47a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 00000000000000..f31062d8c25ed7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,28 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the Target dialect for ArmSME in MLIR.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+static constexpr unsigned kInMemoryTileIdBase = 16;
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d2a..9178655f010c9a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
let description = [{
- An interface for operations that use or allocate Arm SME tiles. These
- operations need to be assigned a tile ID, an i32 attribute, which specifies
- which virtual tile within the ZA storage to use. The number of tiles
- available depends on the type of the tile. This is summarized below:
+ An interface for operations that use Arm SME tiles. These operations need to
+ be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+ within the ZA storage to use. The number of tiles available depends on the
+ type of the tile. This is summarized below:
| Tile Vector Types | Possible Tile IDs |
|-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
| `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
-
- Operations that allocate a new tile (such as arm_sme.get_tile), are used as
- the roots for tile allocation, with all operations that (transitively)
- depend on a root being assigned the same tile ID.
}];
let methods = [
InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
}]
>,
- InterfaceMethod<
- [{
- The type of tile this operation allocates. Returns none (std::nullopt)
- if this operation does not allocate a tile.
- }],
- /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
- /*methodName=*/"getAllocatedTileType",
- /*arguments=*/(ins),
- /*methodBody=*/[{}],
- /*defaultImpl=*/ [{
- // This operation does not allocate a tile.
- return std::nullopt;
- }]
- >,
InterfaceMethod<
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
];
let extraSharedClassDeclaration = [{
- // A helper to create a new operation and propagate this operations tile ID.
- template<typename T, typename... Args>
- T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
- auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
- if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
- tileOp.setTileId($_op.getTileId());
- return op;
- }
-
- // A helper to replace this operation and forward its tile ID (if present).
- template<typename T, typename... Args>
- T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
- auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
- rewriter.replaceOp($_op, newOp);
- return newOp;
- }
-
bool isInMemoryTile() {
auto tileId = getTileId();
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
}
}];
- let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+ let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
}
//===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
Op<ArmSME_Dialect, mnemonic, traits> {}
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
- let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+ let summary = "Creates an undefined value of SME virtual tile type";
let description = [{
- Allocates a new SME "virtual tile" within a function. The contents of the
- tile returned from this operation are undefined.
+ Creates a new SME "virtual tile" value within a function. The contents of
+ the tile returned from this operation are undefined.
Example 1:
```mlir
- // Allocate an 8-bit element "virtual tile"
+ // Create an 8-bit element "virtual tile" value:
%za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
```
Example 2:
```mlir
- // Allocate two 16-bit element "virtual tiles"
+ // Create two 16-bit element "virtual tiles" values:
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
```
Example 3:
```mlir
- // Allocate an 128-bit element "virtual tile"
+ // Create an 128-bit element "virtual tile" value:
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
```
}];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
VectorType getTileType() {
return ::llvm::cast<VectorType>(getTile().getType());
}
-
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getTileType());
- }
- }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
- let summary = "SME tile placeholder";
- let description = [{
- A placeholder to preserve dataflow while lowering to SME intrinsics (which
- do not take or return SME virtual tile values). This operation is intended
- to be DCE'd once all ArmSME operations have been lowered.
-
- This operation is not intended to be used outside of the ArmSME -> LLVM
- conversion.
}];
- let results = (outs SMETile:$tile);
- let assemblyFormat = "attr-dict `:` type($tile)";
}
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
- let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+ let summary = "Creates a zero-initialized value of SME virtual tile type";
let results = (outs SMETile:$res);
let description = [{
- Initialise ZA with 0. This operation is convenient wrapper for the SME
- `zero` intrinsic and instruction.
+ Creates a new SME "virtual tile" value within a function. The contents of
+ the tile returned from this operation are zero-initialized.
Example 1: Zero an 8-bit element ZA tile.
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getVectorType());
- }
VectorType getTileType() {
return getVectorType();
}
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
let assemblyFormat = "attr-dict `:` type($res)";
}
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+ Pure,
+ ArmSMETileOpInterface,
+ AllTypesMatch<["tile", "result"]>
+]> {
+ let summary = "Copies an SME tile value";
+ let arguments = (ins SMETile:$tile);
+ let results = (outs SMETile:$result);
+ let description = [{
+ Copies an SME "virtual tile" value to a new SSA value. This operation is
+ primarily intended to be used to normalize the IR prior to tile allocation.
+
+ Example:
+
+ ```mlir
+ %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getResult().getType());
+ }
+ }];
+ let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
def TileLoadOp : ArmSME_Op<"tile_load", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getVectorType());
- }
VectorType getTileType() {
return getVectorType();
}
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
}
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
- ArmSMETileOpInterface,
+ ArmSMETileOpInterface, Pure,
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
- ArmSMETileOpInterface,
+ ArmSMETileOpInterface, Pure,
TypesMatchWith<
"type of 'result' matches type of 'tile' slice",
"tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
def OuterProductOp :
ArmSME_Op<"outerproduct", [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- // The outerproduct op allocates a new tile if no accumulator is passed.
- if (!getAcc())
- return arm_sme::getSMETileType(getResultType());
- return std::nullopt;
- }
VectorType getTileType() {
return getResultType();
}
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
list<Type> allowedResultVectorTypes,
int numOuterProducts> :
ArmSME_Op<mnemonic, [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- // The outerproduct op allocates a new tile if no accumulator is passed.
- if (!getAcc())
- return arm_sme::getSMETileType(getResultType());
- return std::nullopt;
- }
VectorType getTileType() {
return getResultType();
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874ec..156744ba57e7b2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e89267..b9d74fec6756e3 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,16 +124,21 @@ def EnableArmStreaming
let dependentDialects = ["func::FuncDialect"];
}
-def TileAllocation
- : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
- let summary = "Allocate SME tiles";
+def TestTileAllocation
+ : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+ let summary = "Tests SME tile allocation";
let description = [{
This pass does tile allocation for SME "virtual tiles". It is run at the
'func.func' op level, and assigns tile IDs (via an attribute) to all ops
- that implement the `ArmSMETileOpInterface`. An error will be emitted when
- there's no tiles left.
+ that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+ to be used for testing, tile allocation is done as part of the ArmSME to
+ LLVM conversion.
}];
- let constructor = "mlir::arm_sme::createTileAllocationPass()";
+ let options = [
+ Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+ "bool", /*default=*/"false",
+ "Dump the live ranges of SME tiles (for debugging)">
+ ];
let dependentDialects = ["func::FuncDialect"];
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e69992..a25b844f01eaa6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
namespace mlir {
class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
class RewritePatternSet;
namespace arm_sme {
+
void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+ bool dumpRanges = false);
+
} // namespace arm_sme
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92f5..9ea1c5a5d63fe5 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include <optional>
namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);
+inline bool isValidSMETileVectorType(Type type) {
+ auto vType = dyn_cast<VectorType>(type);
+ return vType && isValidSMETileVectorType(vType);
+}
+
/// Returns the type of SME tile this vector type corresponds to, or none if the
/// vector type does not fit within an SME tile.
std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,19 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
VectorType getSMETileTypeForElement(Type elementType);
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+ FunctionOpInterface function);
+
+/// Returns true if `tileOp` can be cloned to resolve conflicts.
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns true if `tileOp` produces a tile result.
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns the tile `OpOperand` for this `tileOp` (or null).
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 1ba1b88fc1234b..488f747af050f5 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -245,6 +246,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
if (!tileOp.isInMemoryTile())
return failure();
+ tileOp->emitWarning(
+ "failed to allocate SME virtual tile to operation, all tile "
+ "operations will go through memory, expect degraded performance");
+
// Step 1. Create an alloca for the tile at the top of the function (if one
// does not already exist).
auto loc = tileOp.getLoc();
@@ -391,20 +396,6 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
(addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
}
-struct GetTileConversion
- : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
- RequiresSpillsAndFills::No> {
- using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
- getTile, getTile.getTileType());
- return success();
- }
-};
-
/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
@@ -436,7 +427,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
// The base mask is just the mask to zero the first tile (of a size).
// These masks are derived from:
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
- arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
+ arm_sme::ArmSMETileType tileType =
+ *arm_sme::getSMETileType(zero.getTileType());
auto baseMaskForSize = [&] {
switch (tileType) {
case arm_sme::ArmSMETileType::ZAB:
@@ -488,8 +480,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
loc, rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
- rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
- zero, zero.getVectorType());
+ rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
return success();
}
@@ -746,10 +737,12 @@ struct OuterProductOpConversion
auto loc = outerProductOp.getLoc();
Value acc = outerProductOp.getAcc();
- if (!acc)
+ if (!acc) {
// Initalize accumulator with zero.
- acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, loc, resultVectorType);
+ auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+ zero.setTileId(tileId);
+ acc = zero;
+ }
Value lhsMask = outerProductOp.getLhsMask();
Value rhsMask = outerProductOp.getRhsMask();
@@ -791,25 +784,27 @@ struct OuterProductWideningOpConversion
if (!tileId)
return failure();
+ auto loc = op.getLoc();
Value acc = op.getAcc();
- if (!acc)
+ if (!acc) {
// Initalize accumulator with zero.
- acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, op.getLoc(), op.getResultType());
+ auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+ zero.setTileId(tileId);
+ acc = zero;
+ }
Value lhsMask = op.getLhsMask();
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
Value allActiveMask = rewriter.create<arith::ConstantOp>(
- op.getLoc(), DenseElementsAttr::get(predTy, true));
+ loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
- rhsMask, adaptor.getLhs(),
- adaptor.getRhs());
+ rewriter.create<OuterProductWideningIntrOp>(
+ loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -865,15 +860,22 @@ namespace {
struct ConvertArmSMEToLLVMPass
: public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+ ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+ this->dumpTileLiveRanges = dumpTileLiveRanges;
+ }
void runOnOperation() override {
+ auto function = getOperation();
+
+ if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
+ return signalPassFailure();
+
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
configureArmSMEToLLVMConversionLegality(target);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
};
@@ -883,7 +885,7 @@ struct ConvertArmSMEToLLVMPass
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
- arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
+ arm_sme::GetTileOp, arm_sme::CopyTileOp, arm_sme::aarch64_sme_zero,
arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
@@ -955,9 +957,10 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
arm_sme::aarch64_sme_usmopa_wide>,
OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
arm_sme::aarch64_sme_usmops_wide>,
- ZeroOpConversion, GetTileConversion>(patterns, converter);
+ ZeroOpConversion>(patterns, converter);
}
-std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
- return std::make_unique<ConvertArmSMEToLLVMPass>();
+std::unique_ptr<Pass>
+mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+ return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
}
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 16b61c282749cf..9f55932c33af66 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -196,12 +196,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive
// rows however, no load will occur so these need to be zeroed.
- initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, loc, tileType);
+ initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
} else {
- // Allocate a new SME tile.
- initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
}
// Create a loop to load the active tile slices from memory.
@@ -212,10 +209,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value currentTile) -> Value {
// Create 'arm_sme.load_tile_slice' to load tile slice from memory
// into tile.
- return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
- currentTile, memrefIndices, tileSliceIndex,
- tileLoadOp.getLayout());
+ return rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
+ memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
});
if (failed(forOp))
@@ -292,9 +288,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), numCols);
- // Allocate a new SME tile.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -339,10 +333,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
/*passthru=*/pad1DOp);
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
- auto moveSlice =
- tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
- rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
- tileSliceIndex, tileLoadOp.getLayout());
+ auto moveSlice = rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
+ tileLoadOp.getLayout());
rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
@@ -386,8 +379,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
tileStoreOp.getMask(),
[&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
- tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
- rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+ rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+ tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
predicate, tileStoreOp.getBase(), memrefIndices,
tileStoreOp.getLayout());
});
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 29fa9085a0a963..cb3a6658448720 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -20,6 +20,12 @@
using namespace mlir;
using namespace mlir::arm_sme;
+namespace mlir::arm_sme::detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *op) {
+ return verifyOperationHasValidTileId(op);
+}
+} // namespace mlir::arm_sme::detail
+
//===----------------------------------------------------------------------===//
// Tablegen Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 6a9e0221822267..00d764bf5caff9 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -116,4 +116,50 @@ VectorType getSMETileTypeForElement(Type elementType) {
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
}
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ SmallVector<Operation *> worklist;
+ function->walk([&](Operation *op) {
+ auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
+ if (armSMEOp && isOpTriviallyDead(armSMEOp))
+ worklist.push_back(armSMEOp);
+ });
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (!isOpTriviallyDead(op))
+ continue;
+ for (Value value : op->getOperands()) {
+ if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
+ worklist.push_back(armSMEOp);
+ }
+ rewriter.eraseOp(op);
+ }
+}
+
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
+ return tileOp && tileOp->getNumResults() == 1 &&
+ tileOp->getNumOperands() == 0 && isPure(tileOp);
+}
+
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
+ for (Value result : tileOp->getResults()) {
+ if (arm_sme::isValidSMETileVectorType(result.getType()))
+ return true;
+ }
+ return false;
+}
+
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
+ auto isTileOperandType = [](OpOperand &operand) {
+ return arm_sme::isValidSMETileVectorType(operand.get().getType());
+ };
+ OpOperand *tileOperand =
+ llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
+ assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
+ "expected at most one tile operand");
+ if (tileOperand == tileOp->getOpOperands().end())
+ return nullptr;
+ return tileOperand;
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 4acb2a8fb7b539..e3cf52078a24eb 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
-// This pass allocates SME tiles at the 'func.func' op level for ArmSME
+// This transform allocates SME tiles at the 'func.func' op level for ArmSME
// operations. It does this using a 16-bit tile mask that has a bit for each
// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
//
@@ -32,39 +32,31 @@
// ZA6.D ZA6.Q, ZA14.Q
// ZA7.D ZA7.Q, ZA15.Q
//
-// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
-// that is initalized during the first tile allocation within a function and
-// updated on each subsequent allocation.
-//
// [1] https://developer.arm.com/documentation/ddi0616/aa
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
-#define DEBUG_TYPE "allocate-arm-sme-tiles"
-
-namespace mlir {
-namespace arm_sme {
-#define GEN_PASS_DEF_TILEALLOCATION
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_TESTTILEALLOCATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
-static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
-static constexpr StringLiteral
- kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
-
enum class TileMask : unsigned {
// clang-format off
kZA0B = 0xffff, // 1111 1111 1111 1111
@@ -137,172 +129,510 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
}
}
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
- TileMask &tilesInUse) {
- auto masks = getMasks(tileType);
- for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
- if ((tilesInUse & tileMask) == TileMask::kNone) {
- tilesInUse |= tileMask;
- return tileId;
+class TileAllocator {
+public:
+ /// Allocates and returns a tile ID.
+ FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+ if ((tilesInUse & tileMask) == TileMask::kNone) {
+ tilesInUse |= tileMask;
+ return tileId;
+ }
}
+ return failure();
+ }
+
+ /// Releases a previously allocated tile ID.
+ void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) != TileMask::kNone &&
+ "cannot release unallocated tile!");
+ tilesInUse ^= tileMask;
+ }
+
+ /// Allocates an in-memory tile ID.
+ unsigned allocateInMemoryTileId() {
+ // Note: We never release in-memory tile IDs. We could, which may allow
+ // reusing an allocation, but as we _never_ want to spill an SME tile this
+ // is not optimized.
+ return nextInMemoryTileId++;
}
- return failure();
-}
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
- SetVector<Operation *> &dependantOps) {
- auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
- for (auto [idx, value] : llvm::enumerate(inputValues)) {
- if (value == rootValue)
- findDependantOps(exitValues[idx], dependantOps);
+private:
+ TileMask tilesInUse = TileMask::kNone;
+ unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+// Add new intermediate blocks for the true and false destinations of a
+// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
+// branches.
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+ SmallVector<cf::CondBranchOp> worklist;
+ function.walk([&](cf::CondBranchOp condBranch) {
+ if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+ return isValidSMETileVectorType(value.getType());
+ })) {
+ worklist.push_back(condBranch);
}
+ });
+
+ auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+ rewriter.setInsertionPointToEnd(source);
+ rewriter.create<cf::BranchOp>(loc, dest, args);
};
- for (Operation *user : rootValue.getUsers()) {
- if (dependantOps.contains(user))
+
+ for (auto condBranch : worklist) {
+ auto loc = condBranch.getLoc();
+ Block *block = condBranch->getBlock();
+ auto newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+ condBranch.getTrueDestOperands());
+ insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+ condBranch.getFalseDestOperands());
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ }
+}
+
+/// Inserts tile copies `cf.br` operations.
+void insertCopiesAtBranches(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ splitCondBranches(rewriter, function);
+ for (Block &block : function.getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<cf::BranchOp>(terminator))
continue;
- dependantOps.insert(user);
- TypeSwitch<Operation *>(user)
- .Case<cf::BranchOp>([&](auto branchOp) {
- // (CF) Follow branch.
- traverseCorrespondingValues(branchOp.getDestOperands(),
- branchOp.getDest()->getArguments());
- })
- .Case<cf::CondBranchOp>([&](auto condBranchOp) {
- // (CF) Follow true branch.
- traverseCorrespondingValues(
- condBranchOp.getTrueOperands(),
- condBranchOp.getTrueDest()->getArguments());
- // (CF) Follow false branch.
- traverseCorrespondingValues(
- condBranchOp.getFalseOperands(),
- condBranchOp.getFalseDest()->getArguments());
- })
- .Case<LoopLikeOpInterface>([&](auto loopOp) {
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
- traverseCorrespondingValues(loopOp.getInits(),
- loopOp.getRegionIterArgs());
- })
- .Case<scf::YieldOp>([&](auto yieldOp) {
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
- auto parent = user->getParentOp();
- traverseCorrespondingValues(user->getOperands(),
- parent->getResults());
+ rewriter.setInsertionPoint(terminator);
+ for (OpOperand &operand : terminator->getOpOperands()) {
+ if (isValidSMETileVectorType(operand.get().getType())) {
+ auto copy =
+ rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ operand.assign(copy);
+ }
+ }
+ }
+}
+
+/// A range where a tile value is live. The range may contain holes.
+struct LiveRange {
+ using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+ llvm::IntervalMapHalfOpenInfo<unsigned>>;
+ using Allocator = RangeSet::Allocator;
+ static constexpr uint8_t kValidLiveRange = 0xff;
+
+ LiveRange(Allocator &allocator)
+ : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+ /// Returns true if this range overlaps with `otherRange`.
+ bool overlaps(LiveRange const &otherRange) const {
+ return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+ *otherRange.ranges)
+ .valid();
+ }
+
+ /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+ void unionWith(LiveRange const &otherRange) {
+ for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+ ++it)
+ ranges->insert(it.start(), it.stop(), kValidLiveRange);
+ values.set_union(otherRange.values);
+ }
+
+ /// Inserts an interval [start, end) for `value` into this range.
+ void insert(Value value, unsigned start, unsigned end) {
+ values.insert(value);
+ if (start != end)
+ ranges->insert(start, end, kValidLiveRange);
+ }
+
+ bool empty() const { return ranges->empty(); }
+ unsigned start() const { return ranges->start(); }
+ unsigned end() const { return ranges->stop(); }
+ bool operator<(LiveRange const &other) const {
+ return start() < other.start();
+ }
+
+ ArmSMETileType getTileType() const {
+ return *getSMETileType(cast<VectorType>(values[0].getType()));
+ }
+
+ std::unique_ptr<RangeSet> ranges;
+ SetVector<Value> values;
+ std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+ unsigned index = 0;
+ SetVector<Block *> blocks =
+ getTopologicallySortedBlocks(function.getFunctionBody());
+ DenseMap<Operation *, unsigned> operationToIndexMap;
+ for (Block *block : blocks) {
+ index++; // We want block args to have their own number.
+ for (Operation &op : block->getOperations()) {
+ // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+ op.walk([&](ArmSMETileOpInterface nestedOp) {
+ if (&op != nestedOp.getOperation()) {
+ assert(false &&
+ "ArmSME tile allocation does not support nested regions");
+ }
+ });
+#endif
+ operationToIndexMap.try_emplace(&op, index++);
+ }
+ }
+ return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ LiveRange::Allocator &liveRangeAllocator,
+ Liveness &liveness, FunctionOpInterface function) {
+ DenseMap<Value, LiveRange> liveRanges;
+ auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
+ LivenessBlockInfo const &livenessInfo,
+ bool liveAtBlockEntry = false) {
+ if (!isValidSMETileVectorType(value.getType()))
+ return;
+ auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+ auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ unsigned start =
+ operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+ unsigned end = operationToIndexMap.at(lastUseInBlock);
+ it->second.insert(value, start, end);
+ };
+
+ for (Block &block : function.getBlocks()) {
+ LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+ // Handle block arguments:
+ for (Value argument : block.getArguments())
+ updateLiveRanges(argument, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle live-ins:
+ for (Value liveIn : livenessInfo->in())
+ updateLiveRanges(liveIn, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle new definitions:
+ for (Operation &op : block) {
+ for (Value result : op.getResults())
+ updateLiveRanges(result, &op, *livenessInfo);
+ }
+ }
+
+ return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+ function_ref<void(Value)> callback) {
+ Block *block = blockArg.getOwner();
+ unsigned argNumber = blockArg.getArgNumber();
+ for (Block *pred : block->getPredecessors()) {
+ TypeSwitch<Operation *>(pred->getTerminator())
+ .Case<cf::BranchOp>([&](auto branch) {
+ Value predecessorOperand = branch.getDestOperands()[argNumber];
+ callback(predecessorOperand);
})
- .Default([&](auto) {
- // Otherwise, assume users of _any_ result are dependant.
- for (Value result : user->getResults())
- findDependantOps(result, dependantOps);
+ .Case<cf::CondBranchOp>([&](auto condBranch) {
+ if (condBranch.getFalseDest() == block) {
+ Value predecessorOperand =
+ condBranch.getFalseDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
+ if (condBranch.getTrueDest() == block) {
+ Value predecessorOperand =
+ condBranch.getTrueDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
});
}
}
-struct AssignTileIDsPattern
- : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
- PatternRewriter &rewriter) const override {
- if (tileOp.getTileId())
- return failure();
-
- auto func = tileOp->getParentOfType<FunctionOpInterface>();
- auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
- if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
- func->getDiscardableAttr(name)))
- return unsigned(attr.getInt());
- return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+ DenseMap<Value, LiveRange *> liveRanges;
+ for (auto &[value, liveRange] : initialLiveRanges) {
+ liveRanges.insert({value, &liveRange});
+ }
+
+ auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+ LiveRange *aLiveRange = liveRanges.at(a);
+ LiveRange *bLiveRange = liveRanges.at(b);
+ if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+ aLiveRange->unionWith(*bLiveRange);
+ for (Value value : bLiveRange->values)
+ liveRanges[value] = aLiveRange;
+ }
+ };
+
+ // Merge the live ranges of new definitions with their tile operands.
+ auto unifyDefinitionsWithOperands = [&](Value value) {
+ auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ if (!armSMEOp)
+ return;
+ for (auto operand : armSMEOp->getOperands()) {
+ if (isValidSMETileVectorType(operand.getType()))
+ mergeValuesIfNonOverlapping(value, operand);
+ }
+ };
+
+ // Merge the live ranges of block arguments with their predecessors.
+ auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+ auto blockArg = dyn_cast<BlockArgument>(value);
+ if (!blockArg)
+ return;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+ });
+ };
+
+ auto applyRule = [&](auto rule) {
+ llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+ };
+
+ // Unify as many live ranges as we can. This prevents unnecessary moves.
+ applyRule(unifyBlockArgumentsWithPredecessors);
+ applyRule(unifyDefinitionsWithOperands);
+
+ // Remove duplicate live range entries.
+ SetVector<LiveRange *> uniqueLiveRanges;
+ for (auto [_, liveRange] : liveRanges) {
+ if (!liveRange->empty())
+ uniqueLiveRanges.insert(liveRange);
+ }
+
+ // Sort the new live ranges by starting point (ready for tile allocation).
+ auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+ std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+ [](LiveRange *a, LiveRange *b) { return *a < *b; });
+ return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+ TileAllocator tileAllocator;
+ SetVector<LiveRange *> allocatedRanges;
+
+ auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
+ unsigned memoryTileId = tileAllocator.allocateInMemoryTileId();
+ auto spillActiveRange = [&](LiveRange *range) {
+ unsigned tileId = *range->tileId;
+ range->tileId = memoryTileId;
+ allocatedRanges.remove(range);
+ return tileId;
};
- auto setDiscardableIntAttr = [&](StringRef name, auto value) {
- rewriter.modifyOpInPlace(tileOp, [&] {
- func->setDiscardableAttr(name,
- rewriter.getI32IntegerAttr((unsigned)value));
- });
+
+ auto isTrivialSpill = [](LiveRange *allocatedRange) {
+ return allocatedRange->values.size() == 1 &&
+ isTriviallyCloneableTileOp(
+ allocatedRange->values[0]
+ .getDefiningOp<ArmSMETileOpInterface>());
};
- std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
- if (!tileType)
- return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
- TileMask tilesInUse =
- static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
- auto tileId = allocateTileId(*tileType, tilesInUse);
- bool tileIsInMemory = failed(tileId);
- if (tileIsInMemory) {
- // If we could not find a real tile ID, use an in-memory tile ID (ID >=
- // 16). A later pass will insert the necessary spills and reloads.
- tileId =
- getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
- tileOp->emitWarning(
- "failed to allocate SME virtual tile to operation, all tile "
- "operations will go through memory, expect degraded performance");
- }
+ // Heuristic: Spill trivially copyable operations (usually free).
+ if (isTrivialSpill(newRange))
+ return memoryTileId;
+ auto trivialSpill = llvm::find_if(allocatedRanges, isTrivialSpill);
+ if (trivialSpill != allocatedRanges.end())
+ return spillActiveRange(*trivialSpill);
+
+ // Heuristic: Spill the live range that ends last.
+ LiveRange *lastActiveLiveRange = *std::max_element(
+ allocatedRanges.begin(), allocatedRanges.end(),
+ [](LiveRange *a, LiveRange *b) { return a->end() < b->end(); });
+ if (lastActiveLiveRange->end() >= newRange->end())
+ return spillActiveRange(lastActiveLiveRange);
+
+ return memoryTileId;
+ };
- // Set all operations dependent on `tileOp` to use the same tile ID.
- // This is a naive tile allocation scheme, but works for common cases. For
- // example, as this only allocates tile IDs to existing ops, it can't solve
- // cases like this (%tileA and %tileB come from different root operations):
- //
- // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
- // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
- // } else {
- // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
- // }
- //
- // This case would require allocating a new tile for the result of the
- // scf.if, and moving the contents of %tileA or %tileB to result tile (based
- // on the %some_cond).
- // Find all the ops that (transitively) depend on this tile.
- SetVector<Operation *> dependantOps;
- findDependantOps(tileOp->getResult(0), dependantOps);
- auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
- for (auto *op : dependantOps) {
- if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- auto currentTileId = dependantTileOp.getTileId();
- if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
- return dependantTileOp.emitOpError(
- "already assigned different SME virtual tile!");
+ for (LiveRange *newRange : liveRanges) {
+ // Release tiles from live ranges that have ended.
+ allocatedRanges.remove_if([&](LiveRange *allocatedRange) {
+ if (allocatedRange->end() <= newRange->start()) {
+ tileAllocator.releaseTileId(allocatedRange->getTileType(),
+ *allocatedRange->tileId);
+ return true;
}
- }
+ return false;
+ });
- // Rewrite IR.
- if (!tileIsInMemory)
- setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+ // Allocate a tile ID to `newRange`.
+ auto tileId = tileAllocator.allocateTileId(newRange->getTileType());
+ if (succeeded(tileId))
+ newRange->tileId = *tileId;
else
- setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
- rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
- for (auto *op : dependantOps) {
- if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- rewriter.modifyOpInPlace(
- dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
+ newRange->tileId = chooseSpillUsingHeuristics(newRange);
+
+ // Insert the live range into the allocated ranges.
+ if (newRange->tileId < kInMemoryTileIdBase)
+ allocatedRanges.insert(newRange);
+ }
+}
+
+/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
+LogicalResult assignTileIdsAndResolveTrivialConflicts(
+ IRRewriter &rewriter, FunctionOpInterface function,
+ ArrayRef<LiveRange *> allocatedLiveRanges) {
+ for (LiveRange const *liveRange : allocatedLiveRanges) {
+ auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
+ auto isAllocatedToSameTile = [&](Value value) {
+ if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ tileOp && tileOp.getTileId() == tileIdAttr)
+ return true;
+ return liveRange->values.contains(value);
+ };
+ for (Value value : liveRange->values) {
+ for (Operation *user : value.getUsers()) {
+ if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+ // Ensure ArmSME ops that don't produce a value still get a tile ID.
+ if (!hasTileResult(tileOp))
+ tileOp.setTileId(tileIdAttr);
+ }
+ }
+ auto copyOp = value.getDefiningOp<CopyTileOp>();
+ if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
+ // Fold redundant copies.
+ rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
+ } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
+ tileOp.setTileId(tileIdAttr);
+ // Rectify operand tile IDs with result tile IDs.
+ OpOperand *tileOperand = getTileOpOperand(tileOp);
+ if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
+ continue;
+ auto operandTileOp =
+ tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
+ if (!isTriviallyCloneableTileOp(operandTileOp))
+ return tileOp.emitOpError("failed to rectify tile operand with tile "
+ "result (move required)");
+ // Cloning prevents a move/spill (though may require recomputation).
+ rewriter.setInsertionPoint(tileOp);
+ auto clonedOp = operandTileOp.clone();
+ clonedOp.setTileId(tileOp.getTileId());
+ rewriter.insert(clonedOp);
+ if (copyOp)
+ rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
+ else
+ tileOperand->assign(clonedOp->getResult(0));
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ // Validate block arguments.
+ bool tileMismatch = false;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ if (tileMismatch)
+ return;
+ if (!isAllocatedToSameTile(predecessorTile)) {
+ blockArg.getOwner()->getParentOp()->emitOpError(
+ "block argument not allocated to the same tile as "
+ "predecessors");
+ tileMismatch = true;
+ }
+ });
+ if (tileMismatch)
+ return failure();
}
}
+ }
+ return success();
+}
- return success();
+/// Prints live ranges alongside operation names for debugging.
+void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ ArrayRef<LiveRange const *> liveRanges,
+ FunctionOpInterface function) {
+ llvm::errs() << "SME Tile Liveness: @" << function.getName()
+ << "\nKey:\nS - Start\nE - End\n| - Live\n";
+ for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
+ llvm::errs() << "^bb" << blockIdx << ":\n";
+ for (Operation &op : block.getOperations()) {
+ unsigned operationIndex = operationToIndexMap.at(&op);
+ for (LiveRange const *range : liveRanges) {
+ char liveness = ' ';
+ for (auto it = range->ranges->begin(); it != range->ranges->end();
+ ++it) {
+ if (it.start() == operationIndex)
+ liveness = (liveness == 'E' ? '|' : 'S');
+ else if (it.stop() == operationIndex)
+ liveness = (liveness == 'S' ? '|' : 'E');
+ else if (operationIndex >= it.start() && operationIndex < it.stop())
+ liveness = '|';
+ }
+ llvm::errs() << liveness;
+ }
+ llvm::errs() << ' ' << op.getName() << '\n';
+ }
}
-};
+ llvm::errs() << "==========\n";
+}
-struct TileAllocationPass
- : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
+struct TestTileAllocationPass
+ : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
+ using TestTileAllocationBase::TestTileAllocationBase;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- patterns.add<AssignTileIDsPattern>(patterns.getContext());
- GreedyRewriteConfig config;
- // Setting useTopDownTraversal ensures tiles are allocated in program
- // order.
- config.useTopDownTraversal = true;
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(patterns), config))) {
+ if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
signalPassFailure();
- }
}
};
} // namespace
-std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
- return std::make_unique<TileAllocationPass>();
+LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
+ bool dumpRanges) {
+ LiveRange::Allocator liveRangeAllocator;
+ IRRewriter rewriter(function.getContext());
+
+ // 1. Insert copy operations at branch operations.
+ insertCopiesAtBranches(rewriter, function);
+
+ // 2. Gather live ranges for each ArmSME tile within the function.
+ Liveness liveness(function);
+ auto operationToIndexMap = generateOperationNumbering(function);
+ auto initialLiveRanges = gatherTileLiveRanges(
+ operationToIndexMap, liveRangeAllocator, liveness, function);
+ if (initialLiveRanges.empty())
+ return success();
+
+ if (dumpRanges) {
+ // Wrangle initial live ranges into a form suitable for printing.
+ auto nonEmpty = llvm::make_filter_range(
+ llvm::make_second_range(initialLiveRanges),
+ [&](LiveRange const &liveRange) { return !liveRange.empty(); });
+ auto initialRanges = llvm::to_vector(llvm::map_range(
+ nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
+ std::sort(initialRanges.begin(), initialRanges.end(),
+ [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
+ llvm::errs() << "\n========== Initial Live Ranges:\n";
+ dumpLiveRanges(operationToIndexMap, initialRanges, function);
+ }
+
+ // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
+ // for tile allocation. E.g. Unify the result of an operation with its
+ // operands.
+ auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
+
+ if (dumpRanges) {
+ llvm::errs() << "\n========== Coalesced Live Ranges:\n";
+ dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
+ }
+
+ // 4. Allocate tile IDs to live ranges.
+ allocateTilesToLiveRanges(coalescedLiveRanges);
+
+ // 5. Assign the tile IDs back to the ArmSME operations.
+ if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
+ coalescedLiveRanges))) {
+ return failure();
+ }
+
+ /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+ /// users). This prevents the LLVM conversion needlessly inserting spills.
+ eraseTriviallyDeadTileOps(rewriter, function);
+ return success();
}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 81087cc02099fb..2922940460219e 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
-
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file -verify-diagnostics | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
//===----------------------------------------------------------------------===//
@@ -25,6 +24,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -36,6 +36,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
return
}
@@ -47,6 +48,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -58,6 +60,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
return
}
@@ -69,6 +72,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+ "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
return
}
@@ -80,6 +84,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -91,6 +96,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
return
}
@@ -102,6 +108,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -113,6 +120,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
return
}
@@ -124,6 +132,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -135,6 +144,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
return
}
@@ -146,6 +156,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -157,6 +168,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
return
}
@@ -168,6 +180,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+ "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
return
}
@@ -179,6 +192,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -190,6 +204,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
return
}
@@ -201,6 +216,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -212,6 +228,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
return
}
@@ -441,7 +458,8 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : v
func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
- arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -452,7 +470,8 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
- arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
return
}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 7a9e6b42157549..803f15b4b7d05f 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \
// RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \
// RUN: -split-input-file -verify-diagnostics | \
// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING
@@ -56,6 +56,9 @@ func.func @use_too_many_tiles() {
%1 = arm_sme.zero : vector<[4]x[4]xi32>
// expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%2 = arm_sme.zero : vector<[8]x[8]xi16>
+ "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> ()
return
}
// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles
@@ -109,18 +112,16 @@ func.func @use_too_many_tiles() {
/// Note: In this example an entire tile swap is inserted before/after the
/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a
/// single tile slice (and can omit the initial load, like in the previous example).
-func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
- %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8>
+func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref<?x?xf32>) -> vector<[4]x[4]xf32> {
%c0 = arith.constant 0 : index
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%mask = vector.constant_mask [4] : vector<[4]xi1>
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+ "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> ()
return %loadSlice : vector<[4]x[4]xf32>
}
// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
-// AFTER-TILE-ALLOC: arm_sme.get_tile
-// AFTER-TILE-ALLOC-SAME: tile_id = 0
// AFTER-TILE-ALLOC: arm_sme.load_tile_slice
// AFTER-TILE-ALLOC-SAME: tile_id = 16
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 59665c471921d5..1d7c266c84f0a9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" \
+// RUN: -split-input-file -allow-unregistered-dialect -verify-diagnostics
//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 6c393bc38af9c7..a2f2beff78c409 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -20,6 +20,7 @@
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -30,6 +31,7 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -60,6 +62,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
%pad = arith.constant 0 : i32
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -94,6 +97,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -104,6 +108,7 @@ func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32
%pad = arith.constant 0 : i32
// expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -113,6 +118,7 @@ func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?x
%c0 = arith.constant 0 : index
// expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
similarity index 52%
rename from mlir/test/Dialect/ArmSME/tile-allocation.mlir
rename to mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index 9c368dd4fa23f8..9c42cbfe5357cb 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
// -----
+// Note: Tile IDs >= 16 are in-memory tile IDs (i.e. spills).
+
// CHECK-LABEL: mixed_tiles
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
func.func @mixed_tiles() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// CHECK-NEXT: tile_id = 0
@@ -18,76 +19,71 @@ func.func @mixed_tiles() {
// CHECK-NEXT: tile_id = 7
%za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// ZA15.Q is still free.
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_b
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_b() {
// CHECK-NEXT: tile_id = 0
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- return
-}
-
-// -----
-
-func.func @za_b__out_of_tiles() {
- %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
+ "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+ "test.some_use"(%next_tile) : (vector<[16]x[16]xi8>) -> ()
return
}
// -----
+// CHECK-LABEL: za_b_overlapping_za_q
func.func @za_b_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za0_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
func.func @za0_h() {
// CHECK-NEXT: tile_id = 0
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
return
}
// -----
// CHECK-LABEL: za_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h() {
// CHECK-NEXT: tile_id = 0
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// CHECK-NEXT: tile_id = 1
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- return
-}
-
-// -----
-
-// CHECK-LABEL: za_h__out_of_tiles
-func.func @za_h__out_of_tiles() {
- // CHECK-NEXT: tile_id = 0
- %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- // CHECK-NEXT: tile_id = 1
- %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%next_tile) : (vector<[8]x[8]xi16>) -> ()
return
}
// -----
// CHECK-LABEL: za_h_overlapping_za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h_overlapping_za_s() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// CHECK-NEXT: tile_id = 0
@@ -98,13 +94,15 @@ func.func @za_h_overlapping_za_s() {
// ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
// CHECK-NEXT: tile_id = 3
%za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
// CHECK-LABEL: za_h_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h_overlapping_za_d() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// CHECK-NEXT: tile_id = 0
@@ -121,40 +119,65 @@ func.func @za_h_overlapping_za_d() {
// ZA7.Q, ZA15.Q
// CHECK-NEXT: tile_id = 7
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
+// CHECK-LABEL: za_h_overlapping_za_q
func.func @za_h_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // CHECK-NEXT: tile_id = 1
+ %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
+ %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
+ %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
+ %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
+ %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
+ %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
+ %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za0_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
func.func @za0_s() {
// CHECK-NEXT: tile_id = 0
%za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
// CHECK-LABEL: za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_s() {
// CHECK-NEXT: tile_id = 0
%za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
@@ -164,25 +187,20 @@ func.func @za_s() {
%za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK-NEXT: tile_id = 3
%za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- return
-}
-
-// -----
-
-func.func @za_s__out_of_tiles() {
- %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za2_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%next_tile) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
// CHECK-LABEL: za_s_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_s_overlapping_za_d() {
// ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
// CHECK-NEXT: tile_id = 0
@@ -199,44 +217,77 @@ func.func @za_s_overlapping_za_d() {
// ZA7.Q, ZA15.Q
// CHECK-NEXT: tile_id = 7
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za2_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
+// CHECK-LABEL: za_s_overlapping_za_q
func.func @za_s_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ // CHECK-NEXT: tile_id = 1
%za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 2
%za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
%za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
%za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 6
%za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
%za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
%za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 10
%za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
%za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 14
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za0_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
func.func @za0_d() {
// CHECK-NEXT: tile_id = 0
%za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
// CHECK-LABEL: za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_d() {
// CHECK-NEXT: tile_id = 0
%za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
@@ -254,62 +305,90 @@ func.func @za_d() {
%za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// CHECK-NEXT: tile_id = 7
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- return
-}
-
-// -----
-
-func.func @za_d__out_of_tiles() {
- %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za2_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za4_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za6_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%next_tile) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
+// CHECK-LABEL: za_d_overlapping_za_q
func.func @za_d_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 1
%za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 2
%za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
%za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 4
%za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
%za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 6
%za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
%za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
%za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 10
%za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
%za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 12
%za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 14
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za4_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za12_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za0_q
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
func.func @za0_q() {
// CHECK-NEXT: tile_id = 0
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_q) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_q
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_q() {
// CHECK-NEXT: tile_id = 0
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
@@ -343,29 +422,25 @@ func.func @za_q() {
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// CHECK-NEXT: tile_id = 15
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- return
-}
-
-// -----
-
-func.func @za_q__out_of_tiles() {
- %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za4_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za8_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za12_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index b7ba3f728c705a..d9e3d66e370ef3 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,18 +1,16 @@
// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
-// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
-// once it becomes unused, after lowering to control flow.
+// This tests that dead tile values are removed from control flow.
// -----
-// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
-// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
// CHECK-NOT: vector<[4]x[4]xf32>
-func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
^bb1(%1: index, %2: vector<[4]x[4]xf32>): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %1, %c10 : index
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
deleted file mode 100644
index 74e7293eaeca5f..00000000000000
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
-// duplicates.
-
-// CHECK-LABEL: @zero_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @zero_tile() {
- %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
- %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
- "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
- return
-}
-
-// CHECK-LABEL: @get_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @get_tile() {
- %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
- %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
- "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
- "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
- return
-}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index a20203d7e55796..d3325513a84829 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,10 +1,9 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=IN-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=OUT-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=INOUT-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=PRESERVES-ZA
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za | FileCheck %s -check-prefix=IN-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za | FileCheck %s -check-prefix=OUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za | FileCheck %s -check-prefix=INOUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za | FileCheck %s -check-prefix=PRESERVES-ZA
// CHECK-LABEL: @declaration
func.func private @declaration()
@@ -22,11 +21,4 @@ func.func private @declaration()
// DISABLE-ZA-LABEL: @arm_new_za
// DISABLE-ZA-NOT: arm_new_za
// DISABLE-ZA-SAME: attributes {arm_streaming}
-// NO-ARM-STREAMING-LABEL: @arm_new_za
-// NO-ARM-STREAMING-NOT: arm_new_za
-// NO-ARM-STREAMING-NOT: arm_streaming
-// NO-ARM-STREAMING-NOT: arm_in_za
-// NO-ARM-STREAMING-NOT: arm_out_za
-// NO-ARM-STREAMING-NOT: arm_inout_za
-// NO-ARM-STREAMING-NOT: arm_preserves_za
func.func @arm_new_za() { return }
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 01f54a4cf18661..4887d611643fba 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file | FileCheck %s
// CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
@@ -929,6 +929,7 @@ func.func @outerproduct_widening_4way__missing_acc(
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
// Missing accumulator breaks use-def chain.
%3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
+ "test.some_use"(%2) : (vector<[4]x[4]xi32>) -> ()
return %3 : vector<[4]x[4]xi32>
}
@@ -1014,7 +1015,7 @@ func.func @outerproduct_widening_2way__cant_erase(
%acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
%0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
- "fake.use"(%0) : (vector<[4]x[4]xf32>) -> ()
+ "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
return %1 : vector<[4]x[4]xf32>
@@ -1048,7 +1049,7 @@ func.func @outerproduct_widening_4way__multi_use_cant_erase(
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
- "fake.use"(%1) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
%3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 39d9ab6491e3b4..06be7bd974707b 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,19 +1,17 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics
// -----
-func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %cond: i1) {
+// Select between tileA and tileB. This is currently unsupported as it would
+// require inserting (runtime) tile moves.
+func.func @selecting_between_different_tiles_is_unsupported(%dest : memref<?x?xi32>, %tileA : vector<[4]x[4]xi32>, %tileB : vector<[4]x[4]xi32>, %cond: i1) {
%c0 = arith.constant 0 : index
- %tileA = arm_sme.get_tile : vector<[4]x[4]xi32>
- %tileB = arm_sme.get_tile : vector<[4]x[4]xi32>
- // Select between tileA and tileB. This is currently unsupported as it would
- // require inserting tile move operations during tile allocation.
+ // expected-error at +1 {{op failed to rectify tile operand with tile result (move required)}}
%tile = scf.if %cond -> vector<[4]x[4]xi32> {
scf.yield %tileA : vector<[4]x[4]xi32>
} else {
scf.yield %tileB : vector<[4]x[4]xi32>
}
- // expected-error at +1 {{op already assigned different SME virtual tile!}}
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 2dedcb2fbc24e4..dd2ee55d2afc8c 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -1,18 +1,24 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK-BAD
-
-// This file tests some aspects of liveness issues in the SME tile allocator.
-// These tests were designed with a new liveness-based tile allocator in mind
-// (where the names of test cases make more sense), with the current tile
-// allocator these tests all give incorrect results (which is documented by
-// `CHECK-BAD`).
-
-// Incorrect result! The second `move_vector_to_tile_slice` overwrites the first (which is still live).
-//
-// CHECK-BAD-LABEL: @constant_with_multiple_users
-// CHECK-BAD: %[[ZERO_TILE:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation=dump-tile-live-ranges -mlir-disable-threading -split-input-file -verify-diagnostics 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-LIVE-RANGE
+
+// This file tests some simple aspects of using liveness in the SME tile allocator.
+
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE-NEXT: @constant_with_multiple_users
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |E test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+
+// CHECK-LABEL: @constant_with_multiple_users(
+// CHECK-SAME: %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
+ // CHECK-NEXT: %[[ZERO_TILE_0:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[ZERO_TILE_1:.*]] = arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_A]], %[[ZERO_TILE_1]], %{{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_B]], %[[ZERO_TILE_0]], %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
%zero = arm_sme.zero : vector<[4]x[4]xf32>
%tile_a = arm_sme.move_vector_to_tile_slice %a, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_b = arm_sme.move_vector_to_tile_slice %b, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
@@ -23,12 +29,16 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
// -----
-// (No tile IDs -- the current tile allocator ignores this case)
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE-NEXT: @value_with_multiple_users
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |E test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
-// CHECK-BAD-LABEL: @value_with_multiple_users
-// CHECK-BAD-NOT: tile_id
func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
- // A future allocator should error here (as `%tile` would need to be copied).
+ // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
%tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
"test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
@@ -38,12 +48,38 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
// -----
-// CHECK-BAD-LABEL: @reuse_tiles_after_initial_use
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE-NEXT: @reuse_tiles_after_initial_use
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+
+// CHECK-LABEL: @reuse_tiles_after_initial_use
func.func @reuse_tiles_after_initial_use() {
- // CHECK-BAD: arm_sme.get_tile {tile_id = 0 : i32}
- // CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
- // CHECK-BAD: arm_sme.get_tile {tile_id = 2 : i32}
- // CHECK-BAD: arm_sme.get_tile {tile_id = 3 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 2 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 3 : i32}
%tile_a = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_b = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_c = arm_sme.get_tile : vector<[4]x[4]xf32>
@@ -55,19 +91,13 @@ func.func @reuse_tiles_after_initial_use() {
"test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_d) : (vector<[4]x[4]xf32>) -> ()
- // -> Spills after the fourth tile (unnecessary):
- // CHECK-BAD: arm_sme.zero {tile_id = 16 : i32}
- // CHECK-BAD: arm_sme.zero {tile_id = 17 : i32}
- // CHECK-BAD: arm_sme.zero {tile_id = 18 : i32}
- // CHECK-BAD: arm_sme.zero {tile_id = 19 : i32}
- // Unnecessary spills:
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // CHECK: arm_sme.zero {tile_id = 0 : i32}
+ // CHECK: arm_sme.zero {tile_id = 1 : i32}
+ // CHECK: arm_sme.zero {tile_id = 2 : i32}
+ // CHECK: arm_sme.zero {tile_id = 3 : i32}
%tile_1 = arm_sme.zero : vector<[4]x[4]xf32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_2 = arm_sme.zero : vector<[4]x[4]xf32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_3 = arm_sme.zero : vector<[4]x[4]xf32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_4 = arm_sme.zero : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
@@ -81,16 +111,27 @@ func.func @reuse_tiles_after_initial_use() {
// -----
-// Incorrect result! Both branches should yield the result via the same tile.
-//
-// CHECK-BAD-LABEL: @non_overlapping_branches
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE-NEXT: @non_overlapping_branches
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+
+// CHECK-LABEL: @non_overlapping_branches
func.func @non_overlapping_branches(%cond: i1) {
+ // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
%tile = scf.if %cond -> vector<[4]x[4]xf32> {
+ // ^bb1:
%zero = arm_sme.zero : vector<[4]x[4]xf32>
scf.yield %zero : vector<[4]x[4]xf32>
} else {
+ // ^bb2:
%undef = arm_sme.get_tile : vector<[4]x[4]xf32>
scf.yield %undef : vector<[4]x[4]xf32>
}
@@ -100,13 +141,15 @@ func.func @non_overlapping_branches(%cond: i1) {
// -----
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @constant_loop_init_with_multiple_users
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// <deliberately omitted>
+
+// CHECK-LABEL: @constant_loop_init_with_multiple_users
func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
+ // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
@@ -126,26 +169,46 @@ func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vecto
// -----
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @run_out_of_tiles_but_avoid_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE-NEXT: @run_out_of_tiles_but_avoid_spill
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+
+// Note in the live ranges (above) there is five tile values, but we only have four tiles.
+
+// CHECK-LABEL: @run_out_of_tiles_but_avoid_spill
func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
%init = arm_sme.zero : vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
+ // Live = %init
scf.for %i = %c0 to %c10 step %c1 {
+ // CHECK: arm_sme.zero {tile_id = 1 : i32}
+ // CHECK: arm_sme.zero {tile_id = 2 : i32}
+ // CHECK: arm_sme.zero {tile_id = 3 : i32}
+ // CHECK: arm_sme.zero {tile_id = 0 : i32}
%tile_a, %tile_b, %tile_c, %tile_d = scf.for %j = %c0 to %c10 step %c1
iter_args(%iter_a = %init, %iter_b = %init, %iter_c = %init, %iter_d = %init)
-> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32> , vector<[4]x[4]xf32> , vector<[4]x[4]xf32>) {
+ // ^bb2:
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 2 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 3 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_a = arm_sme.move_vector_to_tile_slice %a, %iter_a, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_b = arm_sme.move_vector_to_tile_slice %b, %iter_b, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_c = arm_sme.move_vector_to_tile_slice %c, %iter_c, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_d = arm_sme.move_vector_to_tile_slice %d, %iter_d, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
scf.yield %new_a, %new_b, %new_c, %new_d : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
}
+ // Live = %init, %tile_a, %tile_b, %tile_c, %tile_d (out of tiles!)
+ // This should be resolved by duplicating the arm_sme.zero (from folding
+ // arm_sme.copy_tile operations inserted by the tile allocator).
"test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
@@ -156,24 +219,47 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
// -----
-// Incorrect result! Everything other than zero assigned to tile 1 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @avoidable_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32}
+// We should be able to avoid spills like this, but logic handling this case is
+// not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
+
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE-NEXT: @avoidable_spill
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: || test.some_use
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E test.some_use
+// CHECK-LIVE-RANGE-NEXT: || arith.addi
+// CHECK-LIVE-RANGE-NEXT: EE cf.br
+
+// Note in the live ranges (above) there is two constant live-ins (first two ranges),
+// which gives six overlapping live ranges. The allocator currently will spill the
+// first constant (which results in a real spill at it's use), however, this could
+// be avoided by using the knowledge that at the first "test.some_use" there's
+// actually only two live ranges (so we can fix this be duplicating the constant).
+
+// CHECK-LABEL: @avoidable_spill
func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
+ // CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[4]x[4]xf32>
%zero = arm_sme.zero : vector<[4]x[4]xf32>
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
scf.for %i = %c0 to %c10 step %c1 {
+ // So spilled here (unnecessarily).
+ // The arm_sme.zero op could be moved into the loop to avoid this.
"test.some_use"(%zero) : (vector<[4]x[4]xf32>) -> ()
%tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_c = arm_sme.move_vector_to_tile_slice %c, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_d = arm_sme.move_vector_to_tile_slice %d, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // %zero is still live here (due the the backedge)
"test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 04412e4db1c5f3..ca339be5fb56f1 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,canonicalize))" | FileCheck %s
// This test verifies the tile mask operand of the zero intrinsic zeroes
// the correct tiles. Both integer and floating-point datatypes are checked.
@@ -9,6 +9,7 @@
func.func @zero_za_b() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
%zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
+ "test.some_use"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -16,10 +17,12 @@ func.func @zero_za_b() {
// CHECK-LABEL: zero_za_h
func.func @zero_za_h() {
- // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
+ "test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -27,14 +30,18 @@ func.func @zero_za_h() {
// CHECK-LABEL: zero_za_s
func.func @zero_za_s() {
- // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
- // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
- // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
- // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -42,21 +49,29 @@ func.func @zero_za_s() {
// CHECK-LABEL: zero_za_d
func.func @zero_za_d() {
- // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
+ "test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index 588b44a36c29f3..14d9712e971a86 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -13,26 +13,29 @@
/// performance (hence the warning).
func.func @use_too_many_tiles(%a: memref<?x?xi16>, %b: memref<?x?xi16>, %c: memref<?x?xi16>) {
%c0 = arith.constant 0 : index
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_a = arith.constant dense<0> : vector<[8]x[8]xi16>
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_b = arith.constant dense<1> : vector<[8]x[8]xi16>
// expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_c = arm_sme.tile_load %a[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_d = arm_sme.tile_load %b[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_e = arm_sme.tile_load %c[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
// CHECK-LABEL: tile_a:
// CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0
vector.print str "tile_a:\n"
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
vector.print %tile_a : vector<[8]x[8]xi16>
// CHECK-LABEL: tile_b:
// CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1
vector.print str "tile_b:\n"
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
vector.print %tile_b : vector<[8]x[8]xi16>
// CHECK-LABEL: tile_c:
// CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2
vector.print str "tile_c:\n"
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
vector.print %tile_c : vector<[8]x[8]xi16>
// CHECK-LABEL: tile_d:
// CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
index 1794564a6a7244..0648e771b88913 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
@@ -1,5 +1,6 @@
// DEFINE: %{entry_point} = main
-// DEFINE: %{compile} = mlir-opt %s -convert-arm-sme-to-llvm -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm),test-lower-to-llvm)"
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
index e942c7b8ac058c..cdd8afe1414219 100644
--- a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_library(MLIRArmSMETestPasses
MLIRTransforms
MLIRVectorToArmSME
MLIRVectorToSCF
+ MLIRSCFToControlFlow
)
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index 48d4a5859f8a08..d3dabaf200fdc2 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -14,10 +14,12 @@
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -34,6 +36,10 @@ struct TestLowerToArmSMEOptions
llvm::cl::desc("Fuse outer product operations via "
"'-arm-sme-outer-product-fusion' pass"),
llvm::cl::init(true)};
+ PassOptions::Option<bool> dumpTileLiveRanges{
+ *this, "dump-tile-live-ranges",
+ llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
+ llvm::cl::init(false)};
};
void buildTestLowerToArmSME(OpPassManager &pm,
@@ -65,20 +71,17 @@ void buildTestLowerToArmSME(OpPassManager &pm,
pm.addPass(createConvertVectorToSCFPass(
VectorTransferToSCFOptions().enableFullUnroll()));
- // Allocate tiles for ArmSME operations.
- //
- // Later passes may create further ArmSME ops that implement the
- // ArmSMETileOpInterface, but tiles are allocated for root operations,
- // all of which should now exist.
- pm.addPass(arm_sme::createTileAllocationPass());
-
// Enable streaming-mode and ZA.
pm.addPass(arm_sme::createEnableArmStreamingPass(
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
/*onlyIfRequiredByOps=*/true));
+ // Convert SCF to CF (required for ArmSME tile allocation).
+ pm.addPass(createConvertSCFToCFPass());
+
// Convert ArmSME to LLVM.
- pm.addPass(createConvertArmSMEToLLVMPass());
+ pm.addNestedPass<func::FuncOp>(
+ createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
// Sprinkle some cleanups.
pm.addPass(createCanonicalizerPass());
>From b9a96c5075e269781b7cf9a820a655a644a67646 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 11:43:11 +0000
Subject: [PATCH 2/7] Review fixups
---
.../mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h | 4 ----
mlir/lib/Dialect/ArmSME/IR/Utils.cpp | 4 ++--
mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp | 10 ++++------
mlir/test/Dialect/ArmSME/roundtrip.mlir | 9 +++++++++
4 files changed, 15 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
index f31062d8c25ed7..0c26fa69c85874 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -5,10 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-//
-// This file declares the Target dialect for ArmSME in MLIR.
-//
-//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 00d764bf5caff9..6c0ebddb5a2dc3 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -153,10 +153,10 @@ OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
auto isTileOperandType = [](OpOperand &operand) {
return arm_sme::isValidSMETileVectorType(operand.get().getType());
};
- OpOperand *tileOperand =
- llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
"expected at most one tile operand");
+ OpOperand *tileOperand =
+ llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
if (tileOperand == tileOp->getOpOperands().end())
return nullptr;
return tileOperand;
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index e3cf52078a24eb..7b3381652e28aa 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -131,7 +131,7 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
class TileAllocator {
public:
- /// Allocates and returns a tile ID.
+ /// Allocates and returns a tile ID. Fails if there are no tiles left.
FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
auto masks = getMasks(tileType);
for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
@@ -198,7 +198,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
}
}
-/// Inserts tile copies `cf.br` operations.
+/// Inserts tile copies at `cf.br` operations.
void insertCopiesAtBranches(IRRewriter &rewriter,
FunctionOpInterface function) {
splitCondBranches(rewriter, function);
@@ -278,10 +278,8 @@ generateOperationNumbering(FunctionOpInterface function) {
// This is only correct if all ArmSME have been converted to CF.
#ifndef NDEBUG
op.walk([&](ArmSMETileOpInterface nestedOp) {
- if (&op != nestedOp.getOperation()) {
- assert(false &&
- "ArmSME tile allocation does not support nested regions");
- }
+ assert(&op == nestedOp.getOperation() &&
+ "ArmSME tile allocation does not support nested regions");
});
#endif
operationToIndexMap.try_emplace(&op, index++);
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ab46c7adca5966..6095fdc11ead8f 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1403,3 +1403,12 @@ func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
%reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
return %reuslt : vector<[2]x[2]xi64>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.copy_tile
+//===----------------------------------------------------------------------===//
+
+func.func @arm_sme_copy_tile(%vec: vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.copy_tile %vec : vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
>From 6f5105c90b5572181491498dd142aa43af1ef0ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 14:23:57 +0000
Subject: [PATCH 3/7] More review fixups
---
.../ArmSME/Transforms/TileAllocation.cpp | 6 +--
.../Conversion/ArmSMEToLLVM/unsupported.mlir | 5 +--
.../Dialect/ArmSME/basic-tile-allocation.mlir | 40 -------------------
mlir/test/Dialect/ArmSME/canonicalize.mlir | 4 +-
4 files changed, 6 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 7b3381652e28aa..bcc6deeabd16ad 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -414,7 +414,7 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
return std::move(coalescedLiveRanges);
}
-/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
/// Note: This does not attempt to fill holes in live/allocated ranges.
void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
TileAllocator tileAllocator;
@@ -629,8 +629,8 @@ LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
return failure();
}
- /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
- /// users). This prevents the LLVM conversion needlessly inserting spills.
+ // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+ // users). This prevents the LLVM conversion needlessly inserting spills.
eraseTriviallyDeadTileOps(rewriter, function);
return success();
}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 1d7c266c84f0a9..a62ca080ab8d98 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" \
-// RUN: -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics
//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
@@ -10,6 +9,6 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs :
// expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
// expected-error at +1 {{unsupported type}}
%0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
- "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+ "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
}
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index 9c42cbfe5357cb..eb64a7b6aac58e 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -56,16 +56,6 @@ func.func @za_b_overlapping_za_q() {
// -----
-// CHECK-LABEL: za0_h
-func.func @za0_h() {
- // CHECK-NEXT: tile_id = 0
- %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
- return
-}
-
-// -----
-
// CHECK-LABEL: za_h
func.func @za_h() {
// CHECK-NEXT: tile_id = 0
@@ -167,16 +157,6 @@ func.func @za_h_overlapping_za_q() {
// -----
-// CHECK-LABEL: za0_s
-func.func @za0_s() {
- // CHECK-NEXT: tile_id = 0
- %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
- return
-}
-
-// -----
-
// CHECK-LABEL: za_s
func.func @za_s() {
// CHECK-NEXT: tile_id = 0
@@ -277,16 +257,6 @@ func.func @za_s_overlapping_za_q() {
// -----
-// CHECK-LABEL: za0_d
-func.func @za0_d() {
- // CHECK-NEXT: tile_id = 0
- %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
- return
-}
-
-// -----
-
// CHECK-LABEL: za_d
func.func @za_d() {
// CHECK-NEXT: tile_id = 0
@@ -378,16 +348,6 @@ func.func @za_d_overlapping_za_q() {
// -----
-// CHECK-LABEL: za0_q
-func.func @za0_q() {
- // CHECK-NEXT: tile_id = 0
- %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- "test.some_use"(%za0_q) : (vector<[1]x[1]xi128>) -> ()
- return
-}
-
-// -----
-
// CHECK-LABEL: za_q
func.func @za_q() {
// CHECK-NEXT: tile_id = 0
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index d9e3d66e370ef3..643dfd4a7cbd98 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,9 +1,7 @@
-// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s
// This tests that dead tile values are removed from control flow.
-// -----
-
// CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
// CHECK-NOT: vector<[4]x[4]xf32>
func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
>From 54a8e7d05c882ec5da873d2b5bc125c39f03d05b Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 30 Apr 2024 17:04:15 +0000
Subject: [PATCH 4/7] More fixups
---
.../Dialect/ArmSME/IR/ArmSMEOpInterfaces.h | 3 +
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 6 +-
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 52 +++---
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 2 +-
.../Dialect/ArmSME/basic-tile-allocation.mlir | 2 +-
.../ArmSME/tile-allocation-invalid.mlir | 4 +-
.../ArmSME/tile-allocation-liveness.mlir | 160 ++++++++++--------
7 files changed, 123 insertions(+), 106 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
index 0c26fa69c85874..9153fbb57ea88e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -17,7 +17,10 @@ namespace detail {
LogicalResult verifyArmSMETileOpInterface(Operation *);
}
+// The first in-memory SME tile ID. This is set to 16 as that is the first tile
+// ID larger than any virtual tile ID supported by the SME ISA.
static constexpr unsigned kInMemoryTileIdBase = 16;
+
#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
} // namespace mlir::arm_sme
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index b9d74fec6756e3..8c129ea623b6f8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -126,20 +126,20 @@ def EnableArmStreaming
def TestTileAllocation
: Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
- let summary = "Tests SME tile allocation";
+ let summary = "Tests SME 'virtual tile' allocation";
let description = [{
This pass does tile allocation for SME "virtual tiles". It is run at the
'func.func' op level, and assigns tile IDs (via an attribute) to all ops
that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
to be used for testing, tile allocation is done as part of the ArmSME to
- LLVM conversion.
+ LLVM conversion (`convert-arm-sme-to-llvm`).
}];
let options = [
Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
"bool", /*default=*/"false",
"Dump the live ranges of SME tiles (for debugging)">
];
- let dependentDialects = ["func::FuncDialect"];
+ let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
}
def OuterProductFusion
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 488f747af050f5..562c58b73e7e34 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -885,34 +885,38 @@ struct ConvertArmSMEToLLVMPass
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
- arm_sme::GetTileOp, arm_sme::CopyTileOp, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
- arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
- arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
- arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
- arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
- arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
- arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
- arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
- arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
- arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
+ arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
+ arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
+ arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
+ arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
+ arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+ arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
+ arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
+ arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
+ arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
+ arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
+ arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
+ arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
+ arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
+ arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
+ arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
+ arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
+ arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
+ arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
+ arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
+ arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
+ arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
+ arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
+ arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect,
/* The following are used to lower tile spills/fills */
vector::VectorDialect, scf::SCFDialect,
memref::MemRefDialect>();
- target.addLegalOp<UnrealizedConversionCastOp>();
+ // Pseudo operations. These cannot be code-generated but may exist in the
+ // input IR, or be generated during the conversion. They need to be eliminated
+ // before the final conversion to LLVM IR (and likely will be due to DCE).
+ target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
+ UnrealizedConversionCastOp>();
}
void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 2922940460219e..14b1f323da3a28 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index eb64a7b6aac58e..8b46998d56b044 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 06be7bd974707b..b3112264cba9cb 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,6 +1,4 @@
-// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics
-
-// -----
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics
// Select between tileA and tileB. This is currently unsupported as it would
// require inserting (runtime) tile moves.
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index dd2ee55d2afc8c..a63312ded1b93f 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -3,14 +3,14 @@
// This file tests some simple aspects of using liveness in the SME tile allocator.
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// CHECK-LIVE-RANGE-NEXT: @constant_with_multiple_users
-// CHECK-LIVE-RANGE: ^bb0:
-// CHECK-LIVE-RANGE: S arm_sme.zero
-// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: |E test.some_use
-// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-LABEL: @constant_with_multiple_users
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |E test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LABEL: @constant_with_multiple_users(
// CHECK-SAME: %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
@@ -29,13 +29,13 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
// -----
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// CHECK-LIVE-RANGE-NEXT: @value_with_multiple_users
-// CHECK-LIVE-RANGE: ^bb0:
-// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: |E test.some_use
-// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-LABEL: @value_with_multiple_users
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |E test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
// expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
@@ -48,31 +48,31 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
// -----
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// CHECK-LIVE-RANGE-NEXT: @reuse_tiles_after_initial_use
-// CHECK-LIVE-RANGE: ^bb0:
-// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
-// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
-// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.get_tile
-// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.get_tile
-// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
-// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
-// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
-// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
-// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
-// CHECK-LIVE-RANGE-NEXT: E| test.some_use
-// CHECK-LIVE-RANGE-NEXT: E test.some_use
-// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
-// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero
-// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.zero
-// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.zero
-// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
-// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
-// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
-// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
-// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
-// CHECK-LIVE-RANGE-NEXT: E| test.some_use
-// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-LABEL: @reuse_tiles_after_initial_use
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
// CHECK-LABEL: @reuse_tiles_after_initial_use
func.func @reuse_tiles_after_initial_use() {
@@ -111,16 +111,16 @@ func.func @reuse_tiles_after_initial_use() {
// -----
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// CHECK-LIVE-RANGE-NEXT: @non_overlapping_branches
-// CHECK-LIVE-RANGE: ^bb1:
-// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
-// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
-// CHECK-LIVE-RANGE-NEXT: E cf.br
-// CHECK-LIVE-RANGE-NEXT: ^bb2:
-// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
-// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
-// CHECK-LIVE-RANGE-NEXT: E cf.br
+// CHECK-LIVE-RANGE-LABEL: @non_overlapping_branches
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: E cf.br
// CHECK-LABEL: @non_overlapping_branches
func.func @non_overlapping_branches(%cond: i1) {
@@ -141,8 +141,20 @@ func.func @non_overlapping_branches(%cond: i1) {
// -----
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// <deliberately omitted>
+// Here %vecA and %vecB are not merged into the same live range (as they are unknown values).
+// This means that %vecA and %vecB are both allocated to different tiles (which is not legal).
+func.func @overlapping_branches(%cond: i1, %vecA: vector<[4]x[4]xf32>, %vecB: vector<[4]x[4]xf32>) {
+ // expected-error at below {{op failed to rectify tile operand with tile result (move required)}}
+ %tile = scf.if %cond -> vector<[4]x[4]xf32> {
+ scf.yield %vecA : vector<[4]x[4]xf32>
+ } else {
+ scf.yield %vecB : vector<[4]x[4]xf32>
+ }
+ "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
// CHECK-LABEL: @constant_loop_init_with_multiple_users
func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
@@ -169,14 +181,14 @@ func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vecto
// -----
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// CHECK-LIVE-RANGE-NEXT: @run_out_of_tiles_but_avoid_spill
-// CHECK-LIVE-RANGE: ^bb2:
-// CHECK-LIVE-RANGE-NEXT: |S arm_sme.copy_tile
-// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.copy_tile
-// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.copy_tile
-// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
-// CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+// CHECK-LIVE-RANGE-LABEL: @run_out_of_tiles_but_avoid_spill
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
// Note in the live ranges (above) there is five tile values, but we only have four tiles.
@@ -222,20 +234,20 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
// We should be able to avoid spills like this, but logic handling this case is
// not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
-// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
-// CHECK-LIVE-RANGE-NEXT: @avoidable_spill
-// CHECK-LIVE-RANGE: ^bb2:
-// CHECK-LIVE-RANGE-NEXT: || test.some_use
-// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
-// CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
-// CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
-// CHECK-LIVE-RANGE-NEXT: || E| test.some_use
-// CHECK-LIVE-RANGE-NEXT: || E test.some_use
-// CHECK-LIVE-RANGE-NEXT: || arith.addi
-// CHECK-LIVE-RANGE-NEXT: EE cf.br
+// CHECK-LIVE-RANGE-LABEL: @avoidable_spill
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: || test.some_use
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E test.some_use
+// CHECK-LIVE-RANGE-NEXT: || arith.addi
+// CHECK-LIVE-RANGE-NEXT: EE cf.br
// Note in the live ranges (above) there is two constant live-ins (first two ranges),
// which gives six overlapping live ranges. The allocator currently will spill the
>From f4eb9ebc8a5d3784a472ca671cfee90315511f5d Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 2 May 2024 15:44:52 +0000
Subject: [PATCH 5/7] Add test for tile copies
---
.../mlir/Dialect/ArmSME/Transforms/Passes.td | 5 +-
.../ArmSME/Transforms/TileAllocation.cpp | 7 +-
.../ArmSME/tile-allocation-copies.mlir | 82 +++++++++++++++++++
3 files changed, 92 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 8c129ea623b6f8..81730166d3c01a 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -137,7 +137,10 @@ def TestTileAllocation
let options = [
Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
"bool", /*default=*/"false",
- "Dump the live ranges of SME tiles (for debugging)">
+ "Dump the live ranges of SME tiles (for debugging)">,
+ Option<"tileCopiesOnly", "tile-copies-only", "bool", /*default=*/"false",
+ "Only insert tile copies needed for tile allocation "
+ "(but do not allocate any tiles)">
];
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index bcc6deeabd16ad..ce8f341524751b 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -575,7 +575,12 @@ struct TestTileAllocationPass
: public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
using TestTileAllocationBase::TestTileAllocationBase;
void runOnOperation() override {
- if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
+ FunctionOpInterface function = getOperation();
+ if (tileCopiesOnly) {
+ IRRewriter rewriter(function);
+ return insertCopiesAtBranches(rewriter, function);
+ }
+ if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
new file mode 100644
index 00000000000000..5e43efe93b5566
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation=tile-copies-only -split-input-file | FileCheck %s
+
+// This file tests the inserting copies for the SME tile allocation. Copies are
+// inserted at `cf.br` ops (the predecessors to block arguments). Conditional
+// branches are split to prevent conflicts (see cond_br_with_backedge).
+
+// CHECK-LABEL: func.func @simple_branch(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>)
+// %[[COPY:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// cf.br ^bb1(%[[COPY]] : vector<[4]x[4]xf32>)
+// ^bb1(%[[BLOCK_ARG:.*]]: vector<[4]x[4]xf32>):
+
+func.func @simple_branch(%tile : vector<[4]x[4]xf32>) {
+ cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+ return
+}
+
+// -----
+
+// Note: The ^POINTLESS_SHIM_FOR_BB2 block is added as the cond_br splitting does
+// not check if it needs to insert a copy or not (there is no harm in the empty
+// block though -- it will fold away later).
+
+// CHECK-LABEL: func.func @cond_branch(
+// CHECK-SAME: %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32>
+// CHECK: cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[POINTLESS_SHIM_FOR_BB2:[[:alnum:]]+]]
+// CHECK: ^[[POINTLESS_SHIM_FOR_BB2]]:
+// CHECK: cf.br ^[[BB2:.*]]
+// CHECK: ^[[BB1_COPIES]]:
+// CHECK: arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// CHECK: cf.br ^[[BB1:.*]]
+func.func @cond_branch(%cond: i1, %tile: vector<[4]x[4]xf32>) {
+ cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+ return
+^bb2:
+ return
+}
+
+// -----
+
+// Reduction of a real world example that shows why we must split conditional branches.
+
+// CHECK-LABEL: @cond_branch_with_backedge(
+// CHECK-SAME: %{{[[:alnum:]]+}}: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[TILEC:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILED:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+// CHECK: ^bb1(%[[CURRENT_INDEX:.*]]: index, %[[ITER_TILE:.*]]: vector<[4]x[4]xf32>):
+// CHECK: %[[CONTINUE_LOOP:.*]] = arith.cmpi
+// CHECK: cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]]
+// CHECK: ^[[BB3_COPIES]]:
+// CHECK-NEXT: arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.br ^[[BB3:[[:alnum:]]+]]
+// CHECK: ^[[BB3]](%{{.*}}: vector<[4]x[4]xf32>):
+// CHECK-NEXT: return
+
+func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector<[4]x[4]xf32>, %tileC: vector<[4]x[4]xf32>, %tileD: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ // Live here: %tileA, %tileB, %tileC, %tileD
+ cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+ %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+ // Live here: %iterTile, %tileB, %tileC, %tileD
+ // %iterTile dies at the `cf.cond_br`, but %tileB, %tileC, %tileD are live out (in the ^bb2 case).
+ // If we inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
+ // However, note that the copies are only needed if we take the ^bb3 path. So, if we add
+ // a new block along that path we can insert the copies without any conflicts.
+ cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+ // Live here: %iterTile, %tileB, %tileC, %tileD
+ %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+ %nextIndex = arith.addi %currentIndex, %c1 : index
+ cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+ // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+ return
+}
>From fff78c02921af262a9313a0b656cf4cb322598ed Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Thu, 2 May 2024 16:41:34 +0000
Subject: [PATCH 6/7] Wrap op modifications in `rewriter.modifyOpInPlace()`
---
.../ArmSME/Transforms/TileAllocation.cpp | 28 +++++++++++--------
1 file changed, 17 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index ce8f341524751b..af8079c8dee5c5 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -191,10 +191,12 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
condBranch.getTrueDestOperands());
insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
condBranch.getFalseDestOperands());
- condBranch.getFalseDestOperandsMutable().clear();
- condBranch.getTrueDestOperandsMutable().clear();
- condBranch.setSuccessor(newTrueBranch, 0);
- condBranch.setSuccessor(newFalseBranch, 1);
+ rewriter.modifyOpInPlace(condBranch, [&] {
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ });
}
}
@@ -211,7 +213,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
if (isValidSMETileVectorType(operand.get().getType())) {
auto copy =
rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
- operand.assign(copy);
+ rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
}
}
}
@@ -494,7 +496,8 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
// Ensure ArmSME ops that don't produce a value still get a tile ID.
if (!hasTileResult(tileOp))
- tileOp.setTileId(tileIdAttr);
+ rewriter.modifyOpInPlace(tileOp,
+ [&] { tileOp.setTileId(tileIdAttr); });
}
}
auto copyOp = value.getDefiningOp<CopyTileOp>();
@@ -502,7 +505,7 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
// Fold redundant copies.
rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
} else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
- tileOp.setTileId(tileIdAttr);
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
// Rectify operand tile IDs with result tile IDs.
OpOperand *tileOperand = getTileOpOperand(tileOp);
if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
@@ -515,12 +518,15 @@ LogicalResult assignTileIdsAndResolveTrivialConflicts(
// Cloning prevents a move/spill (though may require recomputation).
rewriter.setInsertionPoint(tileOp);
auto clonedOp = operandTileOp.clone();
- clonedOp.setTileId(tileOp.getTileId());
+ rewriter.modifyOpInPlace(
+ clonedOp, [&] { clonedOp.setTileId(tileOp.getTileId()); });
rewriter.insert(clonedOp);
- if (copyOp)
+ if (copyOp) {
rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
- else
- tileOperand->assign(clonedOp->getResult(0));
+ } else {
+ rewriter.modifyOpInPlace(
+ tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
+ }
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
// Validate block arguments.
bool tileMismatch = false;
>From b21f5dd830c8c4cb7925b679996a1e9118ef6e0a Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Fri, 3 May 2024 15:58:44 +0000
Subject: [PATCH 7/7] More docs + naming
---
.../Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp | 6 +-
.../ArmSME/Transforms/TileAllocation.cpp | 62 ++++++++++++++-----
2 files changed, 50 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 562c58b73e7e34..3dbc8e9916df60 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -406,11 +406,11 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
/// AFTER:
/// ```mlir
/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
-/// %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
+/// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
/// ```
///
-/// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
-/// once all ArmSME ops have been converted to LLVM intrinsics.
+/// The 'arm_sme.get_tile' (which models the return) will fold away once all
+/// ArmSME ops have been converted to LLVM intrinsics.
struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index af8079c8dee5c5..bc41386f5903ef 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -164,9 +164,23 @@ class TileAllocator {
unsigned nextInMemoryTileId = kInMemoryTileIdBase;
};
-// Add new intermediate blocks for the true and false destinations of a
-// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
-// branches.
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+/// ^bb1_copy:
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ^bb2_copy:
+/// cf.br ^bb2
+/// ```
void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
SmallVector<cf::CondBranchOp> worklist;
function.walk([&](cf::CondBranchOp condBranch) {
@@ -200,7 +214,18 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
}
}
-/// Inserts tile copies at `cf.br` operations.
+/// Splits conditional branches (see `splitCondBranches`), then inserts tile
+/// copies at `cf.br` operations.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+/// ```
void insertCopiesAtBranches(IRRewriter &rewriter,
FunctionOpInterface function) {
splitCondBranches(rewriter, function);
@@ -219,7 +244,9 @@ void insertCopiesAtBranches(IRRewriter &rewriter,
}
}
-/// A range where a tile value is live. The range may contain holes.
+/// A live range for a (collection of) tile values. A live range is built up of
+/// intervals [start, end) which represent parts of the program where the value
+/// needs to be live (i.e. in an SME virtual tile).
struct LiveRange {
using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
llvm::IntervalMapHalfOpenInfo<unsigned>>;
@@ -296,33 +323,38 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
LiveRange::Allocator &liveRangeAllocator,
Liveness &liveness, FunctionOpInterface function) {
DenseMap<Value, LiveRange> liveRanges;
- auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
- LivenessBlockInfo const &livenessInfo,
- bool liveAtBlockEntry = false) {
+ /// Defines or updates a live range for an SME tile value. Live-ins may update
+ /// an existing live range (rather than define a new one).
+ auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+ LivenessBlockInfo const &livenessInfo,
+ bool liveAtBlockEntry = false) {
if (!isValidSMETileVectorType(value.getType()))
return;
- auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+ // Find or create a live range for `value`.
+ auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+ LiveRange &valueLiveRange = it->second;
auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
unsigned start =
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
unsigned end = operationToIndexMap.at(lastUseInBlock);
- it->second.insert(value, start, end);
+ valueLiveRange.insert(value, start, end);
};
for (Block &block : function.getBlocks()) {
LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
// Handle block arguments:
for (Value argument : block.getArguments())
- updateLiveRanges(argument, &block.front(), *livenessInfo,
- /*liveAtBlockEntry=*/true);
+ defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
// Handle live-ins:
for (Value liveIn : livenessInfo->in())
- updateLiveRanges(liveIn, &block.front(), *livenessInfo,
- /*liveAtBlockEntry=*/true);
+ defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
// Handle new definitions:
for (Operation &op : block) {
for (Value result : op.getResults())
- updateLiveRanges(result, &op, *livenessInfo);
+ defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
}
}
More information about the Mlir-commits
mailing list