[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue May 7 08:29:50 PDT 2024


================
@@ -137,172 +129,551 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
-                                          TileMask &tilesInUse) {
-  auto masks = getMasks(tileType);
-  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
-    if ((tilesInUse & tileMask) == TileMask::kNone) {
-      tilesInUse |= tileMask;
-      return tileId;
+class TileAllocator {
+public:
+  /// Allocates and returns a tile ID. Fails if there are no tiles left.
+  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+    auto masks = getMasks(tileType);
+    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+      if ((tilesInUse & tileMask) == TileMask::kNone) {
+        tilesInUse |= tileMask;
+        return tileId;
+      }
     }
+    return failure();
+  }
+
+  /// Releases a previously allocated tile ID.
+  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) != TileMask::kNone &&
+           "cannot release unallocated tile!");
+    tilesInUse ^= tileMask;
+  }
+
+  /// Allocates an in-memory tile ID.
+  unsigned allocateInMemoryTileId() {
+    // Note: We never release in-memory tile IDs. We could, which may allow
+    // reusing an allocation, but as we _never_ want to spill an SME tile this
+    // is not optimized.
+    return nextInMemoryTileId++;
   }
-  return failure();
-}
 
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
-                             SetVector<Operation *> &dependantOps) {
-  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
-    for (auto [idx, value] : llvm::enumerate(inputValues)) {
-      if (value == rootValue)
-        findDependantOps(exitValues[idx], dependantOps);
+private:
+  TileMask tilesInUse = TileMask::kNone;
+  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+///  ^bb1_copy:
+///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ^bb2_copy:
+///    cf.br ^bb2
+///  ```
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+  SmallVector<cf::CondBranchOp> worklist;
+  function.walk([&](cf::CondBranchOp condBranch) {
+    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+          return isValidSMETileVectorType(value.getType());
+        })) {
+      worklist.push_back(condBranch);
     }
+  });
+
+  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+    rewriter.setInsertionPointToEnd(source);
+    rewriter.create<cf::BranchOp>(loc, dest, args);
   };
-  for (Operation *user : rootValue.getUsers()) {
-    if (dependantOps.contains(user))
+
+  for (auto condBranch : worklist) {
+    auto loc = condBranch.getLoc();
+    Block *block = condBranch->getBlock();
+    auto newTrueBranch = rewriter.splitBlock(block, block->end());
+    auto newFalseBranch = rewriter.splitBlock(block, block->end());
+    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+               condBranch.getTrueDestOperands());
+    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+               condBranch.getFalseDestOperands());
+    rewriter.modifyOpInPlace(condBranch, [&] {
+      condBranch.getFalseDestOperandsMutable().clear();
+      condBranch.getTrueDestOperandsMutable().clear();
+      condBranch.setSuccessor(newTrueBranch, 0);
+      condBranch.setSuccessor(newFalseBranch, 1);
+    });
+  }
+}
+
+/// Splits conditional branches (see `splitCondBranches`), then inserts tile
+/// copies at `cf.br` operations.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+///  ```
----------------
banach-space wrote:

This example is confusing - it's referring to everything excluding `splitCondBranches`. It will be easier to document if you split things:
```cpp
// Leave as is
void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {}
// Only insert copies at `cf.br`
void void insertCopiesAtBranches(IRRewriter &rewriter,FunctionOpInterface function) {}

// New function!
/// Prepares IR for liveness analysis:
///  * replaces cf.cond_br with cf.br (to avoid "phantom" live ranges and to simplify live range calculation)
///  * inserts copies at cf.br (to help create non-overlapping live ranges)
///
/// Note that this method operates at the CF level. Also, it doesn't get rid of conditional branches.
/// Instead, it makes sure that there are no conditional branches taking SME tiles are block arguments.
/// To this end, "artificial" branch Ops are inserted.
void preProcessForLivenessAnalysis(IRRewriter &rewriter, FunctionOpInterface function) {
  splitCondBranches(rewriter, function);
  insertCopiesAtBranches(rewriter, function);
}
```

https://github.com/llvm/llvm-project/pull/90448


More information about the Mlir-commits mailing list