[Mlir-commits] [mlir] [mlir][ArmSME] Support filling liveness 'holes' in the tile allocator (PR #98350)

Benjamin Maxwell llvmlistbot at llvm.org
Mon Jul 15 03:47:21 PDT 2024


https://github.com/MacDue updated https://github.com/llvm/llvm-project/pull/98350

>From d4a0f1d476a018e4bd301659b8cfa0438232ddfa Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 9 Jul 2024 16:22:40 +0000
Subject: [PATCH 1/3] [mlir][ArmSME] Support filling liveness 'holes' in the
 tile allocator

Holes in a live range are points where the corresponding value does not
need to be in a tile/register. If the tile allocator keeps track of
these holes it can reuse tiles for more values (avoiding spills).

Take this simple example:

```mlir
func.func @example(%cond: i1) {
  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
  cf.cond_br %cond, ^bb2, ^bb1
^bb1:
  // If we end up here we never use %tileA again!
  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb2:
  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
  cf.br ^bb3
^bb3:
  return
}
```

If you were to calculate the liveness of %tileA and %tileB. You'd see
there is a hole in the liveness of %tileA in bb1:

```
      %tileA  %tileB
^bb0:  Live
^bb1:          Live
^bb2:  Live
```

The tile allocator can make use of that hole and reuse the tile ID it
assigned to %tileA for %tileB.
---
 .../ArmSME/Transforms/TileAllocation.cpp      | 113 +++++++++++----
 .../ArmSME/tile-allocation-liveness.mlir      | 130 ++++++++++++++++++
 2 files changed, 218 insertions(+), 25 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 733e758b43907..6023871c5affe 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -153,10 +153,18 @@ class TileAllocator {
     return failure();
   }
 
+  /// Acquires a specific tile ID. Asserts the tile is initially free.
+  void acquireTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) == TileMask::kNone &&
+           "cannot acquire allocated tile!");
+    tilesInUse |= tileMask;
+  }
+
   /// Releases a previously allocated tile ID.
   void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
     TileMask tileMask = getMasks(tileType)[tileId];
-    assert((tilesInUse & tileMask) != TileMask::kNone &&
+    assert((tilesInUse & tileMask) == tileMask &&
            "cannot release unallocated tile!");
     tilesInUse ^= tileMask;
   }
@@ -289,6 +297,11 @@ struct LiveRange {
         .valid();
   }
 
+  /// Returns true if this range overlaps with `point`.
+  bool overlaps(uint64_t point) const {
+    return ranges->lookup(point) == kValidLiveRange;
+  }
+
   /// Unions this live range with `otherRange`, aborts if the ranges overlap.
   void unionWith(LiveRange const &otherRange) {
     for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
@@ -488,69 +501,113 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
   return std::move(coalescedLiveRanges);
 }
 
-/// Choose a live range to spill (via some heuristics). This picks either an
-/// active live range from `activeRanges` or the new live range `newRange`.
+/// Choose a live range to spill (via some heuristics). This picks either a live
+/// range from `activeRanges`, `inactiveRanges`, or the new live range
+/// `newRange`. Note: All live ranges in `activeRanges` and `inactiveRanges` are
+/// assumed to overlap with `newRange`.
 LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
+                                      ArrayRef<LiveRange *> inactiveRanges,
                                       LiveRange *newRange) {
+  auto allOverlappingRanges =
+      llvm::concat<LiveRange>(llvm::make_pointee_range(activeRanges),
+                              llvm::make_pointee_range(inactiveRanges));
+
   // Heuristic: Spill trivially copyable operations (usually free).
-  auto isTrivialSpill = [&](LiveRange *allocatedRange) {
-    return isTileTypeGreaterOrEqual(allocatedRange->getTileType(),
+  auto isTrivialSpill = [&](LiveRange &allocatedRange) {
+    return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
                                     newRange->getTileType()) &&
-           allocatedRange->values.size() == 1 &&
+           allocatedRange.values.size() == 1 &&
            isTriviallyCloneableTileOp(
-               allocatedRange->values[0]
-                   .getDefiningOp<ArmSMETileOpInterface>());
+               allocatedRange.values[0].getDefiningOp<ArmSMETileOpInterface>());
   };
-  if (isTrivialSpill(newRange))
+  if (isTrivialSpill(*newRange))
     return newRange;
-  auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill);
-  if (trivialSpill != activeRanges.end())
-    return *trivialSpill;
+  auto trivialSpill = llvm::find_if(allOverlappingRanges, isTrivialSpill);
+  if (trivialSpill != allOverlappingRanges.end())
+    return &*trivialSpill;
 
   // Heuristic: Spill the range that ends last (with a compatible tile type).
