[Mlir-commits] [mlir] [mlir][ArmSME] Support filling liveness 'holes' in the tile allocator (PR #98350)
Benjamin Maxwell
llvmlistbot at llvm.org
Wed Jul 10 09:57:51 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/98350
Holes in a live range are points where the corresponding value does not need to be in a tile/register. If the tile allocator keeps track of these holes it can reuse tiles for more values (avoiding spills).
Take this simple example:
```mlir
func.func @example(%cond: i1) {
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// If we end up here we never use %tileA again!
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb2:
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb3:
return
}
```
If you were to calculate the liveness of %tileA and %tileB. You'd see there is a hole in the liveness of %tileA in bb1:
```
%tileA %tileB
^bb0: Live
^bb1: Live
^bb2: Live
```
The tile allocator can make use of that hole and reuse the tile ID it assigned to %tileA for %tileB.
>From d4a0f1d476a018e4bd301659b8cfa0438232ddfa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 9 Jul 2024 16:22:40 +0000
Subject: [PATCH] [mlir][ArmSME] Support filling liveness 'holes' in the tile
allocator
Holes in a live range are points where the corresponding value does not
need to be in a tile/register. If the tile allocator keeps track of
these holes it can reuse tiles for more values (avoiding spills).
Take this simple example:
```mlir
func.func @example(%cond: i1) {
%tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
cf.cond_br %cond, ^bb2, ^bb1
^bb1:
// If we end up here we never use %tileA again!
"test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb2:
"test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
cf.br ^bb3
^bb3:
return
}
```
If you were to calculate the liveness of %tileA and %tileB. You'd see
there is a hole in the liveness of %tileA in bb1:
```
%tileA %tileB
^bb0: Live
^bb1: Live
^bb2: Live
```
The tile allocator can make use of that hole and reuse the tile ID it
assigned to %tileA for %tileB.
---
.../ArmSME/Transforms/TileAllocation.cpp | 113 +++++++++++----
.../ArmSME/tile-allocation-liveness.mlir | 130 ++++++++++++++++++
2 files changed, 218 insertions(+), 25 deletions(-)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 733e758b43907..6023871c5affe 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -153,10 +153,18 @@ class TileAllocator {
return failure();
}
+ /// Acquires a specific tile ID. Asserts the tile is initially free.
+ void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) == TileMask::kNone &&
+ "cannot acquire allocated tile!");
+ tilesInUse |= tileMask;
+ }
+
/// Releases a previously allocated tile ID.
void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
TileMask tileMask = getMasks(tileType)[tileId];
- assert((tilesInUse & tileMask) != TileMask::kNone &&
+ assert((tilesInUse & tileMask) == tileMask &&
"cannot release unallocated tile!");
tilesInUse ^= tileMask;
}
@@ -289,6 +297,11 @@ struct LiveRange {
.valid();
}
+ /// Returns true if this range overlaps with `point`.
+ bool overlaps(uint64_t point) const {
+ return ranges->lookup(point) == kValidLiveRange;
+ }
+
/// Unions this live range with `otherRange`, aborts if the ranges overlap.
void unionWith(LiveRange const &otherRange) {
for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
@@ -488,69 +501,113 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
return std::move(coalescedLiveRanges);
}
-/// Choose a live range to spill (via some heuristics). This picks either an
-/// active live range from `activeRanges` or the new live range `newRange`.
+/// Choose a live range to spill (via some heuristics). This picks either a live
+/// range from `activeRanges`, `inactiveRanges`, or the new live range
+/// `newRange`. Note: All live ranges in `activeRanges` and `inactiveRanges` are
+/// assumed to overlap with `newRange`.
LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
+ ArrayRef<LiveRange *> inactiveRanges,
LiveRange *newRange) {
+ auto allOverlappingRanges =
+ llvm::concat<LiveRange>(llvm::make_pointee_range(activeRanges),
+ llvm::make_pointee_range(inactiveRanges));
+
// Heuristic: Spill trivially copyable operations (usually free).
- auto isTrivialSpill = [&](LiveRange *allocatedRange) {
- return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
+ auto isTrivialSpill = [&](LiveRange &allocatedRange) {
+ return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
newRange->getTileType()) &&
- allocatedRange->values.size() == 1 &&
+ allocatedRange.values.size() == 1 &&
isTriviallyCloneableTileOp(
- allocatedRange->values[0]
- .getDefiningOp<ArmSMETileOpInterface>());
+ allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
};
- if (isTrivialSpill(newRange))
+ if (isTrivialSpill(*newRange))
return newRange;
- auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
- if (trivialSpill != activeRanges.end())
- return *trivialSpill;
+ auto trivialSpill = llvm::find_if(allOverlappingRanges, isTrivialSpill);
+ if (trivialSpill != allOverlappingRanges.end())
+ return &*trivialSpill;
// Heuristic: Spill the range that ends last (with a compatible tile type).
- auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
- return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
- a->end() < b->end();
+ auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
+ return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
+ a.end() < b.end();
};
- LiveRange *lastActiveLiveRange = *std::max_element(
- activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
- if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
- return lastActiveLiveRange;
+ LiveRange &lastActiveLiveRange = *std::max_element(
+ allOverlappingRanges.begin(), allOverlappingRanges.end(),
+ isSmallerTileTypeOrEndsEarlier);
+ if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, *newRange))
+ return &lastActiveLiveRange;
return newRange;
}
/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
-/// Note: This does not attempt to fill holes in active live ranges.
void allocateTilesToLiveRanges(
ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
TileAllocator tileAllocator;
SetVector<LiveRange *> activeRanges;
+ SetVector<LiveRange *> inactiveRanges;
for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
- // Release tile IDs from live ranges that have ended.
activeRanges.remove_if([&](LiveRange *activeRange) {
+ // Check for live ranges that have expired.
if (activeRange->end() <= nextRange->start()) {
tileAllocator.releaseTileId(activeRange->getTileType(),
*activeRange->tileId);
return true;
}
+ // Check for live ranges that have become inactive.
+ if (!activeRange->overlaps(nextRange->start())) {
+ tileAllocator.releaseTileId(activeRange->getTileType(),
+ *activeRange->tileId);
+ inactiveRanges.insert(activeRange);
+ return true;
+ }
+ return false;
+ });
+ inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
+ // Check for live ranges that have expired.
+ if (inactiveRange->end() <= nextRange->start()) {
+ return true;
+ }
+ // Check for live ranges that have become active.
+ if (inactiveRange->overlaps(nextRange->start())) {
+ tileAllocator.acquireTileId(inactiveRange->getTileType(),
+ *inactiveRange->tileId);
+ activeRanges.insert(inactiveRange);
+ return true;
+ }
return false;
});
+ // Collect inactive live ranges that overlap with the current new live
+ // range. We need to acquire the tile IDs of overlapping inactive ranges to
+ // prevent two (overlapping) live ranges from getting the same tile ID.
+ SmallVector<LiveRange *> overlappingInactiveRanges;
+ for (LiveRange *inactiveRange : inactiveRanges) {
+ if (inactiveRange->overlaps(*nextRange)) {
+ tileAllocator.acquireTileId(inactiveRange->getTileType(),
+ *inactiveRange->tileId);
+ overlappingInactiveRanges.push_back(inactiveRange);
+ }
+ }
+
// Allocate a tile ID to `nextRange`.
auto rangeTileType = nextRange->getTileType();
auto tileId = tileAllocator.allocateTileId(rangeTileType);
if (succeeded(tileId)) {
nextRange->tileId = *tileId;
} else {
- LiveRange *rangeToSpill =
- chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
+ LiveRange *rangeToSpill = chooseSpillUsingHeuristics(
+ activeRanges.getArrayRef(), overlappingInactiveRanges, nextRange);
if (rangeToSpill != nextRange) {
- // Spill an active live range (so release its tile ID first).
+ // Spill an (in)active live range (so release its tile ID first).
tileAllocator.releaseTileId(rangeToSpill->getTileType(),
*rangeToSpill->tileId);
- activeRanges.remove(rangeToSpill);
// This will always succeed after a spill (of an active live range).
nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
+ // Remove the live range from the active/inactive sets.
+ if (!activeRanges.remove(rangeToSpill)) {
+ bool removed = inactiveRanges.remove(rangeToSpill);
+ assert(removed && "expected a range to be removed!");
+ }
}
rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
}
@@ -558,6 +615,12 @@ void allocateTilesToLiveRanges(
// Insert the live range into the active ranges.
if (nextRange->tileId < kInMemoryTileIdBase)
activeRanges.insert(nextRange);
+
+ // Release tiles reserved for inactive live ranges.
+ for (LiveRange *range : overlappingInactiveRanges) {
+ if (*range->tileId < kInMemoryTileIdBase)
+ tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
+ }
}
}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 9c22b29ac22e7..59afa654778e5 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -430,3 +430,133 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
// Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
return
}
+
+// -----
+
+// CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE: S arm_sme.get_tile
+// CHECK-LIVE-RANGE: E cf.cond_br
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE: S arm_sme.get_tile
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: E test.some_use
+// CHECK-LIVE-RANGE: cf.br
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: E test.some_use
+// CHECK-LIVE-RANGE: cf.br
+
+// Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
+// can reuse the tile ID (0) assigned to %tileA.
+
+// CHECK-LABEL: @fill_holes_in_tile_liveness
+func.func @fill_holes_in_tile_liveness(%cond: i1) {
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+ cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+ "test.dummy"(): () -> ()
+ "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
+ cf.br ^bb3
+^bb2:
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+ cf.br ^bb3
+^bb3:
+ return
+}
+
+// -----
+
+// CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
+// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+// CHECK-LIVE-RANGE: ^bb0:
+// CHECK-LIVE-RANGE: S arm_sme.get_tile
+// CHECK-LIVE-RANGE: E cf.cond_br
+// CHECK-LIVE-RANGE: ^bb1:
+// CHECK-LIVE-RANGE: S arm_sme.get_tile
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: | test.some_use
+// CHECK-LIVE-RANGE: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE: E cf.br
+// CHECK-LIVE-RANGE: ^bb2:
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: | test.dummy
+// CHECK-LIVE-RANGE: |S arm_sme.get_tile
+// CHECK-LIVE-RANGE: E| test.some_use
+// CHECK-LIVE-RANGE: | arm_sme.copy_tile
+// CHECK-LIVE-RANGE: E cf.br
+// CHECK-LIVE-RANGE: ^bb3:
+// CHECK-LIVE-RANGE: E test.some_use
+// CHECK-LIVE-RANGE: func.return
+
+// This tests an edge case in inactive live ranges. The first live range is
+// inactive at the start of ^bb1. If the tile allocator did not check if the
+// second live range overlapped the first it would wrongly re-use tile ID 0
+// (as the first live range is inactive so tile ID 0 is free). This would mean
+// in ^bb2 two overlapping live ranges would have the same tile ID (bad!).
+
+// CHECK-LABEL: @holes_in_tile_liveness_inactive_overlaps
+func.func @holes_in_tile_liveness_inactive_overlaps(%cond: i1) {
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+ cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+ // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+ %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+ "test.dummy"(): () -> ()
+ "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
+ cf.br ^bb3(%tileB: vector<[4]x[4]xf32>)
+^bb2:
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+ %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+ "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+ cf.br ^bb3(%tileC: vector<[4]x[4]xf32>)
+^bb3(%tile: vector<[4]x[4]xf32>):
+ "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+// This is the same as the previous example, but changes the tile types to
+// vector<[16]x[16]xi8>. This means in bb1 the allocator will need to spill the
+// first live range (which is inactive).
+
+// Note: The live ranges are the same as the previous example (so are not checked).
+
+// CHECK-LABEL: @spill_inactive_live_range
+func.func @spill_inactive_live_range(%cond: i1) {
+ // CHECK: arm_sme.get_tile {tile_id = 16 : i32}
+ %tileA = arm_sme.get_tile : vector<[16]x[16]xi8>
+ cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ %tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
+ "test.dummy"(): () -> ()
+ "test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
+ cf.br ^bb3(%tileB: vector<[16]x[16]xi8>)
+^bb2:
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ "test.dummy"(): () -> ()
+ // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+ %tileC = arm_sme.get_tile : vector<[16]x[16]xi8>
+ "test.some_use"(%tileA) : (vector<[16]x[16]xi8>) -> ()
+ cf.br ^bb3(%tileC: vector<[16]x[16]xi8>)
+^bb3(%tile: vector<[16]x[16]xi8>):
+ "test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
+ return
+}
More information about the Mlir-commits
mailing list