[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri May 10 03:18:37 PDT 2024
================
@@ -137,172 +138,629 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
}
}
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
- TileMask &tilesInUse) {
- auto masks = getMasks(tileType);
- for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
- if ((tilesInUse & tileMask) == TileMask::kNone) {
- tilesInUse |= tileMask;
- return tileId;
+class TileAllocator {
+public:
+ /// Allocates and returns a tile ID. Fails if there are no tiles left.
+ FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+ if ((tilesInUse & tileMask) == TileMask::kNone) {
+ tilesInUse |= tileMask;
+ return tileId;
+ }
}
+ return failure();
+ }
+
+ /// Releases a previously allocated tile ID.
+ void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) != TileMask::kNone &&
+ "cannot release unallocated tile!");
+ tilesInUse ^= tileMask;
}
- return failure();
-}
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
- SetVector<Operation *> &dependantOps) {
- auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
- for (auto [idx, value] : llvm::enumerate(inputValues)) {
- if (value == rootValue)
- findDependantOps(exitValues[idx], dependantOps);
+ /// Allocates an in-memory tile ID.
+ unsigned allocateInMemoryTileId() {
+ // Note: We never release in-memory tile IDs. We could, which may allow
+ // reusing an allocation, but as we _never_ want to spill an SME tile this
+ // is not optimized.
+ return nextInMemoryTileId++;
+ }
+
+private:
+ TileMask tilesInUse = TileMask::kNone;
+ unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+/// ^bb1_copy:
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ^bb2_copy:
+/// cf.br ^bb2
+/// ```
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+ SmallVector<cf::CondBranchOp> worklist;
+ function.walk([&](cf::CondBranchOp condBranch) {
+ if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+ return isValidSMETileVectorType(value.getType());
+ })) {
+ worklist.push_back(condBranch);
}
+ });
+
+ auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+ rewriter.setInsertionPointToEnd(source);
+ rewriter.create<cf::BranchOp>(loc, dest, args);
};
- for (Operation *user : rootValue.getUsers()) {
- if (dependantOps.contains(user))
+
+ for (auto condBranch : worklist) {
+ auto loc = condBranch.getLoc();
+ Block *block = condBranch->getBlock();
+ auto newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+ condBranch.getTrueDestOperands());
+ insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+ condBranch.getFalseDestOperands());
+ rewriter.modifyOpInPlace(condBranch, [&] {
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ });
+ }
+}
+
+/// Inserts tile copies at `cf.br` operations.
+///
+/// BEFORE:
+/// ```mlir
+/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
+/// ```
+void insertCopiesAtBranches(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ for (Block &block : function.getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<cf::BranchOp>(terminator))
continue;
- dependantOps.insert(user);
- TypeSwitch<Operation *>(user)
- .Case<cf::BranchOp>([&](auto branchOp) {
- // (CF) Follow branch.
- traverseCorrespondingValues(branchOp.getDestOperands(),
- branchOp.getDest()->getArguments());
- })
- .Case<cf::CondBranchOp>([&](auto condBranchOp) {
- // (CF) Follow true branch.
- traverseCorrespondingValues(
- condBranchOp.getTrueOperands(),
- condBranchOp.getTrueDest()->getArguments());
- // (CF) Follow false branch.
- traverseCorrespondingValues(
- condBranchOp.getFalseOperands(),
- condBranchOp.getFalseDest()->getArguments());
- })
- .Case<LoopLikeOpInterface>([&](auto loopOp) {
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
- traverseCorrespondingValues(loopOp.getInits(),
- loopOp.getRegionIterArgs());
- })
- .Case<scf::YieldOp>([&](auto yieldOp) {
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
- auto parent = user->getParentOp();
- traverseCorrespondingValues(user->getOperands(),
- parent->getResults());
+ rewriter.setInsertionPoint(terminator);
+ for (OpOperand &operand : terminator->getOpOperands()) {
+ if (isValidSMETileVectorType(operand.get().getType())) {
+ auto copy =
+ rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
+ }
+ }
+ }
+}
+
+/// Prepares the IR for tile allocation. It does this by first 'splitting'
+/// conditional branches (see `splitCondBranches`), then inserting tile copies
+/// at branch operations. The conditional branches are split to prevent the
+/// copies needed for them overlapping between the true and false paths of the
+/// branch (see `tile-allocation-copies.mlir` and
+/// `tile-allocation-liveness.mlir` for examples). The copies break up live
+/// ranges and ensure when moving out of SSA the semantics of the program are
+/// preserved.
+void preprocessForTileAllocation(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ splitCondBranches(rewriter, function);
+ insertCopiesAtBranches(rewriter, function);
+}
+
+/// A live range for a (collection of) tile values. A live range is built up of
+/// non-overlapping intervals [start, end) which represent parts of the program
+/// where a value in the range needs to be live (i.e. in an SME virtual tile).
+/// Note that as the intervals are non-overlapping all values within a live
+/// range can be allocated to the same SME virtual tile.
+struct LiveRange {
+ using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+ llvm::IntervalMapHalfOpenInfo<unsigned>>;
+ using Allocator = RangeSet::Allocator;
+ static constexpr uint8_t kValidLiveRange = 0xff;
+
+ LiveRange(Allocator &allocator)
+ : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+ /// Returns true if this range overlaps with `otherRange`.
+ bool overlaps(LiveRange const &otherRange) const {
+ return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+ *otherRange.ranges)
+ .valid();
+ }
+
+ /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+ void unionWith(LiveRange const &otherRange) {
+ for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+ ++it)
+ ranges->insert(it.start(), it.stop(), kValidLiveRange);
+ values.set_union(otherRange.values);
+ }
+
+ /// Inserts an interval [start, end) for `value` into this range.
+ void insert(Value value, unsigned start, unsigned end) {
+ values.insert(value);
+ if (start != end)
+ ranges->insert(start, end, kValidLiveRange);
+ }
+
+ bool empty() const { return ranges->empty(); }
+ unsigned start() const { return ranges->start(); }
+ unsigned end() const { return ranges->stop(); }
+ bool operator<(LiveRange const &other) const {
+ return start() < other.start();
+ }
+
+ ArmSMETileType getTileType() const {
+ return *getSMETileType(cast<VectorType>(values[0].getType()));
+ }
+
+ /// The values contained in this live range.
+ SetVector<Value> values;
+
+ /// A set of (non-overlapping) intervals that mark where any value in `values`
+ /// is live.
+ std::unique_ptr<RangeSet> ranges;
+
+ /// The tile ID (or none) assigned to this live range.
+ std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+/// Operations are numbered consecutively wihin blocks, and the blocks are
+/// topologically sorted (using forward edges). This function is only correct if
+/// all ArmSME have been converted to CF (which is asserted).
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+ unsigned index = 0;
+ SetVector<Block *> blocks =
+ getTopologicallySortedBlocks(function.getFunctionBody());
+ DenseMap<Operation *, unsigned> operationToIndexMap;
+ for (Block *block : blocks) {
+ index++; // We want block args to have their own number.
+ for (Operation &op : block->getOperations()) {
+#ifndef NDEBUG
+ op.walk([&](ArmSMETileOpInterface nestedOp) {
+ assert(&op == nestedOp.getOperation() &&
+ "ArmSME tile allocation does not support nested regions");
+ });
+#endif
+ operationToIndexMap.try_emplace(&op, index++);
+ }
+ }
+ return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ LiveRange::Allocator &liveRangeAllocator,
+ Liveness &liveness, FunctionOpInterface function) {
+ assert(!operationToIndexMap.empty() && "expected operation numbering");
+ DenseMap<Value, LiveRange> liveRanges;
+ /// Defines or updates a live range for an SME tile value. Live-ins may update
+ /// an existing live range (rather than define a new one). Note: If
+ /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
+ /// the block.
+ auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+ LivenessBlockInfo const &livenessInfo,
+ bool liveAtBlockEntry = false) {
+ if (!isValidSMETileVectorType(value.getType()))
+ return;
+ // Find or create a live range for `value`.
+ auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+ LiveRange &valueLiveRange = it->second;
+ auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
+ unsigned startOpIdx =
+ operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+ unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
+ valueLiveRange.insert(value, startOpIdx, endOpIdx);
+ };
+
+ for (Block &block : function.getBlocks()) {
+ LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+ // Handle block arguments:
+ for (Value argument : block.getArguments())
+ defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle live-ins:
+ for (Value liveIn : livenessInfo->in())
+ defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle new definitions:
+ for (Operation &op : block) {
+ for (Value result : op.getResults())
+ defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
+ }
+ }
+
+ return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+ function_ref<void(Value)> callback) {
+ Block *block = blockArg.getOwner();
+ unsigned argNumber = blockArg.getArgNumber();
+ for (Block *pred : block->getPredecessors()) {
+ TypeSwitch<Operation *>(pred->getTerminator())
+ .Case<cf::BranchOp>([&](auto branch) {
+ Value predecessorOperand = branch.getDestOperands()[argNumber];
+ callback(predecessorOperand);
})
- .Default([&](auto) {
- // Otherwise, assume users of _any_ result are dependant.
- for (Value result : user->getResults())
- findDependantOps(result, dependantOps);
+ .Case<cf::CondBranchOp>([&](auto condBranch) {
+ if (condBranch.getFalseDest() == block) {
+ Value predecessorOperand =
+ condBranch.getFalseDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
+ if (condBranch.getTrueDest() == block) {
+ Value predecessorOperand =
+ condBranch.getTrueDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
});
}
}
-struct AssignTileIDsPattern
- : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
- PatternRewriter &rewriter) const override {
- if (tileOp.getTileId())
- return failure();
-
- auto func = tileOp->getParentOfType<FunctionOpInterface>();
- auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
- if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
- func->getDiscardableAttr(name)))
- return unsigned(attr.getInt());
- return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+ DenseMap<Value, LiveRange *> liveRanges;
+ for (auto &[value, liveRange] : initialLiveRanges) {
+ liveRanges.insert({value, &liveRange});
+ }
+
+ // Merge the live ranges of values `a` and `b` into one (if they do not
+ // overlap). After this, the values `a` and `b` will both point to the same
+ // live range (which will contain multiple values).
+ auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+ LiveRange *aLiveRange = liveRanges.at(a);
+ LiveRange *bLiveRange = liveRanges.at(b);
+ if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+ aLiveRange->unionWith(*bLiveRange);
+ for (Value value : bLiveRange->values)
+ liveRanges[value] = aLiveRange;
+ }
+ };
+
+ // Merge the live ranges of new definitions with their tile operands.
+ auto unifyDefinitionsWithOperands = [&](Value value) {
+ auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ if (!armSMEOp)
+ return;
+ for (auto operand : armSMEOp->getOperands()) {
+ if (isValidSMETileVectorType(operand.getType()))
+ mergeValuesIfNonOverlapping(value, operand);
+ }
+ };
+
+ // Merge the live ranges of block arguments with their predecessors.
+ auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+ auto blockArg = dyn_cast<BlockArgument>(value);
+ if (!blockArg)
+ return;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+ });
+ };
+
+ auto applyRule = [&](auto rule) {
+ llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+ };
+
+ // Unify as many live ranges as we can. This prevents unnecessary moves.
+ applyRule(unifyBlockArgumentsWithPredecessors);
+ applyRule(unifyDefinitionsWithOperands);
+
+ // Remove duplicate live range entries.
+ SetVector<LiveRange *> uniqueLiveRanges;
+ for (auto [_, liveRange] : liveRanges) {
+ if (!liveRange->empty())
+ uniqueLiveRanges.insert(liveRange);
+ }
+
+ // Sort the new live ranges by starting point (ready for tile allocation).
+ auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+ std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+ [](LiveRange *a, LiveRange *b) { return *a < *b; });
+ return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+ TileAllocator tileAllocator;
+ SetVector<LiveRange *> allocatedRanges;
+
+ auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
----------------
banach-space wrote:
This lambda is quite long and defines further 2 lambdas inside. IMO, it makes the logic of `allocateTilesToLiveRanges ` hard to follow. Please extract into a separate method and document.
Does it always return an "in-memory tile id"?
https://github.com/llvm/llvm-project/pull/90448
More information about the Mlir-commits
mailing list