-  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) {
-    return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) ||
-           a->end() < b->end();
+  auto isSmallerTileTypeOrEndsEarlier = [](LiveRange &a, LiveRange &b) {
+    return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
+           a.end() < b.end();
   };
-  LiveRange *lastActiveLiveRange = *std::max_element(
-      activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier);
-  if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange))
-    return lastActiveLiveRange;
+  LiveRange &lastActiveLiveRange = *std::max_element(
+      allOverlappingRanges.begin(), allOverlappingRanges.end(),
+      isSmallerTileTypeOrEndsEarlier);
+  if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, *newRange))
+    return &lastActiveLiveRange;
   return newRange;
 }
 
 /// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
-/// Note: This does not attempt to fill holes in active live ranges.
 void allocateTilesToLiveRanges(
     ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
   TileAllocator tileAllocator;
   SetVector<LiveRange *> activeRanges;
+  SetVector<LiveRange *> inactiveRanges;
   for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
-    // Release tile IDs from live ranges that have ended.
     activeRanges.remove_if([&](LiveRange *activeRange) {
+      // Check for live ranges that have expired.
       if (activeRange->end() <= nextRange->start()) {
         tileAllocator.releaseTileId(activeRange->getTileType(),
                                     *activeRange->tileId);
         return true;
       }
+      // Check for live ranges that have become inactive.
+      if (!activeRange->overlaps(nextRange->start())) {
+        tileAllocator.releaseTileId(activeRange->getTileType(),
+                                    *activeRange->tileId);
+        inactiveRanges.insert(activeRange);
+        return true;
+      }
+      return false;
+    });
+    inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
+      // Check for live ranges that have expired.
+      if (inactiveRange->end() <= nextRange->start()) {
+        return true;
+      }
+      // Check for live ranges that have become active.
+      if (inactiveRange->overlaps(nextRange->start())) {
+        tileAllocator.acquireTileId(inactiveRange->getTileType(),
+                                    *inactiveRange->tileId);
+        activeRanges.insert(inactiveRange);
+        return true;
+      }
       return false;
     });
 
+    // Collect inactive live ranges that overlap with the current new live
+    // range. We need to acquire the tile IDs of overlapping inactive ranges to
+    // prevent two (overlapping) live ranges from getting the same tile ID.
+    SmallVector<LiveRange *> overlappingInactiveRanges;
+    for (LiveRange *inactiveRange : inactiveRanges) {
+      if (inactiveRange->overlaps(*nextRange)) {
+        tileAllocator.acquireTileId(inactiveRange->getTileType(),
+                                    *inactiveRange->tileId);
+        overlappingInactiveRanges.push_back(inactiveRange);
+      }
+    }
+
     // Allocate a tile ID to `nextRange`.
     auto rangeTileType = nextRange->getTileType();
     auto tileId = tileAllocator.allocateTileId(rangeTileType);
     if (succeeded(tileId)) {
       nextRange->tileId = *tileId;
     } else {
-      LiveRange *rangeToSpill =
-          chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange);
+      LiveRange *rangeToSpill = chooseSpillUsingHeuristics(
+          activeRanges.getArrayRef(), overlappingInactiveRanges, nextRange);
       if (rangeToSpill != nextRange) {
-        // Spill an active live range (so release its tile ID first).
+        // Spill an (in)active live range (so release its tile ID first).
         tileAllocator.releaseTileId(rangeToSpill->getTileType(),
                                     *rangeToSpill->tileId);
-        activeRanges.remove(rangeToSpill);
         // This will always succeed after a spill (of an active live range).
         nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType);
