[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri May 10 01:24:10 PDT 2024
================
@@ -137,172 +138,612 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
}
}
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
- TileMask &tilesInUse) {
- auto masks = getMasks(tileType);
- for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
- if ((tilesInUse & tileMask) == TileMask::kNone) {
- tilesInUse |= tileMask;
- return tileId;
+class TileAllocator {
+public:
+ /// Allocates and returns a tile ID. Fails if there are no tiles left.
+ FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+ if ((tilesInUse & tileMask) == TileMask::kNone) {
+ tilesInUse |= tileMask;
+ return tileId;
+ }
}
+ return failure();
}
- return failure();
-}
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
- SetVector<Operation *> &dependantOps) {
- auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
- for (auto [idx, value] : llvm::enumerate(inputValues)) {
- if (value == rootValue)
- findDependantOps(exitValues[idx], dependantOps);
+ /// Releases a previously allocated tile ID.
+ void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) != TileMask::kNone &&
+ "cannot release unallocated tile!");
+ tilesInUse ^= tileMask;
+ }
+
+ /// Allocates an in-memory tile ID.
+ unsigned allocateInMemoryTileId() {
+ // Note: We never release in-memory tile IDs. We could, which may allow
+ // reusing an allocation, but as we _never_ want to spill an SME tile this
+ // is not optimized.
+ return nextInMemoryTileId++;
+ }
+
+private:
+ TileMask tilesInUse = TileMask::kNone;
+ unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+/// ^bb1_copy:
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ^bb2_copy:
+/// cf.br ^bb2
+/// ```
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+ SmallVector<cf::CondBranchOp> worklist;
+ function.walk([&](cf::CondBranchOp condBranch) {
+ if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+ return isValidSMETileVectorType(value.getType());
+ })) {
+ worklist.push_back(condBranch);
}
+ });
+
+ auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+ rewriter.setInsertionPointToEnd(source);
+ rewriter.create<cf::BranchOp>(loc, dest, args);
};
- for (Operation *user : rootValue.getUsers()) {
- if (dependantOps.contains(user))
+
+ for (auto condBranch : worklist) {
+ auto loc = condBranch.getLoc();
+ Block *block = condBranch->getBlock();
+ auto newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+ condBranch.getTrueDestOperands());
+ insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+ condBranch.getFalseDestOperands());
+ rewriter.modifyOpInPlace(condBranch, [&] {
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ });
+ }
+}
+
+/// Inserts tile copies at `cf.br` operations.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+/// ```
+void insertCopiesAtBranches(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ for (Block &block : function.getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<cf::BranchOp>(terminator))
continue;
- dependantOps.insert(user);
- TypeSwitch<Operation *>(user)
- .Case<cf::BranchOp>([&](auto branchOp) {
- // (CF) Follow branch.
- traverseCorrespondingValues(branchOp.getDestOperands(),
- branchOp.getDest()->getArguments());
- })
- .Case<cf::CondBranchOp>([&](auto condBranchOp) {
- // (CF) Follow true branch.
- traverseCorrespondingValues(
- condBranchOp.getTrueOperands(),
- condBranchOp.getTrueDest()->getArguments());
- // (CF) Follow false branch.
- traverseCorrespondingValues(
- condBranchOp.getFalseOperands(),
- condBranchOp.getFalseDest()->getArguments());
- })
- .Case<LoopLikeOpInterface>([&](auto loopOp) {
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
- traverseCorrespondingValues(loopOp.getInits(),
- loopOp.getRegionIterArgs());
- })
- .Case<scf::YieldOp>([&](auto yieldOp) {
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
- auto parent = user->getParentOp();
- traverseCorrespondingValues(user->getOperands(),
- parent->getResults());
+ rewriter.setInsertionPoint(terminator);
+ for (OpOperand &operand : terminator->getOpOperands()) {
+ if (isValidSMETileVectorType(operand.get().getType())) {
+ auto copy =
+ rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
+ }
+ }
+ }
+}
+
+/// Prepares the IR for tile allocation. It does this by first 'splitting'
+/// conditional branches (see `splitCondBranches`), then inserting tile copies
+/// at branch operations. The conditional branches are split to prevent the
+/// copies needed for them overlapping between the true and false paths of the
+/// branch (see `tile-allocation-copies.mlir` and
+/// `tile-allocation-liveness.mlir` for examples). The copies break up live
+/// ranges and ensure when moving out of SSA the semantics of the program are
+/// persevered.
+void preprocessForTileAllocation(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ splitCondBranches(rewriter, function);
+ insertCopiesAtBranches(rewriter, function);
+}
+
+/// A live range for a (collection of) tile values. A live range is built up of
+/// intervals [start, end) which represent parts of the program where the value
+/// needs to be live (i.e. in an SME virtual tile).
+struct LiveRange {
+ using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+ llvm::IntervalMapHalfOpenInfo<unsigned>>;
+ using Allocator = RangeSet::Allocator;
+ static constexpr uint8_t kValidLiveRange = 0xff;
----------------
banach-space wrote:
Makes sense, thanks! I've made a few suggestion for more comments ^^^.
https://github.com/llvm/llvm-project/pull/90448
More information about the Mlir-commits
mailing list