[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri May 3 09:07:54 PDT 2024
================
@@ -137,172 +129,508 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
}
}
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
- TileMask &tilesInUse) {
- auto masks = getMasks(tileType);
- for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
- if ((tilesInUse & tileMask) == TileMask::kNone) {
- tilesInUse |= tileMask;
- return tileId;
+class TileAllocator {
+public:
+ /// Allocates and returns a tile ID. Fails if there are no tiles left.
+ FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+ if ((tilesInUse & tileMask) == TileMask::kNone) {
+ tilesInUse |= tileMask;
+ return tileId;
+ }
}
+ return failure();
+ }
+
+ /// Releases a previously allocated tile ID.
+ void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) != TileMask::kNone &&
+ "cannot release unallocated tile!");
+ tilesInUse ^= tileMask;
+ }
+
+ /// Allocates an in-memory tile ID.
+ unsigned allocateInMemoryTileId() {
+ // Note: We never release in-memory tile IDs. We could, which may allow
+ // reusing an allocation, but as we _never_ want to spill an SME tile this
+ // is not optimized.
+ return nextInMemoryTileId++;
}
- return failure();
-}
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
- SetVector<Operation *> &dependantOps) {
- auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
- for (auto [idx, value] : llvm::enumerate(inputValues)) {
- if (value == rootValue)
- findDependantOps(exitValues[idx], dependantOps);
+private:
+ TileMask tilesInUse = TileMask::kNone;
+ unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+// Add new intermediate blocks for the true and false destinations of a
+// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
+// branches.
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+ SmallVector<cf::CondBranchOp> worklist;
+ function.walk([&](cf::CondBranchOp condBranch) {
+ if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+ return isValidSMETileVectorType(value.getType());
+ })) {
+ worklist.push_back(condBranch);
}
+ });
+
+ auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+ rewriter.setInsertionPointToEnd(source);
+ rewriter.create<cf::BranchOp>(loc, dest, args);
};
- for (Operation *user : rootValue.getUsers()) {
- if (dependantOps.contains(user))
+
+ for (auto condBranch : worklist) {
+ auto loc = condBranch.getLoc();
+ Block *block = condBranch->getBlock();
+ auto newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+ condBranch.getTrueDestOperands());
+ insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+ condBranch.getFalseDestOperands());
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ }
+}
+
+/// Inserts tile copies at `cf.br` operations.
+void insertCopiesAtBranches(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ splitCondBranches(rewriter, function);
+ for (Block &block : function.getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<cf::BranchOp>(terminator))
continue;
- dependantOps.insert(user);
- TypeSwitch<Operation *>(user)
- .Case<cf::BranchOp>([&](auto branchOp) {
- // (CF) Follow branch.
- traverseCorrespondingValues(branchOp.getDestOperands(),
- branchOp.getDest()->getArguments());
- })
- .Case<cf::CondBranchOp>([&](auto condBranchOp) {
- // (CF) Follow true branch.
- traverseCorrespondingValues(
- condBranchOp.getTrueOperands(),
- condBranchOp.getTrueDest()->getArguments());
- // (CF) Follow false branch.
- traverseCorrespondingValues(
- condBranchOp.getFalseOperands(),
- condBranchOp.getFalseDest()->getArguments());
- })
- .Case<LoopLikeOpInterface>([&](auto loopOp) {
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
- traverseCorrespondingValues(loopOp.getInits(),
- loopOp.getRegionIterArgs());
- })
- .Case<scf::YieldOp>([&](auto yieldOp) {
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
- auto parent = user->getParentOp();
- traverseCorrespondingValues(user->getOperands(),
- parent->getResults());
+ rewriter.setInsertionPoint(terminator);
+ for (OpOperand &operand : terminator->getOpOperands()) {
+ if (isValidSMETileVectorType(operand.get().getType())) {
+ auto copy =
+ rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ operand.assign(copy);
+ }
+ }
+ }
+}
+
+/// A range where a tile value is live. The range may contain holes.
+struct LiveRange {
+ using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+ llvm::IntervalMapHalfOpenInfo<unsigned>>;
+ using Allocator = RangeSet::Allocator;
+ static constexpr uint8_t kValidLiveRange = 0xff;
+
+ LiveRange(Allocator &allocator)
+ : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+ /// Returns true if this range overlaps with `otherRange`.
+ bool overlaps(LiveRange const &otherRange) const {
+ return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+ *otherRange.ranges)
+ .valid();
+ }
+
+ /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+ void unionWith(LiveRange const &otherRange) {
+ for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+ ++it)
+ ranges->insert(it.start(), it.stop(), kValidLiveRange);
+ values.set_union(otherRange.values);
+ }
+
+ /// Inserts an interval [start, end) for `value` into this range.
+ void insert(Value value, unsigned start, unsigned end) {
+ values.insert(value);
+ if (start != end)
+ ranges->insert(start, end, kValidLiveRange);
+ }
+
+ bool empty() const { return ranges->empty(); }
+ unsigned start() const { return ranges->start(); }
+ unsigned end() const { return ranges->stop(); }
+ bool operator<(LiveRange const &other) const {
+ return start() < other.start();
+ }
+
+ ArmSMETileType getTileType() const {
+ return *getSMETileType(cast<VectorType>(values[0].getType()));
+ }
+
+ std::unique_ptr<RangeSet> ranges;
+ SetVector<Value> values;
+ std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+ unsigned index = 0;
+ SetVector<Block *> blocks =
+ getTopologicallySortedBlocks(function.getFunctionBody());
+ DenseMap<Operation *, unsigned> operationToIndexMap;
+ for (Block *block : blocks) {
+ index++; // We want block args to have their own number.
+ for (Operation &op : block->getOperations()) {
+ // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+ op.walk([&](ArmSMETileOpInterface nestedOp) {
+ assert(&op == nestedOp.getOperation() &&
+ "ArmSME tile allocation does not support nested regions");
+ });
+#endif
+ operationToIndexMap.try_emplace(&op, index++);
+ }
+ }
+ return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
----------------
MacDue wrote:
This is documented on the `LiveRange` class too now
https://github.com/llvm/llvm-project/pull/90448
More information about the Mlir-commits
mailing list