+        // Remove the live range from the active/inactive sets.
+        if (!activeRanges.remove(rangeToSpill)) {
+          bool removed = inactiveRanges.remove(rangeToSpill);
+          assert(removed && "expected a range to be removed!");
+        }
       }
       rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId();
     }
@@ -558,6 +615,12 @@ void allocateTilesToLiveRanges(
     // Insert the live range into the active ranges.
     if (nextRange->tileId < kInMemoryTileIdBase)
       activeRanges.insert(nextRange);
+
+    // Release tiles reserved for inactive live ranges.
+    for (LiveRange *range : overlappingInactiveRanges) {
+      if (*range->tileId < kInMemoryTileIdBase)
+        tileAllocator.releaseTileId(range->getTileType(), *range->tileId);
+    }
   }
 }
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 9c22b29ac22e7..59afa654778e5 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -430,3 +430,133 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
   // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD
   return
 }
+
+// -----
+
+//  CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//        CHECK-LIVE-RANGE: S  arm_sme.get_tile
+//        CHECK-LIVE-RANGE: E  cf.cond_br
+//        CHECK-LIVE-RANGE: ^bb1:
+//        CHECK-LIVE-RANGE:  S arm_sme.get_tile
+//        CHECK-LIVE-RANGE:  | test.dummy
+//        CHECK-LIVE-RANGE:  E test.some_use
+//        CHECK-LIVE-RANGE:    cf.br
+//        CHECK-LIVE-RANGE: ^bb2:
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: E  test.some_use
+//        CHECK-LIVE-RANGE:    cf.br
+
+// Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
+// can reuse the tile ID (0) assigned to %tileA.
+
+// CHECK-LABEL: @fill_holes_in_tile_liveness
+func.func @fill_holes_in_tile_liveness(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3
+^bb2:
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3
+^bb3:
+  return
+}
+
+// -----
+
+//  CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//        CHECK-LIVE-RANGE: S  arm_sme.get_tile
+//        CHECK-LIVE-RANGE: E  cf.cond_br
+//        CHECK-LIVE-RANGE: ^bb1:
+//        CHECK-LIVE-RANGE:  S arm_sme.get_tile
+//        CHECK-LIVE-RANGE:  | test.dummy
+//        CHECK-LIVE-RANGE:  | test.some_use
+//        CHECK-LIVE-RANGE:  | arm_sme.copy_tile
+//        CHECK-LIVE-RANGE:  E cf.br
+//        CHECK-LIVE-RANGE: ^bb2:
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |  test.dummy
+//        CHECK-LIVE-RANGE: |S arm_sme.get_tile
+//        CHECK-LIVE-RANGE: E| test.some_use
+//        CHECK-LIVE-RANGE:  | arm_sme.copy_tile
+//        CHECK-LIVE-RANGE:  E cf.br
+//        CHECK-LIVE-RANGE: ^bb3:
+//        CHECK-LIVE-RANGE:  E test.some_use
+//        CHECK-LIVE-RANGE:    func.return
+
+// This tests an edge case in inactive live ranges. The first live range is
+// inactive at the start of ^bb1. If the tile allocator did not check if the
+// second live range overlapped the first it would wrongly re-use tile ID 0
+// (as the first live range is inactive so tile ID 0 is free). This would mean
+// in ^bb2 two overlapping live ranges would have the same tile ID (bad!).
+
+// CHECK-LABEL: @holes_in_tile_liveness_inactive_overlaps
+func.func @holes_in_tile_liveness_inactive_overlaps(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3(%tileB: vector<[4]x[4]xf32>)
+^bb2:
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3(%tileC: vector<[4]x[4]xf32>)
+^bb3(%tile: vector<[4]x[4]xf32>):
+  "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> ()
+  return
+}
+
+// -----
+
+// This is the same as the previous example, but changes the tile types to
+// vector<[16]x[16]xi8>. This means in bb1 the allocator will need to spill the
+// first live range (which is inactive).
+
+// Note: The live ranges are the same as the previous example (so are not checked).
+
+// CHECK-LABEL: @spill_inactive_live_range
+func.func @spill_inactive_live_range(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 16 : i32}
+  %tileA = arm_sme.get_tile : vector<[16]x[16]xi8>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
+  cf.br ^bb3(%tileB: vector<[16]x[16]xi8>)
+^bb2:
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileC = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.some_use"(%tileA) : (vector<[16]x[16]xi8>) -> ()
+  cf.br ^bb3(%tileC: vector<[16]x[16]xi8>)
+^bb3(%tile: vector<[16]x[16]xi8>):
+  "test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
+  return
+}

