[Mlir-commits] [mlir] 041baf2 - [mlir][ArmSME] Use liveness information in the tile allocator (#90448)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue May 14 06:59:05 PDT 2024
Author: Benjamin Maxwell
Date: 2024-05-14T14:59:01+01:00
New Revision: 041baf2f60ac3e107399641aea04c77019e7eab8
URL: https://github.com/llvm/llvm-project/commit/041baf2f60ac3e107399641aea04c77019e7eab8
DIFF: https://github.com/llvm/llvm-project/commit/041baf2f60ac3e107399641aea04c77019e7eab8.diff
LOG: [mlir][ArmSME] Use liveness information in the tile allocator (#90448)
This patch rewrites the ArmSME tile allocator to use liveness
information to make better tile allocation decisions and improve the
correctness of the ArmSME dialect. This algorithm used here is a linear
scan over live ranges, where live ranges are assigned to tiles as they
appear in the program (chronologically). Live ranges release their
assigned tile ID when the current program point is passed their end.
This is a greedy algorithm (which is mainly to keep the implementation
relatively straightforward), and because it seems to be sufficient for
most kernels (e.g. matmuls) that use ArmSME. The general steps of this
are roughly from
https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf,
though there have been a few simplifications and assumptions made for
our use case.
Hopefully, the only changes needed for a user of the ArmSME dialect is
that:
- `-allocate-arm-sme-tiles` will no longer be a standalone pass
- `-test-arm-sme-tile-allocation` is only for unit tests
- `-convert-arm-sme-to-llvm` must happen after `-convert-scf-to-cf`
- SME tile allocation is now part of the LLVM conversion
By integrating this into the `ArmSME -> LLVM` conversion we can allow
high-level (value-based) ArmSME operations to be side-effect-free, as we
can guarantee nothing will rearrange ArmSME operations before we emit
intrinsics (which could invalidate the tile allocation).
The hope is for ArmSME operations to have no hidden state/side effects
and allow easily lowering dialects such as `vector` and `arith` to SME,
without making assumptions about how the input IR looks, as the
semantics of the operations will be the same. That is no (new) side
effects and the IR follows the rules of SSA (a value will never change).
The aim is correctness, so we have a base for working on optimizations.
Added:
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir
Modified:
mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
mlir/lib/Dialect/ArmSME/IR/Utils.cpp
mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
mlir/test/Dialect/ArmSME/canonicalize.mlir
mlir/test/Dialect/ArmSME/roundtrip.mlir
mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
Removed:
mlir/test/Dialect/ArmSME/cse.mlir
################################################################################
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index eab871ab49998..403f811a2569a 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -12,6 +12,7 @@
#include <memory>
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
namespace mlir {
class Pass;
@@ -21,7 +22,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
-std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
+std::unique_ptr<Pass>
+createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false);
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index d094ee3b36ab9..e6d678dc1b12b 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> {
// ArmSMEToLLVM
//===----------------------------------------------------------------------===//
-def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
+def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> {
let summary = "Lower the operations from the ArmSME dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertArmSMEToLLVMPass()";
@@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> {
"arm_sme::ArmSMEDialect",
"LLVM::LLVMDialect"
];
+ let options = [
+ Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+ "bool", /*default=*/"false",
+ "Dump the live ranges of SME tiles (for debugging)">
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
index c507cea5357a7..dac54712c7f47 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h
@@ -15,6 +15,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
@@ -24,11 +25,6 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-namespace mlir::arm_sme {
-static constexpr unsigned kInMemoryTileIdBase = 16;
-#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
-} // namespace mlir::arm_sme
-
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc"
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
new file mode 100644
index 0000000000000..9153fbb57ea88
--- /dev/null
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h
@@ -0,0 +1,27 @@
+//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H
+#define MLIR_DIALECT_ARMSME_OPINTERFACES_H
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+
+namespace mlir::arm_sme {
+
+namespace detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *);
+}
+
+// The first in-memory SME tile ID. This is set to 16 as that is the first tile
+// ID larger than any virtual tile ID supported by the SME ISA.
+static constexpr unsigned kInMemoryTileIdBase = 16;
+
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc"
+} // namespace mlir::arm_sme
+
+#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 239c4beab10d2..9178655f010c9 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type",
def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
let description = [{
- An interface for operations that use or allocate Arm SME tiles. These
- operations need to be assigned a tile ID, an i32 attribute, which specifies
- which virtual tile within the ZA storage to use. The number of tiles
- available depends on the type of the tile. This is summarized below:
+ An interface for operations that use Arm SME tiles. These operations need to
+ be assigned a tile ID, an i32 attribute, which specifies which virtual tile
+ within the ZA storage to use. The number of tiles available depends on the
+ type of the tile. This is summarized below:
| Tile Vector Types | Possible Tile IDs |
|-------------------------------------------------------------------------|---------------------|
@@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
| `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) |
| `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) |
| `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) |
-
- Operations that allocate a new tile (such as arm_sme.get_tile), are used as
- the roots for tile allocation, with all operations that (transitively)
- depend on a root being assigned the same tile ID.
}];
let methods = [
InterfaceMethod<
@@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
return op->getAttrOfType<mlir::IntegerAttr>("tile_id");
}]
>,
- InterfaceMethod<
- [{
- The type of tile this operation allocates. Returns none (std::nullopt)
- if this operation does not allocate a tile.
- }],
- /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>",
- /*methodName=*/"getAllocatedTileType",
- /*arguments=*/(ins),
- /*methodBody=*/[{}],
- /*defaultImpl=*/ [{
- // This operation does not allocate a tile.
- return std::nullopt;
- }]
- >,
InterfaceMethod<
"Returns the VectorType of the tile used by this operation.",
/*returnType=*/"VectorType",
@@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> {
];
let extraSharedClassDeclaration = [{
- // A helper to create a new operation and propagate this operations tile ID.
- template<typename T, typename... Args>
- T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) {
- auto op = rewriter.create<T>(loc, std::forward<Args>(args)...);
- if (auto tileOp = ::llvm::dyn_cast<ArmSMETileOpInterface>(op.getOperation()))
- tileOp.setTileId($_op.getTileId());
- return op;
- }
-
- // A helper to replace this operation and forward its tile ID (if present).
- template<typename T, typename... Args>
- T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) {
- auto newOp = createOpAndForwardTileId<T>(rewriter, $_op.getLoc(), std::forward<Args>(args)...);
- rewriter.replaceOp($_op, newOp);
- return newOp;
- }
-
bool isInMemoryTile() {
auto tileId = getTileId();
return tileId && tileId.getInt() >= kInMemoryTileIdBase;
}
}];
- let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }];
+ let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }];
}
//===----------------------------------------------------------------------===//
@@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr<ArmSME_Dialect, TypeSize,
class ArmSME_Op<string mnemonic, list<Trait> traits = []> :
Op<ArmSME_Dialect, mnemonic, traits> {}
-def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
- let summary = "Returns a SME virtual tile";
+def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> {
+ let summary = "Creates an undefined value of SME virtual tile type";
let description = [{
- Allocates a new SME "virtual tile" within a function. The contents of the
- tile returned from this operation are undefined.
+ Creates a new SME "virtual tile" value within a function. The contents of
+ the tile returned from this operation are undefined.
Example 1:
```mlir
- // Allocate an 8-bit element "virtual tile"
+ // Create an 8-bit element "virtual tile" value:
%za0_b = arm_sme.get_tile: vector<[16]x[16]xi8>
```
Example 2:
```mlir
- // Allocate two 16-bit element "virtual tiles"
+ // Create two 16-bit element "virtual tiles" values:
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
```
Example 3:
```mlir
- // Allocate an 128-bit element "virtual tile"
+ // Create an 128-bit element "virtual tile" value:
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
```
}];
@@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> {
VectorType getTileType() {
return ::llvm::cast<VectorType>(getTile().getType());
}
-
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getTileType());
- }
- }];
-}
-
-def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> {
- let summary = "SME tile placeholder";
- let description = [{
- A placeholder to preserve dataflow while lowering to SME intrinsics (which
- do not take or return SME virtual tile values). This operation is intended
- to be DCE'd once all ArmSME operations have been lowered.
-
- This operation is not intended to be used outside of the ArmSME -> LLVM
- conversion.
}];
- let results = (outs SMETile:$tile);
- let assemblyFormat = "attr-dict `:` type($tile)";
}
-//
-// Tile reset.
-//
-
-def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
- let summary = "Initialize the two-dimensional ZA array with 0s";
+def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> {
+ let summary = "Creates a zero-initialized value of SME virtual tile type";
let results = (outs SMETile:$res);
let description = [{
- Initialise ZA with 0. This operation is convenient wrapper for the SME
- `zero` intrinsic and instruction.
+ Creates a new SME "virtual tile" value within a function. The contents of
+ the tile returned from this operation are zero-initialized.
Example 1: Zero an 8-bit element ZA tile.
@@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
}
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getVectorType());
- }
VectorType getTileType() {
return getVectorType();
}
@@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> {
let assemblyFormat = "attr-dict `:` type($res)";
}
+def CopyTileOp : ArmSME_Op<"copy_tile", [
+ Pure,
+ ArmSMETileOpInterface,
+ AllTypesMatch<["tile", "result"]>
+]> {
+ let summary = "Copies an SME tile value";
+ let arguments = (ins SMETile:$tile);
+ let results = (outs SMETile:$result);
+ let description = [{
+ Copies an SME "virtual tile" value to a new SSA value. This operation is
+ primarily intended to be used to normalize the IR prior to tile allocation.
+
+ Example:
+
+ ```mlir
+ %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+ ```
+ }];
+ let extraClassDeclaration = [{
+ VectorType getTileType() {
+ return ::llvm::cast<VectorType>(getResult().getType());
+ }
+ }];
+ let assemblyFormat = "$tile attr-dict `:` type($result)";
+}
+
def TileLoadOp : ArmSME_Op<"tile_load", [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
@@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getResult().getType());
}
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- return arm_sme::getSMETileType(getVectorType());
- }
VectorType getTileType() {
return getVectorType();
}
@@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [
```
}];
let arguments = (ins
- Arg<AnyMemRef, "the reference to load from">:$base, SVEPredicate:$mask,
+ Arg<AnyMemRef, "the reference to load from", [MemRead]>:$base, SVEPredicate:$mask,
SMETile:$tile, Variadic<Index>:$indices, Index:$tile_slice_index,
ArmSME_TileSliceLayoutAttr:$layout
);
@@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [
}
def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
- ArmSMETileOpInterface,
+ ArmSMETileOpInterface, Pure,
AllTypesMatch<["tile", "result"]>,
TypesMatchWith<
"type of 'vector' matches type of 'tile' slice",
@@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [
}
def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [
- ArmSMETileOpInterface,
+ ArmSMETileOpInterface, Pure,
TypesMatchWith<
"type of 'result' matches type of 'tile' slice",
"tile", "result",
@@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint<string operand> :
def OuterProductOp :
ArmSME_Op<"outerproduct", [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -802,12 +766,6 @@ let arguments = (ins
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- // The outerproduct op allocates a new tile if no accumulator is passed.
- if (!getAcc())
- return arm_sme::getSMETileType(getResultType());
- return std::nullopt;
- }
VectorType getTileType() {
return getResultType();
}
@@ -819,6 +777,7 @@ class OuterProductWideningBase<string mnemonic,
list<Type> allowedResultVectorTypes,
int numOuterProducts> :
ArmSME_Op<mnemonic, [
+ Pure,
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
@@ -857,12 +816,6 @@ class OuterProductWideningBase<string mnemonic,
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
- std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
- // The outerproduct op allocates a new tile if no accumulator is passed.
- if (!getAcc())
- return arm_sme::getSMETileType(getResultType());
- return std::nullopt;
- }
VectorType getTileType() {
return getResultType();
}
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index c2f1b1f1b874e..156744ba57e7b 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -29,9 +29,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
const ArmStreamingMode = ArmStreamingMode::Streaming,
const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false);
-/// Pass that allocates tile IDs to ArmSME operations.
-std::unique_ptr<Pass> createTileAllocationPass();
-
/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7959d291e8926..869a031d6cae8 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -124,17 +124,25 @@ def EnableArmStreaming
let dependentDialects = ["func::FuncDialect"];
}
-def TileAllocation
- : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
- let summary = "Allocate SME tiles";
+def TestTileAllocation
+ : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> {
+ let summary = "Tests SME 'virtual tile' allocation";
let description = [{
This pass does tile allocation for SME "virtual tiles". It is run at the
'func.func' op level, and assigns tile IDs (via an attribute) to all ops
- that implement the `ArmSMETileOpInterface`. An error will be emitted when
- there's no tiles left.
+ that implement the `ArmSMETileOpInterface`. Note: This pass is only intended
+ to be used for testing, tile allocation is done as part of the ArmSME to
+ LLVM conversion (`convert-arm-sme-to-llvm`).
}];
- let constructor = "mlir::arm_sme::createTileAllocationPass()";
- let dependentDialects = ["func::FuncDialect"];
+ let options = [
+ Option<"dumpTileLiveRanges", "dump-tile-live-ranges",
+ "bool", /*default=*/"false",
+ "Dump the live ranges of SME tiles (for debugging)">,
+ Option<"preprocessOnly", "preprocess-only", "bool", /*default=*/"false",
+ "Only preprocess IR so it is ready for tile allocation "
+ "(but do not allocate any tiles)">
+ ];
+ let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"];
}
def OuterProductFusion
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
index e00c7503e6999..a25b844f01eaa 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
@@ -9,6 +9,8 @@
#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H
#define MLIR_DIALECT_ARMSME_TRANSFORMS_H
+#include "mlir/Interfaces/FunctionInterfaces.h"
+
namespace mlir {
class LLVMConversionTarget;
@@ -16,7 +18,14 @@ class LLVMTypeConverter;
class RewritePatternSet;
namespace arm_sme {
+
void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
+
+/// Allocate tile IDs to all ArmSME operations in a function. Requires the
+/// function to be lowered to control flow (cf dialect).
+LogicalResult allocateSMETiles(FunctionOpInterface function,
+ bool dumpRanges = false);
+
} // namespace arm_sme
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 027ad8954f92f..1f40eb6fc693c 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -16,8 +16,10 @@
#define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
#include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Interfaces/FunctionInterfaces.h"
#include <optional>
namespace mlir {
@@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);
+inline bool isValidSMETileVectorType(Type type) {
+ auto vType = dyn_cast<VectorType>(type);
+ return vType && isValidSMETileVectorType(vType);
+}
+
/// Returns the type of SME tile this vector type corresponds to, or none if the
/// vector type does not fit within an SME tile.
std::optional<ArmSMETileType> getSMETileType(VectorType);
@@ -63,6 +70,31 @@ bool isMultipleOfSMETileVectorType(VectorType vType);
/// Creates a vector type for the SME tile of `elementType`.
VectorType getSMETileTypeForElement(Type elementType);
+/// Erase trivially dead tile ops from a function.
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+ FunctionOpInterface function);
+
+/// Returns true if `tileOp` is trivially cloneable. A tile operation is
+/// trivially cloneable if:
+///
+/// 1. It has no operands (and only a single tile result)
+/// 2. It is 'Pure'
+///
+/// This ensures that the cloned operation will not share any dependencies with
+/// the original operation (which could also need to be considered), and that
+/// inserting the cloned operation at a
diff erent point in the program won't
+/// change the semantics of the program (as it has no side effects).
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns true if `tileOp` produces a tile result.
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns the tile `OpOperand` for this `tileOp` (or null).
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp);
+
+/// Returns true `typeA` is >= (in terms of bytes) than `typeB`.
+bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB);
+
} // namespace mlir::arm_sme
#endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 1ba1b88fc1234..3dbc8e9916df6 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -16,6 +16,7 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
@@ -245,6 +246,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern {
if (!tileOp.isInMemoryTile())
return failure();
+ tileOp->emitWarning(
+ "failed to allocate SME virtual tile to operation, all tile "
+ "operations will go through memory, expect degraded performance");
+
// Step 1. Create an alloca for the tile at the top of the function (if one
// does not already exist).
auto loc = tileOp.getLoc();
@@ -391,20 +396,6 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns,
(addArmSMEConversionPattern<Patterns>(patterns, typeConverter), ...);
}
-struct GetTileConversion
- : public ConvertArmSMEOpToLLVMPattern<arm_sme::GetTileOp,
- RequiresSpillsAndFills::No> {
- using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor,
- ConversionPatternRewriter &rewriter) const override {
- rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
- getTile, getTile.getTileType());
- return success();
- }
-};
-
/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
@@ -415,11 +406,11 @@ struct GetTileConversion
/// AFTER:
/// ```mlir
/// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
-/// %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32>
+/// %v = arm_sme.get_tile : vector<[4]x[4]xi32>
/// ```
///
-/// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away
-/// once all ArmSME ops have been converted to LLVM intrinsics.
+/// The 'arm_sme.get_tile' (which models the return) will fold away once all
+/// ArmSME ops have been converted to LLVM intrinsics.
struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern;
@@ -436,7 +427,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
// The base mask is just the mask to zero the first tile (of a size).
// These masks are derived from:
// https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles-
- arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType();
+ arm_sme::ArmSMETileType tileType =
+ *arm_sme::getSMETileType(zero.getTileType());
auto baseMaskForSize = [&] {
switch (tileType) {
case arm_sme::ArmSMETileType::ZAB:
@@ -488,8 +480,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
loc, rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
- rewriter.replaceOpWithNewOp<arm_sme::MaterializeSSATileOp>(
- zero, zero.getVectorType());
+ rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
return success();
}
@@ -746,10 +737,12 @@ struct OuterProductOpConversion
auto loc = outerProductOp.getLoc();
Value acc = outerProductOp.getAcc();
- if (!acc)
+ if (!acc) {
// Initalize accumulator with zero.
- acc = outerProductOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, loc, resultVectorType);
+ auto zero = rewriter.create<arm_sme::ZeroOp>(loc, resultVectorType);
+ zero.setTileId(tileId);
+ acc = zero;
+ }
Value lhsMask = outerProductOp.getLhsMask();
Value rhsMask = outerProductOp.getRhsMask();
@@ -791,25 +784,27 @@ struct OuterProductWideningOpConversion
if (!tileId)
return failure();
+ auto loc = op.getLoc();
Value acc = op.getAcc();
- if (!acc)
+ if (!acc) {
// Initalize accumulator with zero.
- acc = op.template createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, op.getLoc(), op.getResultType());
+ auto zero = rewriter.create<arm_sme::ZeroOp>(loc, op.getResultType());
+ zero.setTileId(tileId);
+ acc = zero;
+ }
Value lhsMask = op.getLhsMask();
Value rhsMask = op.getRhsMask();
if (!lhsMask || !rhsMask) {
auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type());
Value allActiveMask = rewriter.create<arith::ConstantOp>(
- op.getLoc(), DenseElementsAttr::get(predTy, true));
+ loc, DenseElementsAttr::get(predTy, true));
lhsMask = allActiveMask;
rhsMask = allActiveMask;
}
- rewriter.create<OuterProductWideningIntrOp>(op.getLoc(), tileId, lhsMask,
- rhsMask, adaptor.getLhs(),
- adaptor.getRhs());
+ rewriter.create<OuterProductWideningIntrOp>(
+ loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs());
// The outerproduct intrinsics have no result, replace
// 'arm_sme.outerproduct' with the input tile to preserve dataflow.
@@ -865,15 +860,22 @@ namespace {
struct ConvertArmSMEToLLVMPass
: public impl::ConvertArmSMEToLLVMBase<ConvertArmSMEToLLVMPass> {
+ ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+ this->dumpTileLiveRanges = dumpTileLiveRanges;
+ }
void runOnOperation() override {
+ auto function = getOperation();
+
+ if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
+ return signalPassFailure();
+
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
LLVMTypeConverter converter(&getContext());
configureArmSMEToLLVMConversionLegality(target);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
- if (failed(applyPartialConversion(getOperation(), target,
- std::move(patterns))))
+ if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
}
};
@@ -883,34 +885,38 @@ struct ConvertArmSMEToLLVMPass
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
- arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
- arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz,
- arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz,
- arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz,
- arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz,
- arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz,
- arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert,
- arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert,
- arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert,
- arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert,
- arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert,
- arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz,
- arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz,
- arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa,
- arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide,
- arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide,
- arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide,
- arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32,
- arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32,
- arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide,
- arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide,
- arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh,
- arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>();
+ arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str,
+ arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz,
+ arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz,
+ arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz,
+ arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz,
+ arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz,
+ arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert,
+ arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert,
+ arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert,
+ arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert,
+ arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert,
+ arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert,
+ arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert,
+ arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide,
+ arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide,
+ arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide,
+ arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32,
+ arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32,
+ arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide,
+ arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide,
+ arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb,
+ arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw,
+ arm_sme::aarch64_sme_cntsd>();
target.addLegalDialect<arith::ArithDialect,
/* The following are used to lower tile spills/fills */
vector::VectorDialect, scf::SCFDialect,
memref::MemRefDialect>();
- target.addLegalOp<UnrealizedConversionCastOp>();
+ // Pseudo operations. These cannot be code-generated but may exist in the
+ // input IR, or be generated during the conversion. They need to be eliminated
+ // before the final conversion to LLVM IR (and likely will be due to DCE).
+ target.addLegalOp<arm_sme::GetTileOp, arm_sme::CopyTileOp,
+ UnrealizedConversionCastOp>();
}
void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
@@ -955,9 +961,10 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
arm_sme::aarch64_sme_usmopa_wide>,
OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
arm_sme::aarch64_sme_usmops_wide>,
- ZeroOpConversion, GetTileConversion>(patterns, converter);
+ ZeroOpConversion>(patterns, converter);
}
-std::unique_ptr<Pass> mlir::createConvertArmSMEToLLVMPass() {
- return std::make_unique<ConvertArmSMEToLLVMPass>();
+std::unique_ptr<Pass>
+mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) {
+ return std::make_unique<ConvertArmSMEToLLVMPass>(dumpTileLiveRanges);
}
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 16b61c282749c..9f55932c33af6 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -196,12 +196,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
// Initialize tile with zero to satisfy padding. Inactive cols will be
// zeroed anyway since the loads use zeroing predication. For inactive
// rows however, no load will occur so these need to be zeroed.
- initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::ZeroOp>(
- rewriter, loc, tileType);
+ initTile = rewriter.create<arm_sme::ZeroOp>(loc, tileType);
} else {
- // Allocate a new SME tile.
- initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
}
// Create a loop to load the active tile slices from memory.
@@ -212,10 +209,9 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
Value currentTile) -> Value {
// Create 'arm_sme.load_tile_slice' to load tile slice from memory
// into tile.
- return tileLoadOp.createOpAndForwardTileId<arm_sme::LoadTileSliceOp>(
- rewriter, loc, tileType, tileLoadOp.getBase(), predicate,
- currentTile, memrefIndices, tileSliceIndex,
- tileLoadOp.getLayout());
+ return rewriter.create<arm_sme::LoadTileSliceOp>(
+ loc, tileType, tileLoadOp.getBase(), predicate, currentTile,
+ memrefIndices, tileSliceIndex, tileLoadOp.getLayout());
});
if (failed(forOp))
@@ -292,9 +288,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
auto numColsI32 = rewriter.create<arith::IndexCastUIOp>(
loc, rewriter.getI32Type(), numCols);
- // Allocate a new SME tile.
- auto initTile = tileLoadOp.createOpAndForwardTileId<arm_sme::GetTileOp>(
- rewriter, loc, tileType);
+ auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
// Create a loop that loads each ZA tile slice from memory.
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
@@ -339,10 +333,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion
/*passthru=*/pad1DOp);
// Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile.
- auto moveSlice =
- tileLoadOp.createOpAndForwardTileId<arm_sme::MoveVectorToTileSliceOp>(
- rewriter, loc, tileType, loadSlice->getResult(0), currentTile,
- tileSliceIndex, tileLoadOp.getLayout());
+ auto moveSlice = rewriter.create<arm_sme::MoveVectorToTileSliceOp>(
+ loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex,
+ tileLoadOp.getLayout());
rewriter.create<scf::YieldOp>(loc, moveSlice.getResult());
rewriter.setInsertionPointAfter(forOp);
@@ -386,8 +379,8 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(),
tileStoreOp.getMask(),
[&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) {
- tileStoreOp.replaceWithAndForwardTileId<arm_sme::StoreTileSliceOp>(
- rewriter, tileStoreOp.getValueToStore(), tileSliceIndex,
+ rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
+ tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex,
predicate, tileStoreOp.getBase(), memrefIndices,
tileStoreOp.getLayout());
});
diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
index 29fa9085a0a96..cb3a665844872 100644
--- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp
@@ -20,6 +20,12 @@
using namespace mlir;
using namespace mlir::arm_sme;
+namespace mlir::arm_sme::detail {
+LogicalResult verifyArmSMETileOpInterface(Operation *op) {
+ return verifyOperationHasValidTileId(op);
+}
+} // namespace mlir::arm_sme::detail
+
//===----------------------------------------------------------------------===//
// Tablegen Definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
index 6a9e022182226..1f7305a5f8141 100644
--- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp
@@ -116,4 +116,57 @@ VectorType getSMETileTypeForElement(Type elementType) {
return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
}
+void eraseTriviallyDeadTileOps(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ SmallVector<Operation *> worklist;
+ function->walk([&](Operation *op) {
+ auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
+ if (armSMEOp && isOpTriviallyDead(armSMEOp))
+ worklist.push_back(armSMEOp);
+ });
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (!isOpTriviallyDead(op))
+ continue;
+ for (Value value : op->getOperands()) {
+ if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
+ worklist.push_back(armSMEOp);
+ }
+ rewriter.eraseOp(op);
+ }
+}
+
+bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
+ return tileOp && tileOp->getNumResults() == 1 &&
+ tileOp->getNumOperands() == 0 && isPure(tileOp);
+}
+
+bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
+ for (Value result : tileOp->getResults()) {
+ if (arm_sme::isValidSMETileVectorType(result.getType()))
+ return true;
+ }
+ return false;
+}
+
+OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
+ if (!tileOp)
+ return nullptr;
+ auto isTileOperandType = [](OpOperand &operand) {
+ return arm_sme::isValidSMETileVectorType(operand.get().getType());
+ };
+ assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
+ "expected at most one tile operand");
+ OpOperand *tileOperand =
+ llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
+ if (tileOperand == tileOp->getOpOperands().end())
+ return nullptr;
+ return tileOperand;
+}
+
+bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
+ // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
+ return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
+}
+
} // namespace mlir::arm_sme
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 4acb2a8fb7b53..1e1e0e569124d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -6,12 +6,18 @@
//
//===----------------------------------------------------------------------===//
//
-// This pass allocates SME tiles at the 'func.func' op level for ArmSME
-// operations. It does this using a 16-bit tile mask that has a bit for each
-// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule.
+// This transform allocates SME tiles at the 'func.func' op level for ArmSME
+// operations. It roughly implements a linear scan register allocator, similar
+// to the one outlined in [1], but with simplifications and assumptions made for
+// our use case. Note that this is a greedy allocator (so it may not always find
+// the most optimal allocation of tiles).
+//
+// The allocator operates at the CF dialect level. It is the responsibility of
+// users to ensure the IR has been lowered to CF before invoking the tile
+// allocator.
//
// The 128-bit tiles overlap with other element tiles as follows (see section
-// B2.3.2 of SME spec [1]):
+// B2.3.2 of SME spec [2]):
//
// Tile Overlaps
// ---------------------------------------------------------------------------
@@ -32,39 +38,34 @@
// ZA6.D ZA6.Q, ZA14.Q
// ZA7.D ZA7.Q, ZA15.Q
//
-// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
-// that is initalized during the first tile allocation within a function and
-// updated on each subsequent allocation.
-//
-// [1] https://developer.arm.com/documentation/ddi0616/aa
+// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register
+// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer)
+// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf
+// [2] https://developer.arm.com/documentation/ddi0616/aa
//
//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/Liveness.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Transforms.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
-#define DEBUG_TYPE "allocate-arm-sme-tiles"
-
-namespace mlir {
-namespace arm_sme {
-#define GEN_PASS_DEF_TILEALLOCATION
+namespace mlir::arm_sme {
+#define GEN_PASS_DEF_TESTTILEALLOCATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
-} // namespace arm_sme
-} // namespace mlir
+} // namespace mlir::arm_sme
using namespace mlir;
using namespace mlir::arm_sme;
namespace {
-static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use");
-static constexpr StringLiteral
- kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id");
-
enum class TileMask : unsigned {
// clang-format off
kZA0B = 0xffff, // 1111 1111 1111 1111
@@ -137,172 +138,640 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
}
}
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
- TileMask &tilesInUse) {
- auto masks = getMasks(tileType);
- for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
- if ((tilesInUse & tileMask) == TileMask::kNone) {
- tilesInUse |= tileMask;
- return tileId;
+class TileAllocator {
+public:
+ /// Allocates and returns a tile ID. Fails if there are no tiles left.
+ FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+ if ((tilesInUse & tileMask) == TileMask::kNone) {
+ tilesInUse |= tileMask;
+ return tileId;
+ }
}
+ return failure();
+ }
+
+ /// Releases a previously allocated tile ID.
+ void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) != TileMask::kNone &&
+ "cannot release unallocated tile!");
+ tilesInUse ^= tileMask;
}
- return failure();
-}
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
- SetVector<Operation *> &dependantOps) {
- auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
- for (auto [idx, value] : llvm::enumerate(inputValues)) {
- if (value == rootValue)
- findDependantOps(exitValues[idx], dependantOps);
+ /// Allocates an in-memory tile ID.
+ unsigned allocateInMemoryTileId() {
+ // Note: We never release in-memory tile IDs. We could, which may allow
+ // reusing an allocation, but as we _never_ want to spill an SME tile this
+ // is not optimized.
+ return nextInMemoryTileId++;
+ }
+
+private:
+ TileMask tilesInUse = TileMask::kNone;
+ unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+/// ^bb1_copy:
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ^bb2_copy:
+/// cf.br ^bb2
+/// ```
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+ SmallVector<cf::CondBranchOp> worklist;
+ function.walk([&](cf::CondBranchOp condBranch) {
+ if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+ return isValidSMETileVectorType(value.getType());
+ })) {
+ worklist.push_back(condBranch);
}
+ });
+
+ auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+ rewriter.setInsertionPointToEnd(source);
+ rewriter.create<cf::BranchOp>(loc, dest, args);
};
- for (Operation *user : rootValue.getUsers()) {
- if (dependantOps.contains(user))
+
+ for (auto condBranch : worklist) {
+ auto loc = condBranch.getLoc();
+ Block *block = condBranch->getBlock();
+ auto newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+ condBranch.getTrueDestOperands());
+ insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+ condBranch.getFalseDestOperands());
+ rewriter.modifyOpInPlace(condBranch, [&] {
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ });
+ }
+}
+
+/// Inserts tile copies at `cf.br` operations.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
+/// ```
+void insertCopiesAtBranches(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ for (Block &block : function.getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<cf::BranchOp>(terminator))
continue;
- dependantOps.insert(user);
- TypeSwitch<Operation *>(user)
- .Case<cf::BranchOp>([&](auto branchOp) {
- // (CF) Follow branch.
- traverseCorrespondingValues(branchOp.getDestOperands(),
- branchOp.getDest()->getArguments());
- })
- .Case<cf::CondBranchOp>([&](auto condBranchOp) {
- // (CF) Follow true branch.
- traverseCorrespondingValues(
- condBranchOp.getTrueOperands(),
- condBranchOp.getTrueDest()->getArguments());
- // (CF) Follow false branch.
- traverseCorrespondingValues(
- condBranchOp.getFalseOperands(),
- condBranchOp.getFalseDest()->getArguments());
- })
- .Case<LoopLikeOpInterface>([&](auto loopOp) {
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
- traverseCorrespondingValues(loopOp.getInits(),
- loopOp.getRegionIterArgs());
- })
- .Case<scf::YieldOp>([&](auto yieldOp) {
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
- auto parent = user->getParentOp();
- traverseCorrespondingValues(user->getOperands(),
- parent->getResults());
+ rewriter.setInsertionPoint(terminator);
+ for (OpOperand &operand : terminator->getOpOperands()) {
+ if (isValidSMETileVectorType(operand.get().getType())) {
+ auto copy =
+ rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
+ }
+ }
+ }
+}
+
+/// Prepares the IR for tile allocation. It does this by first 'splitting'
+/// conditional branches (see `splitCondBranches`), then inserting tile copies
+/// at branch operations. The conditional branches are split to prevent the
+/// copies needed for them overlapping between the true and false paths of the
+/// branch (see `tile-allocation-copies.mlir` and
+/// `tile-allocation-liveness.mlir` for examples). The copies break up live
+/// ranges and ensure when moving out of SSA the semantics of the program are
+/// preserved.
+void preprocessForTileAllocation(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ splitCondBranches(rewriter, function);
+ insertCopiesAtBranches(rewriter, function);
+}
+
+/// A live range for a (collection of) tile values. A live range is built up of
+/// non-overlapping intervals [start, end) which represent parts of the program
+/// where a value in the range needs to be live (i.e. in an SME virtual tile).
+/// Note that as the intervals are non-overlapping all values within a live
+/// range can be allocated to the same SME virtual tile.
+struct LiveRange {
+ using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+ llvm::IntervalMapHalfOpenInfo<unsigned>>;
+ using Allocator = RangeSet::Allocator;
+ // Dummy value for the IntervalMap. Only the keys matter (the intervals).
+ static constexpr uint8_t kValidLiveRange = 0xff;
+
+ LiveRange(Allocator &allocator)
+ : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+ /// Returns true if this range overlaps with `otherRange`.
+ bool overlaps(LiveRange const &otherRange) const {
+ return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+ *otherRange.ranges)
+ .valid();
+ }
+
+ /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+ void unionWith(LiveRange const &otherRange) {
+ for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+ ++it)
+ ranges->insert(it.start(), it.stop(), kValidLiveRange);
+ values.set_union(otherRange.values);
+ }
+
+ /// Inserts an interval [start, end) for `value` into this range.
+ void insert(Value value, unsigned start, unsigned end) {
+ values.insert(value);
+ if (start != end)
+ ranges->insert(start, end, kValidLiveRange);
+ }
+
+ bool empty() const { return ranges->empty(); }
+ unsigned start() const { return ranges->start(); }
+ unsigned end() const { return ranges->stop(); }
+ bool operator<(LiveRange const &other) const {
+ return start() < other.start();
+ }
+
+ ArmSMETileType getTileType() const {
+ return *getSMETileType(cast<VectorType>(values[0].getType()));
+ }
+
+ /// The values contained in this live range.
+ SetVector<Value> values;
+
+ /// A set of (non-overlapping) intervals that mark where any value in `values`
+ /// is live.
+ std::unique_ptr<RangeSet> ranges;
+
+ /// The tile ID (or none) assigned to this live range.
+ std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+/// Operations are numbered consecutively wihin blocks, and the blocks are
+/// topologically sorted (using forward edges). This function is only correct if
+/// all ArmSME have been converted to CF (which is asserted).
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+ unsigned index = 0;
+ SetVector<Block *> blocks =
+ getTopologicallySortedBlocks(function.getFunctionBody());
+ DenseMap<Operation *, unsigned> operationToIndexMap;
+ for (Block *block : blocks) {
+ index++; // We want block args to have their own number.
+ for (Operation &op : block->getOperations()) {
+#ifndef NDEBUG
+ op.walk([&](ArmSMETileOpInterface nestedOp) {
+ assert(&op == nestedOp.getOperation() &&
+ "ArmSME tile allocation does not support nested regions");
+ });
+#endif
+ operationToIndexMap.try_emplace(&op, index++);
+ }
+ }
+ return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ LiveRange::Allocator &liveRangeAllocator,
+ Liveness &liveness, FunctionOpInterface function) {
+ assert(!operationToIndexMap.empty() && "expected operation numbering");
+ DenseMap<Value, LiveRange> liveRanges;
+ /// Defines or updates a live range for an SME tile value. Live-ins may update
+ /// an existing live range (rather than define a new one). Note: If
+ /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
+ /// the block.
+ auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+ LivenessBlockInfo const &livenessInfo,
+ bool liveAtBlockEntry = false) {
+ if (!isValidSMETileVectorType(value.getType()))
+ return;
+ // Find or create a live range for `value`.
+ auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+ LiveRange &valueLiveRange = it->second;
+ auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
+ unsigned startOpIdx =
+ operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+ unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
+ valueLiveRange.insert(value, startOpIdx, endOpIdx);
+ };
+
+ for (Block &block : function.getBlocks()) {
+ LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+ // Handle block arguments:
+ for (Value argument : block.getArguments())
+ defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle live-ins:
+ for (Value liveIn : livenessInfo->in())
+ defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle new definitions:
+ for (Operation &op : block) {
+ for (Value result : op.getResults())
+ defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
+ }
+ }
+
+ return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+ function_ref<void(Value)> callback) {
+ Block *block = blockArg.getOwner();
+ unsigned argNumber = blockArg.getArgNumber();
+ for (Block *pred : block->getPredecessors()) {
+ TypeSwitch<Operation *>(pred->getTerminator())
+ .Case<cf::BranchOp>([&](auto branch) {
+ Value predecessorOperand = branch.getDestOperands()[argNumber];
+ callback(predecessorOperand);
})
- .Default([&](auto) {
- // Otherwise, assume users of _any_ result are dependant.
- for (Value result : user->getResults())
- findDependantOps(result, dependantOps);
+ .Case<cf::CondBranchOp>([&](auto condBranch) {
+ if (condBranch.getFalseDest() == block) {
+ Value predecessorOperand =
+ condBranch.getFalseDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
+ if (condBranch.getTrueDest() == block) {
+ Value predecessorOperand =
+ condBranch.getTrueDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
});
}
}
-struct AssignTileIDsPattern
- : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
- PatternRewriter &rewriter) const override {
- if (tileOp.getTileId())
- return failure();
-
- auto func = tileOp->getParentOfType<FunctionOpInterface>();
- auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
- if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
- func->getDiscardableAttr(name)))
- return unsigned(attr.getInt());
- return defaultVal;
- };
- auto setDiscardableIntAttr = [&](StringRef name, auto value) {
- rewriter.modifyOpInPlace(tileOp, [&] {
- func->setDiscardableAttr(name,
- rewriter.getI32IntegerAttr((unsigned)value));
- });
- };
- std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
- if (!tileType)
- return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
- TileMask tilesInUse =
- static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
- auto tileId = allocateTileId(*tileType, tilesInUse);
- bool tileIsInMemory = failed(tileId);
- if (tileIsInMemory) {
- // If we could not find a real tile ID, use an in-memory tile ID (ID >=
- // 16). A later pass will insert the necessary spills and reloads.
- tileId =
- getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
- tileOp->emitWarning(
- "failed to allocate SME virtual tile to operation, all tile "
- "operations will go through memory, expect degraded performance");
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+ DenseMap<Value, LiveRange *> liveRanges;
+ for (auto &[value, liveRange] : initialLiveRanges) {
+ liveRanges.insert({value, &liveRange});
+ }
+
+ // Merge the live ranges of values `a` and `b` into one (if they do not
+ // overlap). After this, the values `a` and `b` will both point to the same
+ // live range (which will contain multiple values).
+ auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+ LiveRange *aLiveRange = liveRanges.at(a);
+ LiveRange *bLiveRange = liveRanges.at(b);
+ if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+ aLiveRange->unionWith(*bLiveRange);
+ for (Value value : bLiveRange->values)
+ liveRanges[value] = aLiveRange;
+ }
+ };
+
+ // Merge the live ranges of new definitions with their tile operands.
+ auto unifyDefinitionsWithOperands = [&](Value value) {
+ auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ if (!armSMEOp)
+ return;
+ for (auto operand : armSMEOp->getOperands()) {
+ if (isValidSMETileVectorType(operand.getType()))
+ mergeValuesIfNonOverlapping(value, operand);
}
+ };
+
+ // Merge the live ranges of block arguments with their predecessors.
+ auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+ auto blockArg = dyn_cast<BlockArgument>(value);
+ if (!blockArg)
+ return;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+ });
+ };
+
+ auto applyRule = [&](auto rule) {
+ llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+ };
+
+ // Unify as many live ranges as we can. This prevents unnecessary moves.
+ applyRule(unifyBlockArgumentsWithPredecessors);
+ applyRule(unifyDefinitionsWithOperands);
+
+ // Remove duplicate live range entries.
+ SetVector<LiveRange *> uniqueLiveRanges;
+ for (auto [_, liveRange] : liveRanges) {
+ if (!liveRange->empty())
+ uniqueLiveRanges.insert(liveRange);
+ }
+
+ // Sort the new live ranges by starting point (ready for tile allocation).
+ auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+ std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+ [](LiveRange *a, LiveRange *b) { return *a < *b; });
+ return std::move(coalescedLiveRanges);
+}
+
+/// Choose a live range to spill (via some heuristics). This picks either an
+/// active live range from `activeRanges` or the new live range `newRange`.
+LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
+ LiveRange *newRange) {
+ // Heuristic: Spill trivially copyable operations (usually free).
+ auto isTrivialSpill = [&](LiveRange *allocatedRange) {
+ return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
+ newRange->getTileType()) &&
+ allocatedRange->values.size() == 1 &&
+ isTriviallyCloneableTileOp(
+ allocatedRange->values[0]
+ .getDefiningOp<ArmSMETileOpInterface>());
+ };
+ if (isTrivialSpill(newRange))
+ return newRange;
+ auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
+ if (trivialSpill != activeRanges.end())
+ return *trivialSpill;
+
+ // Heuristic: Spill the range that ends last (with a compatible tile type).
+ auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
+ return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
+ a->end() < b->end();
+ };
+ LiveRange *lastActiveLiveRange = *std::max_element(
+ activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
+ if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
+ return lastActiveLiveRange;
+ return newRange;
+}
+
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
+/// Note: This does not attempt to fill holes in active live ranges.
+void allocateTilesToLiveRanges(
+ ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
+ TileAllocator tileAllocator;
+ SetVector<LiveRange *> activeRanges;
+ for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
+ // Release tile IDs from live ranges that have ended.
+ activeRanges.remove_if([&](LiveRange *activeRange) {
+ if (activeRange->end() <= nextRange->start()) {
+ tileAllocator.releaseTileId(activeRange->getTileType(),
+ *activeRange->tileId);
+ return true;
+ }
+ return false;
+ });
- // Set all operations dependent on `tileOp` to use the same tile ID.
- // This is a naive tile allocation scheme, but works for common cases. For
- // example, as this only allocates tile IDs to existing ops, it can't solve
- // cases like this (%tileA and %tileB come from
diff erent root operations):
- //
- // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
- // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
- // } else {
- // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
- // }
- //
- // This case would require allocating a new tile for the result of the
- // scf.if, and moving the contents of %tileA or %tileB to result tile (based
- // on the %some_cond).
- // Find all the ops that (transitively) depend on this tile.
- SetVector<Operation *> dependantOps;
- findDependantOps(tileOp->getResult(0), dependantOps);
- auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
- for (auto *op : dependantOps) {
- if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- auto currentTileId = dependantTileOp.getTileId();
- if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
- return dependantTileOp.emitOpError(
- "already assigned
diff erent SME virtual tile!");
+ // Allocate a tile ID to `nextRange`.
+ auto rangeTileType = nextRange->getTileType();
+ auto tileId = tileAllocator.allocateTileId(rangeTileType);
+ if (succeeded(tileId)) {
+ nextRange->tileId = *tileId;
+ } else {
+ LiveRange *rangeToSpill =
+ chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
+ if (rangeToSpill != nextRange) {
+ // Spill an active live range (so release its tile ID first).
+ tileAllocator.releaseTileId(rangeToSpill->getTileType(),
+ *rangeToSpill->tileId);
+ activeRanges.remove(rangeToSpill);
+ // This will always succeed after a spill (of an active live range).
+ nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
}
+ rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
+ }
+
+ // Insert the live range into the active ranges.
+ if (nextRange->tileId < kInMemoryTileIdBase)
+ activeRanges.insert(nextRange);
+ }
+}
+
+/// Assigns a tile ID to an MLIR value.
+void assignTileIdToValue(IRRewriter &rewriter, Value value,
+ IntegerAttr tileIdAttr) {
+ if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>())
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
+ for (Operation *user : value.getUsers()) {
+ if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+ // Ensure ArmSME ops that don't produce a value still get a tile ID.
+ if (!hasTileResult(tileOp))
+ rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); });
}
+ }
+}
+
+/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
+LogicalResult assignTileIdsAndResolveTrivialConflicts(
+ IRRewriter &rewriter, FunctionOpInterface function,
+ ArrayRef<LiveRange *> allocatedLiveRanges) {
+ for (LiveRange const *liveRange : allocatedLiveRanges) {
+ auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
+ auto isAllocatedToSameTile = [&](Value value) {
+ if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ tileOp && tileOp.getTileId() == tileIdAttr)
+ return true;
+ return liveRange->values.contains(value);
+ };
+
+ /// Eliminates copies where the operand has the same tile ID.
+ auto foldRedundantCopies = [&](Value value) -> LogicalResult {
+ auto copyOp = value.getDefiningOp<CopyTileOp>();
+ if (!copyOp || !isAllocatedToSameTile(copyOp.getTile()))
+ return failure();
+ rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
+ return success();
+ };
+
+ /// Validates each predecessor to a tile block argument has been assigned
+ /// the same tile ID.
+ auto validateBlockArguments = [&](Value value) {
+ auto blockArg = dyn_cast<BlockArgument>(value);
+ if (!blockArg) {
+ // Not a block argument (nothing to validate).
+ return success();
+ }
+ bool tileMismatch = false;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ if (tileMismatch)
+ return;
+ if (!isAllocatedToSameTile(predecessorTile)) {
+ blockArg.getOwner()->getParentOp()->emitOpError(
+ "block argument not allocated to the same SME virtial tile as "
+ "predecessors");
+ tileMismatch = true;
+ }
+ });
+ return success(/*isSuccess=*/!tileMismatch);
+ };
- // Rewrite IR.
- if (!tileIsInMemory)
- setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
- else
- setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
- rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
- for (auto *op : dependantOps) {
- if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
+ /// Attempts to resolve (trivial) tile ID conflicts.
+ auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult {
+ auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ OpOperand *tileOperand = getTileOpOperand(tileOp);
+ if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) {
+ // Operand already allocated to the correct tile.
+ // No conflict to resolve.
+ return success();
+ }
+ auto operandTileOp =
+ tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
+ if (!isTriviallyCloneableTileOp(operandTileOp)) {
+ auto error =
+ tileOp.emitOpError("tile operand allocated to
diff erent SME "
+ "virtial tile (move required)");
+ error.attachNote(tileOperand->get().getLoc())
+ << "tile operand is: " << tileOperand->get();
+ return error;
+ }
+ // Cloning prevents a move/spill (though may require recomputation).
+ rewriter.setInsertionPoint(tileOp);
+ auto clonedOp = operandTileOp.clone();
+ rewriter.modifyOpInPlace(clonedOp,
+ [&] { clonedOp.setTileId(tileOp.getTileId()); });
+ rewriter.insert(clonedOp);
+ if (isa<CopyTileOp>(tileOp)) {
+ rewriter.replaceAllUsesWith(tileOp->getResult(0),
+ clonedOp->getResult(0));
+ } else {
rewriter.modifyOpInPlace(
- dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
+ tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); });
}
+ return success();
+ };
+
+ for (Value value : liveRange->values) {
+ // 1. Assign the tile ID to the value.
+ assignTileIdToValue(rewriter, value, tileIdAttr);
+
+ // 2. Attempt to eliminate redundant tile copies.
+ if (succeeded(foldRedundantCopies(value)))
+ continue;
+
+ // 3. Validate tile block arguments.
+ if (failed(validateBlockArguments(value)))
+ return failure();
+
+ // 4. Attempt to resolve (trivial) tile ID conflicts.
+ if (failed(resolveTrivialTileConflicts(value)))
+ return failure();
}
+ }
+ return success();
+}
- return success();
+/// Prints live ranges alongside operation names for debugging.
+void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ ArrayRef<LiveRange const *> liveRanges,
+ FunctionOpInterface function) {
+ llvm::errs() << "SME Tile Liveness: @" << function.getName()
+ << "\nKey:\nS - Start\nE - End\n| - Live\n";
+ for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
+ llvm::errs() << "^bb" << blockIdx << ":\n";
+ for (Operation &op : block.getOperations()) {
+ unsigned operationIndex = operationToIndexMap.at(&op);
+ for (LiveRange const *range : liveRanges) {
+ char liveness = ' ';
+ for (auto it = range->ranges->begin(); it != range->ranges->end();
+ ++it) {
+ if (it.start() == operationIndex)
+ liveness = (liveness == 'E' ? '|' : 'S');
+ else if (it.stop() == operationIndex)
+ liveness = (liveness == 'S' ? '|' : 'E');
+ else if (operationIndex >= it.start() && operationIndex < it.stop())
+ liveness = '|';
+ }
+ llvm::errs() << liveness;
+ }
+ llvm::errs() << ' ' << op.getName() << '\n';
+ }
}
-};
+ llvm::errs() << "==========\n";
+}
-struct TileAllocationPass
- : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
+struct TestTileAllocationPass
+ : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
+ using TestTileAllocationBase::TestTileAllocationBase;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- patterns.add<AssignTileIDsPattern>(patterns.getContext());
- GreedyRewriteConfig config;
- // Setting useTopDownTraversal ensures tiles are allocated in program
- // order.
- config.useTopDownTraversal = true;
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(patterns), config))) {
- signalPassFailure();
+ FunctionOpInterface function = getOperation();
+ if (preprocessOnly) {
+ IRRewriter rewriter(function);
+ return preprocessForTileAllocation(rewriter, function);
}
+ if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges)))
+ signalPassFailure();
}
};
} // namespace
-std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
- return std::make_unique<TileAllocationPass>();
+LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
+ bool dumpRanges) {
+ if (function.empty()) {
+ // TODO: Also return early if the function contains no ArmSME ops?
+ return success();
+ }
+
+ LiveRange::Allocator liveRangeAllocator;
+ IRRewriter rewriter(function.getContext());
+
+ // 1. Preprocess the IR for tile allocation.
+ preprocessForTileAllocation(rewriter, function);
+
+ // 2. Gather live ranges for each ArmSME tile within the function.
+ Liveness liveness(function);
+ auto operationToIndexMap = generateOperationNumbering(function);
+ auto initialLiveRanges = gatherTileLiveRanges(
+ operationToIndexMap, liveRangeAllocator, liveness, function);
+ if (initialLiveRanges.empty())
+ return success();
+
+ if (dumpRanges) {
+ // Wrangle initial live ranges into a form suitable for printing.
+ auto nonEmpty = llvm::make_filter_range(
+ llvm::make_second_range(initialLiveRanges),
+ [&](LiveRange const &liveRange) { return !liveRange.empty(); });
+ auto initialRanges = llvm::to_vector(llvm::map_range(
+ nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
+ std::sort(initialRanges.begin(), initialRanges.end(),
+ [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
+ llvm::errs() << "\n========== Initial Live Ranges:\n";
+ dumpLiveRanges(operationToIndexMap, initialRanges, function);
+ }
+
+ // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
+ // for tile allocation. E.g. Unify the result of an operation with its
+ // operands.
+ auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
+
+ if (dumpRanges) {
+ llvm::errs() << "\n========== Coalesced Live Ranges:\n";
+ dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
+ }
+
+ // 4. Allocate tile IDs to live ranges.
+ allocateTilesToLiveRanges(coalescedLiveRanges);
+
+ // 5. Assign the tile IDs back to the ArmSME operations.
+ if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
+ coalescedLiveRanges))) {
+ return failure();
+ }
+
+ // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+ // users). This prevents the LLVM conversion needlessly inserting spills.
+ eraseTriviallyDeadTileOps(rewriter, function);
+ return success();
}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index f48046a8d7995..14b1f323da3a2 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s
-
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file | FileCheck %s
// Test conversion of ArmSME ops to LLVM intrinsics.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index a9c1a65a296f4..2c3868d7f25cb 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \
// RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \
// RUN: -split-input-file -verify-diagnostics | \
// RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING
@@ -56,6 +56,9 @@ func.func @use_too_many_tiles() {
%1 = arm_sme.zero : vector<[4]x[4]xi32>
// expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%2 = arm_sme.zero : vector<[8]x[8]xi16>
+ "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> ()
return
}
// AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles
@@ -131,18 +134,16 @@ func.func @use_too_many_tiles() {
/// Note: In this example an entire tile swap is inserted before/after the
/// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a
/// single tile slice (and can omit the initial load, like in the previous example).
-func.func @very_excessive_spills(%memref : memref<?x?xf32>) -> vector<[4]x[4]xf32> {
- %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8>
+func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref<?x?xf32>) -> vector<[4]x[4]xf32> {
%c0 = arith.constant 0 : index
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%mask = vector.constant_mask [4] : vector<[4]xi1>
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
- "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> ()
+ "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> ()
+ return %loadSlice : vector<[4]x[4]xf32>
}
// AFTER-TILE-ALLOC-LABEL: @very_excessive_spills
-// AFTER-TILE-ALLOC: arm_sme.get_tile
-// AFTER-TILE-ALLOC-SAME: tile_id = 0
// AFTER-TILE-ALLOC: arm_sme.load_tile_slice
// AFTER-TILE-ALLOC-SAME: tile_id = 16
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 15767ff1dec3f..a62ca080ab8d9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics
//===----------------------------------------------------------------------===//
// arm_sme.outerproduct
diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index e144bac970a7d..8b46998d56b04 100644
--- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s
// -----
diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir
index b7ba3f728c705..643dfd4a7cbd9 100644
--- a/mlir/test/Dialect/ArmSME/canonicalize.mlir
+++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir
@@ -1,18 +1,14 @@
-// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s
-// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed
-// once it becomes unused, after lowering to control flow.
+// This tests that dead tile values are removed from control flow.
-// -----
-
-// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks
-// CHECK-NOT: arm_sme.materialize_ssa_tile
+// CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks
// CHECK-NOT: vector<[4]x[4]xf32>
-func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
+func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref<?x?xi32>) {
%c10 = arith.constant 10 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
- %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32>
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>)
^bb1(%1: index, %2: vector<[4]x[4]xf32>): // 2 preds: ^bb0, ^bb2
%3 = arith.cmpi slt, %1, %c10 : index
diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir
deleted file mode 100644
index 74e7293eaeca5..0000000000000
--- a/mlir/test/Dialect/ArmSME/cse.mlir
+++ /dev/null
@@ -1,30 +0,0 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s
-
-// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as
-// duplicates.
-
-// CHECK-LABEL: @zero_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @zero_tile() {
- %tile_1 = arm_sme.zero : vector<[4]x[4]xi32>
- %tile_2 = arm_sme.zero : vector<[4]x[4]xi32>
- "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
- "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
- return
-}
-
-// CHECK-LABEL: @get_tile
-// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32>
-// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> ()
-// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> ()
-func.func @get_tile() {
- %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32>
- %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32>
- "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> ()
- "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> ()
- return
-}
diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir
index ab46c7adca596..6095fdc11ead8 100644
--- a/mlir/test/Dialect/ArmSME/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir
@@ -1403,3 +1403,12 @@ func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect
%reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
return %reuslt : vector<[2]x[2]xi64>
}
+
+//===----------------------------------------------------------------------===//
+// arm_sme.copy_tile
+//===----------------------------------------------------------------------===//
+
+func.func @arm_sme_copy_tile(%vec: vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> {
+ %result = arm_sme.copy_tile %vec : vector<[4]x[4]xf32>
+ return %result : vector<[4]x[4]xf32>
+}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
new file mode 100644
index 0000000000000..6d9cbf36a162f
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir
@@ -0,0 +1,159 @@
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation=preprocess-only -split-input-file | FileCheck %s
+
+// This file tests the inserting copies for the SME tile allocation. Copies are
+// inserted at `cf.br` ops (the predecessors to block arguments). Conditional
+// branches are split to prevent conflicts (see cond_br_with_backedge).
+
+// CHECK-LABEL: func.func @simple_branch(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>)
+// %[[COPY:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// cf.br ^bb1(%[[COPY]] : vector<[4]x[4]xf32>)
+// ^bb1(%[[BLOCK_ARG:.*]]: vector<[4]x[4]xf32>):
+
+func.func @simple_branch(%tile : vector<[4]x[4]xf32>) {
+ cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+ return
+}
+
+// -----
+
+// Note: The ^POINTLESS_SHIM_FOR_BB2 block is added as the cond_br splitting does
+// not check if it needs to insert a copy or not (there is no harm in the empty
+// block though -- it will fold away later).
+
+// CHECK-LABEL: func.func @cond_branch(
+// CHECK-SAME: %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32>
+// CHECK: cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[POINTLESS_SHIM_FOR_BB2:[[:alnum:]]+]]
+// CHECK: ^[[POINTLESS_SHIM_FOR_BB2]]:
+// CHECK: cf.br ^[[BB2:.*]]
+// CHECK: ^[[BB1_COPIES]]:
+// CHECK: arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// CHECK: cf.br ^[[BB1:.*]]
+func.func @cond_branch(%cond: i1, %tile: vector<[4]x[4]xf32>) {
+ cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+^bb1(%blockArg: vector<[4]x[4]xf32>):
+ return
+^bb2:
+ return
+}
+
+// -----
+
+// Reduction of a real world example that shows why we must split conditional branches.
+
+// CHECK-LABEL: @cond_branch_with_backedge(
+// CHECK-SAME: %[[TILEA:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+// CHECK-SAME: %[[TILEC:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILED:[[:alnum:]]+]]: vector<[4]x[4]xf32>,
+// CHECK: %[[BB1_COPY_0:.*]] = arm_sme.copy_tile %[[TILEA]] : vector<[4]x[4]xf32>
+// CHECK: cf.br ^bb1(%{{[[:alnum:]]+}}, %[[BB1_COPY_0]]
+// CHECK: ^bb1(%[[CURRENT_INDEX:.*]]: index, %[[ITER_TILE:.*]]: vector<[4]x[4]xf32>):
+// CHECK: %[[CONTINUE_LOOP:.*]] = arith.cmpi
+// CHECK: cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2_COPIES:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]]
+// CHECK: ^[[BB3_COPIES]]:
+// CHECK-NEXT: %[[BB3_COPY_0:.*]] = arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: %[[BB3_COPY_1:.*]] = arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: %[[BB3_COPY_2:.*]] = arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: %[[BB3_COPY_3:.*]] = arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.br ^[[BB3:[[:alnum:]]+]](%[[BB3_COPY_0]], %[[BB3_COPY_1]], %[[BB3_COPY_2]], %[[BB3_COPY_3]]
+// CHECK: ^[[BB2_COPIES]]:
+// CHECK-NEXT: cf.br ^[[BB2:[[:alnum:]]+]]
+// CHECK: ^[[BB2]]:
+// CHECK-NEXT: %[[NEXT_TILE:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}}, %[[ITER_TILE]]
+// CHECK: %[[BB1_COPY_1:.*]] = arm_sme.copy_tile %[[NEXT_TILE]] : vector<[4]x[4]xf32>
+// CHECK: cf.br ^bb1(%{{[[:alnum:]]+}}, %[[BB1_COPY_1]]
+// CHECK: ^[[BB3]](%{{.*}}: vector<[4]x[4]xf32>):
+// CHECK-NEXT: return
+func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector<[4]x[4]xf32>, %tileC: vector<[4]x[4]xf32>, %tileD: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ // Live here: %tileA, %tileB, %tileC, %tileD
+ cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+ %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+ // Live here: %iterTile, %tileB, %tileC, %tileD
+ // %iterTile, %tileB, %tileC, %tileD are live out (in the ^bb2 case). If we
+ // inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles.
+ // However, note that the copies are only needed if we take the ^bb3 path. So, if we add
+ // a new block along that path we can insert the copies without any conflicts.
+ cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+ // Live here: %iterTile, %tileB, %tileC, %tileD
+ %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+ %nextIndex = arith.addi %currentIndex, %c1 : index
+ cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+ // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @tile_dominance
+// CHECK-NOT: arm_sme.copy_tile
+func.func @tile_dominance(%arg0: vector<[4]x[4]xf32>) {
+ cf.br ^bb1
+^bb1: // 2 preds: ^bb0, ^bb4
+ "test.some_use"(%arg0) : (vector<[4]x[4]xf32>) -> ()
+ return
+^bb2: // no predecessors
+ %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+ cf.br ^bb3
+^bb3: // pred: ^bb2
+ "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
+ return
+^bb4: // no predecessors
+ cf.br ^bb1
+^bb5: // no predecessors
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @cond_branch_true_and_false_tile_args(
+// CHECK-SAME: %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[BB2_COPIES:[[:alnum:]]+]]
+// CHECK: ^[[BB2_COPIES]]:
+// CHECK-NEXT: %[[COPY_0:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.br ^[[BB2:[[:alnum:]]+]](%[[COPY_0]]
+// CHECK: ^[[BB1_COPIES]]:
+// CHECK-NEXT: %[[COPY_1:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.br ^[[BB1:[[:alnum:]]+]](%[[COPY_1]]
+// CHECK: ^[[BB1]]{{.*}}:
+// CHECK-NEXT: return
+// CHECK: ^[[BB2]]{{.*}}:
+// CHECK-NEXT: return
+func.func @cond_branch_true_and_false_tile_args(%cond: i1, %tile: vector<[4]x[4]xf32>) {
+ cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2(%tile: vector<[4]x[4]xf32>)
+^bb1(%blockArg0: vector<[4]x[4]xf32>):
+ return
+^bb2(%blockArg1: vector<[4]x[4]xf32>):
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @multiple_predecessors
+// CHECK: ^bb1:
+// CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
+// CHECK-NEXT: %[[COPY_0:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.br ^bb3(%[[COPY_0]] : vector<[4]x[4]xf32>)
+// CHECK: ^bb2:
+// CHECK-NEXT: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xf32>
+// CHECK-NEXT: %[[COPY_1:.*]] = arm_sme.copy_tile %[[ZERO]] : vector<[4]x[4]xf32>
+// CHECK-NEXT: cf.br ^bb3(%[[COPY_1]] : vector<[4]x[4]xf32>)
+// CHECK: ^bb3({{.*}}):
+// CHECK-NEXT: return
+func.func @multiple_predecessors(%cond: i1)
+{
+ cf.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
+ cf.br ^bb3(%tile : vector<[4]x[4]xf32>)
+^bb2:
+ %zero = arm_sme.zero : vector<[4]x[4]xf32>
+ cf.br ^bb3(%zero : vector<[4]x[4]xf32>)
+^bb3(%blockArg: vector<[4]x[4]xf32>): // pred: ^bb1, ^bb2
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
index 39d9ab6491e3b..6b5e44365bf58 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir
@@ -1,19 +1,17 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics
-// -----
+// Select between tileA and tileB. This is currently unsupported as it would
+// require inserting (runtime) tile moves.
-func.func @selecting_between_
diff erent_tiles_is_unsupported(%dest : memref<?x?xi32>, %cond: i1) {
+// expected-note at below {{tile operand is: <block argument> of type 'vector<[4]x[4]xi32>'}}
+func.func @selecting_between_
diff erent_tiles_is_unsupported(%dest : memref<?x?xi32>, %tileA : vector<[4]x[4]xi32>, %tileB : vector<[4]x[4]xi32>, %cond: i1) {
%c0 = arith.constant 0 : index
- %tileA = arm_sme.get_tile : vector<[4]x[4]xi32>
- %tileB = arm_sme.get_tile : vector<[4]x[4]xi32>
- // Select between tileA and tileB. This is currently unsupported as it would
- // require inserting tile move operations during tile allocation.
+ // expected-error at below {{op tile operand allocated to
diff erent SME virtial tile (move required)}}
%tile = scf.if %cond -> vector<[4]x[4]xi32> {
scf.yield %tileA : vector<[4]x[4]xi32>
} else {
scf.yield %tileB : vector<[4]x[4]xi32>
}
- // expected-error at +1 {{op already assigned
diff erent SME virtual tile!}}
arm_sme.tile_store %tile, %dest[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 2dedcb2fbc24e..88fc8a8923d34 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -1,18 +1,26 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK-BAD
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation=dump-tile-live-ranges -mlir-disable-threading -split-input-file -verify-diagnostics 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-LIVE-RANGE
-// This file tests some aspects of liveness issues in the SME tile allocator.
-// These tests were designed with a new liveness-based tile allocator in mind
-// (where the names of test cases make more sense), with the current tile
-// allocator these tests all give incorrect results (which is documented by
-// `CHECK-BAD`).
+// This file tests some simple aspects of using liveness in the SME tile allocator.
+// Note: We use -convert-scf-to-cf first as the tile allocator expects CF, but
+// some of these tests are written in SCF (to make things easier to follow).
-// Incorrect result! The second `move_vector_to_tile_slice` overwrites the first (which is still live).
-//
-// CHECK-BAD-LABEL: @constant_with_multiple_users
-// CHECK-BAD: %[[ZERO_TILE:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-LIVE-RANGE-LABEL: @constant_with_multiple_users
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |E test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+
+// CHECK-LABEL: @constant_with_multiple_users(
+// CHECK-SAME: %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32>
func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
+ // CHECK-NEXT: %[[ZERO_TILE_0:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[ZERO_TILE_1:.*]] = arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_A]], %[[ZERO_TILE_1]], %{{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_B]], %[[ZERO_TILE_0]], %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
%zero = arm_sme.zero : vector<[4]x[4]xf32>
%tile_a = arm_sme.move_vector_to_tile_slice %a, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_b = arm_sme.move_vector_to_tile_slice %b, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
@@ -23,12 +31,17 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>
// -----
-// (No tile IDs -- the current tile allocator ignores this case)
+// CHECK-LIVE-RANGE-LABEL: @value_with_multiple_users
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |E test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
-// CHECK-BAD-LABEL: @value_with_multiple_users
-// CHECK-BAD-NOT: tile_id
+// expected-note at below {{tile operand is: <block argument> of type 'vector<[4]x[4]xf32>'}}
func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) {
- // A future allocator should error here (as `%tile` would need to be copied).
+ // expected-error at below {{op tile operand allocated to
diff erent SME virtial tile (move required)}}
%tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32>
"test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
@@ -38,12 +51,38 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x
// -----
-// CHECK-BAD-LABEL: @reuse_tiles_after_initial_use
+// CHECK-LIVE-RANGE-LABEL: @reuse_tiles_after_initial_use
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: |||| test.dummy
+// CHECK-LIVE-RANGE-NEXT: E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+
+// CHECK-LABEL: @reuse_tiles_after_initial_use
func.func @reuse_tiles_after_initial_use() {
- // CHECK-BAD: arm_sme.get_tile {tile_id = 0 : i32}
- // CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
- // CHECK-BAD: arm_sme.get_tile {tile_id = 2 : i32}
- // CHECK-BAD: arm_sme.get_tile {tile_id = 3 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 2 : i32}
+ // CHECK: arm_sme.get_tile {tile_id = 3 : i32}
%tile_a = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_b = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_c = arm_sme.get_tile : vector<[4]x[4]xf32>
@@ -55,19 +94,13 @@ func.func @reuse_tiles_after_initial_use() {
"test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_d) : (vector<[4]x[4]xf32>) -> ()
- // -> Spills after the fourth tile (unnecessary):
- // CHECK-BAD: arm_sme.zero {tile_id = 16 : i32}
- // CHECK-BAD: arm_sme.zero {tile_id = 17 : i32}
- // CHECK-BAD: arm_sme.zero {tile_id = 18 : i32}
- // CHECK-BAD: arm_sme.zero {tile_id = 19 : i32}
- // Unnecessary spills:
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // CHECK: arm_sme.zero {tile_id = 0 : i32}
+ // CHECK: arm_sme.zero {tile_id = 1 : i32}
+ // CHECK: arm_sme.zero {tile_id = 2 : i32}
+ // CHECK: arm_sme.zero {tile_id = 3 : i32}
%tile_1 = arm_sme.zero : vector<[4]x[4]xf32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_2 = arm_sme.zero : vector<[4]x[4]xf32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_3 = arm_sme.zero : vector<[4]x[4]xf32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_4 = arm_sme.zero : vector<[4]x[4]xf32>
"test.dummy"(): () -> ()
"test.dummy"(): () -> ()
@@ -81,16 +114,123 @@ func.func @reuse_tiles_after_initial_use() {
// -----
-// Incorrect result! Both branches should yield the result via the same tile.
+// CHECK-LIVE-RANGE-LABEL: @tile_live_ins
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: EE cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb1:
+// CHECK-LIVE-RANGE-NEXT: || test.dummy
+// CHECK-LIVE-RANGE-NEXT: || test.dummy
+// CHECK-LIVE-RANGE-NEXT: EE cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: || test.dummy
+// CHECK-LIVE-RANGE-NEXT: || test.dummy
+// CHECK-LIVE-RANGE-NEXT: EE cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb3:
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+
+// CHECK-LABEL: @tile_live_ins
+func.func @tile_live_ins()
+{
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+ %tile_1 = arm_sme.get_tile : vector<[4]x[4]xf32>
+ %tile_2 = arm_sme.zero : vector<[4]x[4]xf32>
+ cf.br ^bb1
+^bb1:
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ cf.br ^bb2
+^bb2:
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ cf.br ^bb3
+^bb3:
+ "test.some_use"(%tile_1) : (vector<[4]x[4]xf32>) -> ()
+ "test.some_use"(%tile_2) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+// This is basically the same test as tile_live_ins but shows that the order of
+// the blocks within the source does not relate to the liveness, which is based
+// on successors and predecessors (not textual order).
+//
+// So %tile_1 is live on the path bb0 -> bb2 -> bb1 (and dies in bb1). The
+// 'hole' when looking at the live range dump comes from the textual order
+// (and would disappear if bb1 was moved before bb2 in the source).
//
-// CHECK-BAD-LABEL: @non_overlapping_branches
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+// When looking at the live range dump (outside of straight-line code) it
+// normally makes more sense to consider blocks in isolation (and how they
+// relate to the CFG).
+
+// CHECK-LIVE-RANGE-LABEL: @non_sequential_live_ins
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: | test.dummy
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb1:
+// CHECK-LIVE-RANGE-NEXT: E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: | test.dummy
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: || test.dummy
+// CHECK-LIVE-RANGE-NEXT: EE cf.cond_br
+// CHECK-LIVE-RANGE-NEXT: ^bb3:
+// CHECK-LIVE-RANGE-NEXT: | test.dummy
+// CHECK-LIVE-RANGE-NEXT: E test.some_use
+// CHECK-LIVE-RANGE-NEXT: func.return
+
+// CHECK-LABEL: @non_sequential_live_ins
+func.func @non_sequential_live_ins(%cond: i1) {
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+ %tile_1 = arm_sme.get_tile : vector<[4]x[4]xf32>
+ "test.dummy"(): () -> ()
+ cf.br ^bb2
+^bb1:
+ "test.some_use"(%tile_1) : (vector<[4]x[4]xf32>) -> ()
+ "test.dummy"(): () -> ()
+ cf.br ^bb3
+^bb2:
+ %tile_2 = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.dummy"(): () -> ()
+ cf.cond_br %cond, ^bb1, ^bb3
+^bb3:
+ "test.dummy"(): () -> ()
+ "test.some_use"(%tile_2) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+// CHECK-LIVE-RANGE-LABEL: @non_overlapping_branches
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero
+// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+// CHECK-LIVE-RANGE-NEXT: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile
+// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: E cf.br
+
+// CHECK-LABEL: @non_overlapping_branches
func.func @non_overlapping_branches(%cond: i1) {
+ // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
%tile = scf.if %cond -> vector<[4]x[4]xf32> {
+ // ^bb1:
%zero = arm_sme.zero : vector<[4]x[4]xf32>
scf.yield %zero : vector<[4]x[4]xf32>
} else {
+ // ^bb2:
%undef = arm_sme.get_tile : vector<[4]x[4]xf32>
scf.yield %undef : vector<[4]x[4]xf32>
}
@@ -100,52 +240,65 @@ func.func @non_overlapping_branches(%cond: i1) {
// -----
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @constant_loop_init_with_multiple_users
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
-func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c10 = arith.constant 10 : index
- %init = arm_sme.zero : vector<[4]x[4]xf32>
- %tile_a = scf.for %i = %c0 to %c10 step %c1 iter_args(%iter = %init) -> vector<[4]x[4]xf32> {
- %new_tile = arm_sme.move_vector_to_tile_slice %a, %iter, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
- scf.yield %new_tile : vector<[4]x[4]xf32>
- }
- %tile_b = scf.for %i = %c0 to %c10 step %c1 iter_args(%iter = %init) -> vector<[4]x[4]xf32> {
- %new_tile = arm_sme.move_vector_to_tile_slice %a, %iter, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
- scf.yield %new_tile : vector<[4]x[4]xf32>
+// Here %vecA and %vecB are not merged into the same live range (as they are unknown values).
+// This means that %vecA and %vecB are both allocated to
diff erent tiles (which is not legal).
+
+// expected-note at below {{tile operand is: <block argument> of type 'vector<[4]x[4]xf32>'}}
+func.func @overlapping_branches(%cond: i1, %vecA: vector<[4]x[4]xf32>, %vecB: vector<[4]x[4]xf32>) {
+ // expected-error at below {{op tile operand allocated to
diff erent SME virtial tile (move required)}}
+ %tile = scf.if %cond -> vector<[4]x[4]xf32> {
+ scf.yield %vecA : vector<[4]x[4]xf32>
+ } else {
+ scf.yield %vecB : vector<[4]x[4]xf32>
}
- "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
- "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
+ "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
return
}
// -----
-// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @run_out_of_tiles_but_avoid_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-LIVE-RANGE-LABEL: @run_out_of_tiles_but_avoid_spill
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: |S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE-NEXT: EEEEE cf.br
+
+// Note in the live ranges (above) there is five tile values, but we only have four tiles.
+// There is no 'real' spill as we spill the `arm_sme.zero` but are then able to clone it
+// at each of its uses.
+
+// CHECK-LABEL: @run_out_of_tiles_but_avoid_spill
func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
%init = arm_sme.zero : vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
+ // Live = %init
scf.for %i = %c0 to %c10 step %c1 {
+ // CHECK: arm_sme.zero {tile_id = 1 : i32}
+ // CHECK: arm_sme.zero {tile_id = 2 : i32}
+ // CHECK: arm_sme.zero {tile_id = 3 : i32}
+ // CHECK: arm_sme.zero {tile_id = 0 : i32}
%tile_a, %tile_b, %tile_c, %tile_d = scf.for %j = %c0 to %c10 step %c1
iter_args(%iter_a = %init, %iter_b = %init, %iter_c = %init, %iter_d = %init)
-> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32> , vector<[4]x[4]xf32> , vector<[4]x[4]xf32>) {
+ // ^bb2:
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 2 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 3 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_a = arm_sme.move_vector_to_tile_slice %a, %iter_a, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_b = arm_sme.move_vector_to_tile_slice %b, %iter_b, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_c = arm_sme.move_vector_to_tile_slice %c, %iter_c, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
%new_d = arm_sme.move_vector_to_tile_slice %d, %iter_d, %i : vector<[4]xf32> into vector<[4]x[4]xf32>
scf.yield %new_a, %new_b, %new_c, %new_d : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
}
+ // Live = %init, %tile_a, %tile_b, %tile_c, %tile_d (out of tiles!)
+ // This should be resolved by duplicating the arm_sme.zero (from folding
+ // arm_sme.copy_tile operations inserted by the tile allocator).
"test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
@@ -156,24 +309,48 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x
// -----
-// Incorrect result! Everything other than zero assigned to tile 1 (which means values that are still live are overwritten).
-//
-// CHECK-BAD-LABEL: @avoidable_spill
-// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32}
-// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32}
-// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32}
+// We should be able to avoid spills like this, but logic handling this case is
+// not implemented yet. Note tile ID >= 16 means a spill/in-memory tile.
+
+// CHECK-LIVE-RANGE-LABEL: @avoidable_spill
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE-NEXT: || test.some_use
+// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice
+// CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E|| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E| test.some_use
+// CHECK-LIVE-RANGE-NEXT: || E test.some_use
+// CHECK-LIVE-RANGE-NEXT: || arith.addi
+// CHECK-LIVE-RANGE-NEXT: EE cf.br
+
+// Note in the live ranges (above) there is two constant live-ins (first two ranges),
+// which gives six overlapping live ranges (at the point where %tile_d is defined).
+// The allocator currently will spill the first constant (which results in a real
+// spill at it's use), however, this could be avoided by using the knowledge that
+// at the first "test.some_use" there's actually only two live ranges (so we can
+// fix this be duplicating the constant).
+
+// CHECK-LABEL: @avoidable_spill
func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) {
+ // CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[4]x[4]xf32>
%zero = arm_sme.zero : vector<[4]x[4]xf32>
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
scf.for %i = %c0 to %c10 step %c1 {
+ // So spilled here (unnecessarily).
+ // The arm_sme.zero op could be moved into the loop to avoid this.
"test.some_use"(%zero) : (vector<[4]x[4]xf32>) -> ()
%tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_c = arm_sme.move_vector_to_tile_slice %c, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
%tile_d = arm_sme.move_vector_to_tile_slice %d, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // %zero is still live here (due the the backedge)
"test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> ()
"test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> ()
@@ -181,3 +358,75 @@ func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<
}
return
}
+
+// -----
+
+// This test is a follow up to the test of the same name in `tile-allocation-copies.mlir`.
+// This shows the live ranges (which are why we need to split the conditional branch).
+
+// CHECK-LIVE-RANGE-LABEL: @cond_branch_with_backedge
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE--NEXT: ||| | arith.cmpi
+// CHECK-LIVE-RANGE--NEXT: EEE E cf.cond_br
+//
+// CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES:[[:alnum:]]+]]:
+// CHECK-LIVE-RANGE--NEXT: ||| ES arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: E|| |S arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: E| ||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: E |||S arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: EEEE cf.br
+//
+// It is important to note that the first three live ranges in ^bb1 do not end
+// at the `cf.cond_br` they are live-out via the backedge bb1 -> bb2 -> bb1.
+// This means that if we placed the `arm_sme.tile_copies` before the `cf.cond_br`
+// then those live ranges would not end at the copies, resulting in unwanted
+// overlapping live ranges (and hence tile spills).
+//
+// With the conditional branch split and the copies placed in the BB3_COPIES
+// block the first three live ranges end at the copy operations (as the
+// BB3_COPIES block is on the path out of the loop and has no backedge). This
+// means there is no overlaps and the live ranges all merge, as shown below.
+//
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE--NEXT: |||| arith.cmpi
+// CHECK-LIVE-RANGE--NEXT: EEEE cf.cond_br
+//
+// CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES]]:
+// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile
+// CHECK-LIVE-RANGE--NEXT: EEEE cf.br
+
+// CHECK-LABEL: @cond_branch_with_backedge
+// CHECK-NOT: tile_id = 16
+// CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 2 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.get_tile {tile_id = 3 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32>
+// CHECK-NOT tile_id = 16
+func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
+ %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+ %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+ %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+ %tileD = arm_sme.get_tile : vector<[4]x[4]xf32>
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c10 = arith.constant 10 : index
+ // Live here: %tileA, %tileB, %tileC, %tileD
+ cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>)
+^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>):
+ %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index
+ // Live here: %iterTile, %tileB, %tileC, %tileD
+ cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>)
+^bb2:
+ // Live here: %iterTile, %tileB, %tileC, %tileD
+ %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32>
+ %nextIndex = arith.addi %currentIndex, %c1 : index
+ cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>)
+^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>):
+ // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
+ return
+}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir
new file mode 100644
index 0000000000000..27757e29c1e2f
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir
@@ -0,0 +1,38 @@
+
+// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s
+
+// CHECK-LABEL: @always_spill_larger_or_equal_tile_type
+// CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.zero {tile_id = 2 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.zero {tile_id = 3 : i32} : vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} {tile_id = 16 : i32} : memref<?x?xf16>, vector<[8]x[8]xf16>
+func.func @always_spill_larger_or_equal_tile_type(%memref: memref<?x?xf16>) -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[8]x[8]xf16>) {
+ %c0 = arith.constant 0 : index
+ %0 = arm_sme.zero : vector<[4]x[4]xf32>
+ %1 = arm_sme.zero : vector<[4]x[4]xf32>
+ %2 = arm_sme.zero : vector<[4]x[4]xf32>
+ %3 = arm_sme.zero : vector<[4]x[4]xf32>
+ // The load will be spilled (even though the zero's are 'trivial' spills) as a single `f32` tile would not fit the load.
+ %load = arm_sme.tile_load %memref[%c0, %c0] : memref<?x?xf16>, vector<[8]x[8]xf16>
+ return %0, %1, %2, %3, %load : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[8]x[8]xf16>
+}
+
+// -----
+
+// CHECK-LABEL: @spill_larger_tile_type
+// CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[16]x[16]xi8>
+// CHECK: arm_sme.tile_load {{.*}} {tile_id = 0 : i32} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} {tile_id = 1 : i32} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} {tile_id = 2 : i32} : memref<?x?xf32>, vector<[4]x[4]xf32>
+// CHECK: arm_sme.tile_load {{.*}} {tile_id = 3 : i32} : memref<?x?xf32>, vector<[4]x[4]xf32>
+func.func @spill_larger_tile_type(%memref: memref<?x?xf32>) -> (vector<[16]x[16]xi8>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) {
+ %c0 = arith.constant 0 : index
+ // Spilling the `arm_sme.zero` should free up space for all four f32 tiles.
+ %0 = arm_sme.zero : vector<[16]x[16]xi8>
+ %1 = arm_sme.tile_load %memref[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %2 = arm_sme.tile_load %memref[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %3 = arm_sme.tile_load %memref[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ %4 = arm_sme.tile_load %memref[%c0, %c0] : memref<?x?xf32>, vector<[4]x[4]xf32>
+ return %0, %1, %2, %3, %4 : vector<[16]x[16]xi8>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>
+}
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index cac2dcc24d104..ca339be5fb56f 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,canonicalize))" | FileCheck %s
// This test verifies the tile mask operand of the zero intrinsic zeroes
// the correct tiles. Both integer and floating-point datatypes are checked.
diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
index 588b44a36c29f..14d9712e971a8 100644
--- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
+++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir
@@ -13,26 +13,29 @@
/// performance (hence the warning).
func.func @use_too_many_tiles(%a: memref<?x?xi16>, %b: memref<?x?xi16>, %c: memref<?x?xi16>) {
%c0 = arith.constant 0 : index
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_a = arith.constant dense<0> : vector<[8]x[8]xi16>
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_b = arith.constant dense<1> : vector<[8]x[8]xi16>
// expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_c = arm_sme.tile_load %a[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_d = arm_sme.tile_load %b[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
%tile_e = arm_sme.tile_load %c[%c0, %c0] : memref<?x?xi16>, vector<[8]x[8]xi16>
// CHECK-LABEL: tile_a:
// CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0
vector.print str "tile_a:\n"
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
vector.print %tile_a : vector<[8]x[8]xi16>
// CHECK-LABEL: tile_b:
// CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1
vector.print str "tile_b:\n"
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
vector.print %tile_b : vector<[8]x[8]xi16>
// CHECK-LABEL: tile_c:
// CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2
vector.print str "tile_c:\n"
+ // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
vector.print %tile_c : vector<[8]x[8]xi16>
// CHECK-LABEL: tile_d:
// CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
index 1794564a6a724..0648e771b8891 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir
@@ -1,5 +1,6 @@
// DEFINE: %{entry_point} = main
-// DEFINE: %{compile} = mlir-opt %s -convert-arm-sme-to-llvm -test-lower-to-llvm
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm),test-lower-to-llvm)"
// DEFINE: %{run} = %mcr_aarch64_cmd \
// DEFINE: -march=aarch64 -mattr=+sve,+sme \
// DEFINE: -e %{entry_point} -entry-point-result=void \
diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
index e942c7b8ac058..cdd8afe141421 100644
--- a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt
@@ -15,4 +15,5 @@ add_mlir_library(MLIRArmSMETestPasses
MLIRTransforms
MLIRVectorToArmSME
MLIRVectorToSCF
+ MLIRSCFToControlFlow
)
diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
index 48d4a5859f8a0..d3dabaf200fdc 100644
--- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
+++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp
@@ -14,10 +14,12 @@
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h"
+#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
@@ -34,6 +36,10 @@ struct TestLowerToArmSMEOptions
llvm::cl::desc("Fuse outer product operations via "
"'-arm-sme-outer-product-fusion' pass"),
llvm::cl::init(true)};
+ PassOptions::Option<bool> dumpTileLiveRanges{
+ *this, "dump-tile-live-ranges",
+ llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"),
+ llvm::cl::init(false)};
};
void buildTestLowerToArmSME(OpPassManager &pm,
@@ -65,20 +71,17 @@ void buildTestLowerToArmSME(OpPassManager &pm,
pm.addPass(createConvertVectorToSCFPass(
VectorTransferToSCFOptions().enableFullUnroll()));
- // Allocate tiles for ArmSME operations.
- //
- // Later passes may create further ArmSME ops that implement the
- // ArmSMETileOpInterface, but tiles are allocated for root operations,
- // all of which should now exist.
- pm.addPass(arm_sme::createTileAllocationPass());
-
// Enable streaming-mode and ZA.
pm.addPass(arm_sme::createEnableArmStreamingPass(
arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA,
/*onlyIfRequiredByOps=*/true));
+ // Convert SCF to CF (required for ArmSME tile allocation).
+ pm.addPass(createConvertSCFToCFPass());
+
// Convert ArmSME to LLVM.
- pm.addPass(createConvertArmSMEToLLVMPass());
+ pm.addNestedPass<func::FuncOp>(
+ createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges));
// Sprinkle some cleanups.
pm.addPass(createCanonicalizerPass());
More information about the Mlir-commits
mailing list