>From 55fe7ebe8165a4d43ae91fa24195ad626f534ab4 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 15 Jul 2024 10:32:17 +0000
Subject: [PATCH 2/3] Fixups

---
 .../ArmSME/Transforms/TileAllocation.cpp      |  31 +++--
 .../ArmSME/tile-allocation-liveness.mlir      | 118 ++++++++++++------
 2 files changed, 103 insertions(+), 46 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 6023871c5affe..57ef4eecfb3d9 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -297,7 +297,7 @@ struct LiveRange {
         .valid();
   }
 
-  /// Returns true if this range overlaps with `point`.
+  /// Returns true if this range is active at `point` in the program.
   bool overlaps(uint64_t point) const {
     return ranges->lookup(point) == kValidLiveRange;
   }
@@ -531,11 +531,11 @@ LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
     return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
            a.end() < b.end();
   };
-  LiveRange &lastActiveLiveRange = *std::max_element(
+  LiveRange &latestEndingLiveRange = *std::max_element(
       allOverlappingRanges.begin(), allOverlappingRanges.end(),
       isSmallerTileTypeOrEndsEarlier);
-  if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, *newRange))
-    return &lastActiveLiveRange;
+  if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
+    return &latestEndingLiveRange;
   return newRange;
 }
 
@@ -543,17 +543,25 @@ LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
 void allocateTilesToLiveRanges(
     ArrayRef<LiveRange *> liveRangesSortedByStartPoint) {
   TileAllocator tileAllocator;
+  // `activeRanges` = Live ranges that need to be in a tile at the current point
+  // in the program.
   SetVector<LiveRange *> activeRanges;
+  // `inactiveRanges` = Live ranges that _do not_ need to be in a tile
+  // at the current point in the program but could become active again later.
+  // An inactive section of a live range can be seen as a 'hole' in the live
+  // range, where it is possible to re-use the live range's tile ID _before_ has
+  // it has ended. This allows reusing tiles more (so avoids spills).
   SetVector<LiveRange *> inactiveRanges;
   for (LiveRange *nextRange : liveRangesSortedByStartPoint) {
+    // Update the `activeRanges` at `newRange->start()`.
     activeRanges.remove_if([&](LiveRange *activeRange) {
-      // Check for live ranges that have expired.
+      // 1. Check for live ranges that have expired.
       if (activeRange->end() <= nextRange->start()) {
         tileAllocator.releaseTileId(activeRange->getTileType(),
                                     *activeRange->tileId);
         return true;
       }
-      // Check for live ranges that have become inactive.
+      // 2. Check for live ranges that have become inactive.
       if (!activeRange->overlaps(nextRange->start())) {
         tileAllocator.releaseTileId(activeRange->getTileType(),
                                     *activeRange->tileId);
@@ -562,12 +570,13 @@ void allocateTilesToLiveRanges(
       }
       return false;
     });
+    // Update the `inactiveRanges` at `newRange->start()`.
     inactiveRanges.remove_if([&](LiveRange *inactiveRange) {
-      // Check for live ranges that have expired.
+      // 1. Check for live ranges that have expired.
       if (inactiveRange->end() <= nextRange->start()) {
         return true;
       }
-      // Check for live ranges that have become active.
+      // 2. Check for live ranges that have become active.
       if (inactiveRange->overlaps(nextRange->start())) {
         tileAllocator.acquireTileId(inactiveRange->getTileType(),
                                     *inactiveRange->tileId);
@@ -577,9 +586,9 @@ void allocateTilesToLiveRanges(
       return false;
     });
 
-    // Collect inactive live ranges that overlap with the current new live
-    // range. We need to acquire the tile IDs of overlapping inactive ranges to
-    // prevent two (overlapping) live ranges from getting the same tile ID.
+    // Collect inactive live ranges that overlap with the new live range. We
+    // need to reserve the tile IDs of overlapping inactive ranges to prevent
+    // two (overlapping) live ranges from getting the same tile ID.
     SmallVector<LiveRange *> overlappingInactiveRanges;
     for (LiveRange *inactiveRange : inactiveRanges) {
       if (inactiveRange->overlaps(*nextRange)) {
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
index 59afa654778e5..2e1f3d1ee10a9 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir
@@ -436,30 +436,32 @@ func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) {
 //  CHECK-LIVE-RANGE-LABEL: @fill_holes_in_tile_liveness
 //        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
 //        CHECK-LIVE-RANGE: ^bb0:
-//        CHECK-LIVE-RANGE: S  arm_sme.get_tile
-//        CHECK-LIVE-RANGE: E  cf.cond_br
-//        CHECK-LIVE-RANGE: ^bb1:
-//        CHECK-LIVE-RANGE:  S arm_sme.get_tile
-//        CHECK-LIVE-RANGE:  | test.dummy
-//        CHECK-LIVE-RANGE:  E test.some_use
-//        CHECK-LIVE-RANGE:    cf.br
-//        CHECK-LIVE-RANGE: ^bb2:
-//        CHECK-LIVE-RANGE: |  test.dummy
-//        CHECK-LIVE-RANGE: |  test.dummy
-//        CHECK-LIVE-RANGE: |  test.dummy
-//        CHECK-LIVE-RANGE: E  test.some_use
-//        CHECK-LIVE-RANGE:    cf.br
+//   CHECK-LIVE-RANGE-NEXT: S  arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: E  cf.cond_br
+//   CHECK-LIVE-RANGE-NEXT: ^bb1:
+//   CHECK-LIVE-RANGE-NEXT:  S arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT:  | test.dummy
+//   CHECK-LIVE-RANGE-NEXT:  E test.some_use
+//   CHECK-LIVE-RANGE-NEXT:    cf.br
+//   CHECK-LIVE-RANGE-NEXT: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT: |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT: |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT: E  test.some_use
+//   CHECK-LIVE-RANGE-NEXT:    cf.br
 
 // Here there's a 'hole' in the liveness of %tileA (in bb1) where another value
-// can reuse the tile ID (0) assigned to %tileA.
+// can reuse the tile ID assigned to %tileA. The liveness for %tileB is
+// entirely within the 'hole' in %tileA's live range, so %tileB should get the
+// same tile ID as %tileA.
 
 // CHECK-LABEL: @fill_holes_in_tile_liveness
 func.func @fill_holes_in_tile_liveness(%cond: i1) {
-  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A:.*]] : i32}
   %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
   cf.cond_br %cond, ^bb2, ^bb1
 ^bb1:
-  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  // CHECK: arm_sme.get_tile {tile_id = [[TILE_ID_A]] : i32}
   %tileB = arm_sme.get_tile : vector<[4]x[4]xf32>
   "test.dummy"(): () -> ()
   "test.some_use"(%tileB) : (vector<[4]x[4]xf32>) -> ()
@@ -479,25 +481,25 @@ func.func @fill_holes_in_tile_liveness(%cond: i1) {
 //  CHECK-LIVE-RANGE-LABEL: @holes_in_tile_liveness_inactive_overlaps
 //        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
 //        CHECK-LIVE-RANGE: ^bb0:
-//        CHECK-LIVE-RANGE: S  arm_sme.get_tile
-//        CHECK-LIVE-RANGE: E  cf.cond_br
-//        CHECK-LIVE-RANGE: ^bb1:
-//        CHECK-LIVE-RANGE:  S arm_sme.get_tile
-//        CHECK-LIVE-RANGE:  | test.dummy
-//        CHECK-LIVE-RANGE:  | test.some_use
-//        CHECK-LIVE-RANGE:  | arm_sme.copy_tile
-//        CHECK-LIVE-RANGE:  E cf.br
-//        CHECK-LIVE-RANGE: ^bb2:
-//        CHECK-LIVE-RANGE: |  test.dummy
-//        CHECK-LIVE-RANGE: |  test.dummy
-//        CHECK-LIVE-RANGE: |  test.dummy
-//        CHECK-LIVE-RANGE: |S arm_sme.get_tile
-//        CHECK-LIVE-RANGE: E| test.some_use
-//        CHECK-LIVE-RANGE:  | arm_sme.copy_tile
-//        CHECK-LIVE-RANGE:  E cf.br
-//        CHECK-LIVE-RANGE: ^bb3:
-//        CHECK-LIVE-RANGE:  E test.some_use
-//        CHECK-LIVE-RANGE:    func.return
+//   CHECK-LIVE-RANGE-NEXT: S  arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: E  cf.cond_br
+//   CHECK-LIVE-RANGE-NEXT: ^bb1:
+//   CHECK-LIVE-RANGE-NEXT:  S arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT:  | test.dummy
+//   CHECK-LIVE-RANGE-NEXT:  | test.some_use
+//   CHECK-LIVE-RANGE-NEXT:  | arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT:  E cf.br
+//   CHECK-LIVE-RANGE-NEXT: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT: |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT: |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: E| test.some_use
+//   CHECK-LIVE-RANGE-NEXT:  | arm_sme.copy_tile
+//   CHECK-LIVE-RANGE-NEXT:  E cf.br
+//   CHECK-LIVE-RANGE-NEXT: ^bb3:
+//   CHECK-LIVE-RANGE-NEXT:  E test.some_use
+//   CHECK-LIVE-RANGE-NEXT:    func.return
 
 // This tests an edge case in inactive live ranges. The first live range is
 // inactive at the start of ^bb1. If the tile allocator did not check if the
@@ -560,3 +562,49 @@ func.func @spill_inactive_live_range(%cond: i1) {
   "test.some_use"(%tile) : (vector<[16]x[16]xi8>) -> ()
   return
 }
+
+// -----
+
+//  CHECK-LIVE-RANGE-LABEL: @reactivate_inactive_live_range
+//        CHECK-LIVE-RANGE: ========== Coalesced Live Ranges:
+//        CHECK-LIVE-RANGE: ^bb0:
+//   CHECK-LIVE-RANGE-NEXT: S   arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: E   cf.cond_br
+//   CHECK-LIVE-RANGE-NEXT: ^bb1:
+//   CHECK-LIVE-RANGE-NEXT:  S  arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT:  |  test.dummy
+//   CHECK-LIVE-RANGE-NEXT:  E  test.some_use
+//   CHECK-LIVE-RANGE-NEXT:     cf.br
+//   CHECK-LIVE-RANGE-NEXT: ^bb2:
+//   CHECK-LIVE-RANGE-NEXT: | S arm_sme.get_tile
+//   CHECK-LIVE-RANGE-NEXT: | | test.dummy
+//   CHECK-LIVE-RANGE-NEXT: | | test.dummy
+//   CHECK-LIVE-RANGE-NEXT: | E test.some_use
+//   CHECK-LIVE-RANGE-NEXT: E   test.some_use
+//   CHECK-LIVE-RANGE-NEXT:     cf.br
+
+// Here the live range for %tileA becomes inactive in bb1 (so %tileB gets tile
+// ID 0 too). Then in bb2 the live range for tileA is reactivated as it overlaps
+// with the start of %tileC's live range (which means %tileC gets tile ID 1).
+
+func.func @reactivate_inactive_live_range(%cond: i1) {
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileA = arm_sme.get_tile : vector<[4]x[4]xf32>
+  cf.cond_br %cond, ^bb2, ^bb1
+^bb1:
+  // CHECK: arm_sme.get_tile {tile_id = 0 : i32}
+  %tileB = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileB) : (vector<[16]x[16]xi8>) -> ()
+  cf.br ^bb3
+^bb2:
+  // CHECK: arm_sme.get_tile {tile_id = 1 : i32}
+  %tileC = arm_sme.get_tile : vector<[4]x[4]xf32>
+  "test.dummy"(): () -> ()
+  "test.dummy"(): () -> ()
+  "test.some_use"(%tileC) : (vector<[4]x[4]xf32>) -> ()
+  "test.some_use"(%tileA) : (vector<[4]x[4]xf32>) -> ()
+  cf.br ^bb3
+^bb3:
+  return
+}

>From 9be39215a8182b03e2fce1e68b1576b90427f602 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Mon, 15 Jul 2024 10:41:47 +0000
Subject: [PATCH 3/3] Templates to the rescue

---
 .../ArmSME/Transforms/TileAllocation.cpp      | 34 +++++++++----------
 1 file changed, 17 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index 57ef4eecfb3d9..68085762bab27 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -502,16 +502,11 @@ coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
 }
 
 /// Choose a live range to spill (via some heuristics). This picks either a live
-/// range from `activeRanges`, `inactiveRanges`, or the new live range
-/// `newRange`. Note: All live ranges in `activeRanges` and `inactiveRanges` are
-/// assumed to overlap with `newRange`.
-LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
-                                      ArrayRef<LiveRange *> inactiveRanges,
-                                      LiveRange *newRange) {
-  auto allOverlappingRanges =
-      llvm::concat<LiveRange>(llvm::make_pointee_range(activeRanges),
-                              llvm::make_pointee_range(inactiveRanges));
-
+/// range from `overlappingRanges`, or the new live range `newRange`.
+template <typename OverlappingRangesIterator>
+LiveRange *
+chooseSpillUsingHeuristics(OverlappingRangesIterator overlappingRanges,
+                           LiveRange *newRange) {
   // Heuristic: Spill trivially copyable operations (usually free).
   auto isTrivialSpill = [&](LiveRange &allocatedRange) {
     return isTileTypeGreaterOrEqual(allocatedRange.getTileType(),
@@ -522,8 +517,8 @@ LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
   };
   if (isTrivialSpill(*newRange))
     return newRange;
-  auto trivialSpill = llvm::find_if(allOverlappingRanges, isTrivialSpill);
-  if (trivialSpill != allOverlappingRanges.end())
+  auto trivialSpill = llvm::find_if(overlappingRanges, isTrivialSpill);
+  if (trivialSpill != overlappingRanges.end())
     return &*trivialSpill;
 
   // Heuristic: Spill the range that ends last (with a compatible tile type).
@@ -531,9 +526,9 @@ LiveRange *chooseSpillUsingHeuristics(ArrayRef<LiveRange *> activeRanges,
     return !isTileTypeGreaterOrEqual(a.getTileType(), b.getTileType()) ||
            a.end() < b.end();
   };
-  LiveRange &latestEndingLiveRange = *std::max_element(
-      allOverlappingRanges.begin(), allOverlappingRanges.end(),
-      isSmallerTileTypeOrEndsEarlier);
+  LiveRange &latestEndingLiveRange =
+      *std::max_element(overlappingRanges.begin(), overlappingRanges.end(),
+                        isSmallerTileTypeOrEndsEarlier);
   if (!isSmallerTileTypeOrEndsEarlier(latestEndingLiveRange, *newRange))
     return &latestEndingLiveRange;
   return newRange;
@@ -604,8 +599,13 @@ void allocateTilesToLiveRanges(
     if (succeeded(tileId)) {
       nextRange->tileId = *tileId;
     } else {
-      LiveRange *rangeToSpill = chooseSpillUsingHeuristics(
-          activeRanges.getArrayRef(), overlappingInactiveRanges, nextRange);
+      // Create an iterator over all overlapping live ranges.
+      auto allOverlappingRanges = llvm::concat<LiveRange>(
+          llvm::make_pointee_range(activeRanges.getArrayRef()),
+          llvm::make_pointee_range(overlappingInactiveRanges));
+      // Choose an overlapping live range to spill.
+      LiveRange *rangeToSpill =
+          chooseSpillUsingHeuristics(allOverlappingRanges, nextRange);
       if (rangeToSpill != nextRange) {
         // Spill an (in)active live range (so release its tile ID first).
         tileAllocator.releaseTileId(rangeToSpill->getTileType(),



More information about the Mlir-commits mailing list