[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)
Cullen Rhodes
llvmlistbot at llvm.org
Tue Apr 30 06:51:29 PDT 2024
================
@@ -137,172 +129,510 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
}
}
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
- TileMask &tilesInUse) {
- auto masks = getMasks(tileType);
- for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
- if ((tilesInUse & tileMask) == TileMask::kNone) {
- tilesInUse |= tileMask;
- return tileId;
+class TileAllocator {
+public:
+ /// Allocates and returns a tile ID.
+ FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+ if ((tilesInUse & tileMask) == TileMask::kNone) {
+ tilesInUse |= tileMask;
+ return tileId;
+ }
}
+ return failure();
+ }
+
+ /// Releases a previously allocated tile ID.
+ void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+ TileMask tileMask = getMasks(tileType)[tileId];
+ assert((tilesInUse & tileMask) != TileMask::kNone &&
+ "cannot release unallocated tile!");
+ tilesInUse ^= tileMask;
+ }
+
+ /// Allocates an in-memory tile ID.
+ unsigned allocateInMemoryTileId() {
+ // Note: We never release in-memory tile IDs. We could, which may allow
+ // reusing an allocation, but as we _never_ want to spill an SME tile this
+ // is not optimized.
+ return nextInMemoryTileId++;
}
- return failure();
-}
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
- SetVector<Operation *> &dependantOps) {
- auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
- for (auto [idx, value] : llvm::enumerate(inputValues)) {
- if (value == rootValue)
- findDependantOps(exitValues[idx], dependantOps);
+private:
+ TileMask tilesInUse = TileMask::kNone;
+ unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+// Add new intermediate blocks for the true and false destinations of a
+// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
+// branches.
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+ SmallVector<cf::CondBranchOp> worklist;
+ function.walk([&](cf::CondBranchOp condBranch) {
+ if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+ return isValidSMETileVectorType(value.getType());
+ })) {
+ worklist.push_back(condBranch);
}
+ });
+
+ auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+ rewriter.setInsertionPointToEnd(source);
+ rewriter.create<cf::BranchOp>(loc, dest, args);
};
- for (Operation *user : rootValue.getUsers()) {
- if (dependantOps.contains(user))
+
+ for (auto condBranch : worklist) {
+ auto loc = condBranch.getLoc();
+ Block *block = condBranch->getBlock();
+ auto newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+ condBranch.getTrueDestOperands());
+ insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+ condBranch.getFalseDestOperands());
+ condBranch.getFalseDestOperandsMutable().clear();
+ condBranch.getTrueDestOperandsMutable().clear();
+ condBranch.setSuccessor(newTrueBranch, 0);
+ condBranch.setSuccessor(newFalseBranch, 1);
+ }
+}
+
+/// Inserts tile copies `cf.br` operations.
+void insertCopiesAtBranches(IRRewriter &rewriter,
+ FunctionOpInterface function) {
+ splitCondBranches(rewriter, function);
+ for (Block &block : function.getBlocks()) {
+ Operation *terminator = block.getTerminator();
+ if (!isa<cf::BranchOp>(terminator))
continue;
- dependantOps.insert(user);
- TypeSwitch<Operation *>(user)
- .Case<cf::BranchOp>([&](auto branchOp) {
- // (CF) Follow branch.
- traverseCorrespondingValues(branchOp.getDestOperands(),
- branchOp.getDest()->getArguments());
- })
- .Case<cf::CondBranchOp>([&](auto condBranchOp) {
- // (CF) Follow true branch.
- traverseCorrespondingValues(
- condBranchOp.getTrueOperands(),
- condBranchOp.getTrueDest()->getArguments());
- // (CF) Follow false branch.
- traverseCorrespondingValues(
- condBranchOp.getFalseOperands(),
- condBranchOp.getFalseDest()->getArguments());
- })
- .Case<LoopLikeOpInterface>([&](auto loopOp) {
- // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
- traverseCorrespondingValues(loopOp.getInits(),
- loopOp.getRegionIterArgs());
- })
- .Case<scf::YieldOp>([&](auto yieldOp) {
- // (SCF) Follow yields of (basic) control flow (e.g. for loops).
- auto parent = user->getParentOp();
- traverseCorrespondingValues(user->getOperands(),
- parent->getResults());
+ rewriter.setInsertionPoint(terminator);
+ for (OpOperand &operand : terminator->getOpOperands()) {
+ if (isValidSMETileVectorType(operand.get().getType())) {
+ auto copy =
+ rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+ operand.assign(copy);
+ }
+ }
+ }
+}
+
+/// A range where a tile value is live. The range may contain holes.
+struct LiveRange {
+ using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+ llvm::IntervalMapHalfOpenInfo<unsigned>>;
+ using Allocator = RangeSet::Allocator;
+ static constexpr uint8_t kValidLiveRange = 0xff;
+
+ LiveRange(Allocator &allocator)
+ : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+ /// Returns true if this range overlaps with `otherRange`.
+ bool overlaps(LiveRange const &otherRange) const {
+ return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+ *otherRange.ranges)
+ .valid();
+ }
+
+ /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+ void unionWith(LiveRange const &otherRange) {
+ for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+ ++it)
+ ranges->insert(it.start(), it.stop(), kValidLiveRange);
+ values.set_union(otherRange.values);
+ }
+
+ /// Inserts an interval [start, end) for `value` into this range.
+ void insert(Value value, unsigned start, unsigned end) {
+ values.insert(value);
+ if (start != end)
+ ranges->insert(start, end, kValidLiveRange);
+ }
+
+ bool empty() const { return ranges->empty(); }
+ unsigned start() const { return ranges->start(); }
+ unsigned end() const { return ranges->stop(); }
+ bool operator<(LiveRange const &other) const {
+ return start() < other.start();
+ }
+
+ ArmSMETileType getTileType() const {
+ return *getSMETileType(cast<VectorType>(values[0].getType()));
+ }
+
+ std::unique_ptr<RangeSet> ranges;
+ SetVector<Value> values;
+ std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+ unsigned index = 0;
+ SetVector<Block *> blocks =
+ getTopologicallySortedBlocks(function.getFunctionBody());
+ DenseMap<Operation *, unsigned> operationToIndexMap;
+ for (Block *block : blocks) {
+ index++; // We want block args to have their own number.
+ for (Operation &op : block->getOperations()) {
+ // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+ op.walk([&](ArmSMETileOpInterface nestedOp) {
+ if (&op != nestedOp.getOperation()) {
+ assert(false &&
+ "ArmSME tile allocation does not support nested regions");
+ }
+ });
+#endif
+ operationToIndexMap.try_emplace(&op, index++);
+ }
+ }
+ return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ LiveRange::Allocator &liveRangeAllocator,
+ Liveness &liveness, FunctionOpInterface function) {
+ DenseMap<Value, LiveRange> liveRanges;
+ auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
+ LivenessBlockInfo const &livenessInfo,
+ bool liveAtBlockEntry = false) {
+ if (!isValidSMETileVectorType(value.getType()))
+ return;
+ auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+ auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ unsigned start =
+ operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+ unsigned end = operationToIndexMap.at(lastUseInBlock);
+ it->second.insert(value, start, end);
+ };
+
+ for (Block &block : function.getBlocks()) {
+ LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+ // Handle block arguments:
+ for (Value argument : block.getArguments())
+ updateLiveRanges(argument, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle live-ins:
+ for (Value liveIn : livenessInfo->in())
+ updateLiveRanges(liveIn, &block.front(), *livenessInfo,
+ /*liveAtBlockEntry=*/true);
+ // Handle new definitions:
+ for (Operation &op : block) {
+ for (Value result : op.getResults())
+ updateLiveRanges(result, &op, *livenessInfo);
+ }
+ }
+
+ return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+ function_ref<void(Value)> callback) {
+ Block *block = blockArg.getOwner();
+ unsigned argNumber = blockArg.getArgNumber();
+ for (Block *pred : block->getPredecessors()) {
+ TypeSwitch<Operation *>(pred->getTerminator())
+ .Case<cf::BranchOp>([&](auto branch) {
+ Value predecessorOperand = branch.getDestOperands()[argNumber];
+ callback(predecessorOperand);
})
- .Default([&](auto) {
- // Otherwise, assume users of _any_ result are dependant.
- for (Value result : user->getResults())
- findDependantOps(result, dependantOps);
+ .Case<cf::CondBranchOp>([&](auto condBranch) {
+ if (condBranch.getFalseDest() == block) {
+ Value predecessorOperand =
+ condBranch.getFalseDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
+ if (condBranch.getTrueDest() == block) {
+ Value predecessorOperand =
+ condBranch.getTrueDestOperands()[argNumber];
+ callback(predecessorOperand);
+ }
});
}
}
-struct AssignTileIDsPattern
- : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
- using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
- LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
- PatternRewriter &rewriter) const override {
- if (tileOp.getTileId())
- return failure();
-
- auto func = tileOp->getParentOfType<FunctionOpInterface>();
- auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
- if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
- func->getDiscardableAttr(name)))
- return unsigned(attr.getInt());
- return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+ DenseMap<Value, LiveRange *> liveRanges;
+ for (auto &[value, liveRange] : initialLiveRanges) {
+ liveRanges.insert({value, &liveRange});
+ }
+
+ auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+ LiveRange *aLiveRange = liveRanges.at(a);
+ LiveRange *bLiveRange = liveRanges.at(b);
+ if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+ aLiveRange->unionWith(*bLiveRange);
+ for (Value value : bLiveRange->values)
+ liveRanges[value] = aLiveRange;
+ }
+ };
+
+ // Merge the live ranges of new definitions with their tile operands.
+ auto unifyDefinitionsWithOperands = [&](Value value) {
+ auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ if (!armSMEOp)
+ return;
+ for (auto operand : armSMEOp->getOperands()) {
+ if (isValidSMETileVectorType(operand.getType()))
+ mergeValuesIfNonOverlapping(value, operand);
+ }
+ };
+
+ // Merge the live ranges of block arguments with their predecessors.
+ auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+ auto blockArg = dyn_cast<BlockArgument>(value);
+ if (!blockArg)
+ return;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+ });
+ };
+
+ auto applyRule = [&](auto rule) {
+ llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+ };
+
+ // Unify as many live ranges as we can. This prevents unnecessary moves.
+ applyRule(unifyBlockArgumentsWithPredecessors);
+ applyRule(unifyDefinitionsWithOperands);
+
+ // Remove duplicate live range entries.
+ SetVector<LiveRange *> uniqueLiveRanges;
+ for (auto [_, liveRange] : liveRanges) {
+ if (!liveRange->empty())
+ uniqueLiveRanges.insert(liveRange);
+ }
+
+ // Sort the new live ranges by starting point (ready for tile allocation).
+ auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+ std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+ [](LiveRange *a, LiveRange *b) { return *a < *b; });
+ return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+ TileAllocator tileAllocator;
+ SetVector<LiveRange *> allocatedRanges;
+
+ auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
+ unsigned memoryTileId = tileAllocator.allocateInMemoryTileId();
+ auto spillActiveRange = [&](LiveRange *range) {
+ unsigned tileId = *range->tileId;
+ range->tileId = memoryTileId;
+ allocatedRanges.remove(range);
+ return tileId;
};
- auto setDiscardableIntAttr = [&](StringRef name, auto value) {
- rewriter.modifyOpInPlace(tileOp, [&] {
- func->setDiscardableAttr(name,
- rewriter.getI32IntegerAttr((unsigned)value));
- });
+
+ auto isTrivialSpill = [](LiveRange *allocatedRange) {
+ return allocatedRange->values.size() == 1 &&
+ isTriviallyCloneableTileOp(
+ allocatedRange->values[0]
+ .getDefiningOp<ArmSMETileOpInterface>());
};
- std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
- if (!tileType)
- return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
- TileMask tilesInUse =
- static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
- auto tileId = allocateTileId(*tileType, tilesInUse);
- bool tileIsInMemory = failed(tileId);
- if (tileIsInMemory) {
- // If we could not find a real tile ID, use an in-memory tile ID (ID >=
- // 16). A later pass will insert the necessary spills and reloads.
- tileId =
- getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
- tileOp->emitWarning(
- "failed to allocate SME virtual tile to operation, all tile "
- "operations will go through memory, expect degraded performance");
- }
+ // Heuristic: Spill trivially copyable operations (usually free).
+ if (isTrivialSpill(newRange))
+ return memoryTileId;
+ auto trivialSpill = llvm::find_if(allocatedRanges, isTrivialSpill);
+ if (trivialSpill != allocatedRanges.end())
+ return spillActiveRange(*trivialSpill);
+
+ // Heuristic: Spill the live range that ends last.
+ LiveRange *lastActiveLiveRange = *std::max_element(
+ allocatedRanges.begin(), allocatedRanges.end(),
+ [](LiveRange *a, LiveRange *b) { return a->end() < b->end(); });
+ if (lastActiveLiveRange->end() >= newRange->end())
+ return spillActiveRange(lastActiveLiveRange);
+
+ return memoryTileId;
+ };
- // Set all operations dependent on `tileOp` to use the same tile ID.
- // This is a naive tile allocation scheme, but works for common cases. For
- // example, as this only allocates tile IDs to existing ops, it can't solve
- // cases like this (%tileA and %tileB come from different root operations):
- //
- // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
- // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
- // } else {
- // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
- // }
- //
- // This case would require allocating a new tile for the result of the
- // scf.if, and moving the contents of %tileA or %tileB to result tile (based
- // on the %some_cond).
- // Find all the ops that (transitively) depend on this tile.
- SetVector<Operation *> dependantOps;
- findDependantOps(tileOp->getResult(0), dependantOps);
- auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
- for (auto *op : dependantOps) {
- if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- auto currentTileId = dependantTileOp.getTileId();
- if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
- return dependantTileOp.emitOpError(
- "already assigned different SME virtual tile!");
+ for (LiveRange *newRange : liveRanges) {
+ // Release tiles from live ranges that have ended.
+ allocatedRanges.remove_if([&](LiveRange *allocatedRange) {
+ if (allocatedRange->end() <= newRange->start()) {
+ tileAllocator.releaseTileId(allocatedRange->getTileType(),
+ *allocatedRange->tileId);
+ return true;
}
- }
+ return false;
+ });
- // Rewrite IR.
- if (!tileIsInMemory)
- setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+ // Allocate a tile ID to `newRange`.
+ auto tileId = tileAllocator.allocateTileId(newRange->getTileType());
+ if (succeeded(tileId))
+ newRange->tileId = *tileId;
else
- setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
- rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
- for (auto *op : dependantOps) {
- if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
- rewriter.modifyOpInPlace(
- dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
+ newRange->tileId = chooseSpillUsingHeuristics(newRange);
+
+ // Insert the live range into the allocated ranges.
+ if (newRange->tileId < kInMemoryTileIdBase)
+ allocatedRanges.insert(newRange);
+ }
+}
+
+/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
+LogicalResult assignTileIdsAndResolveTrivialConflicts(
+ IRRewriter &rewriter, FunctionOpInterface function,
+ ArrayRef<LiveRange *> allocatedLiveRanges) {
+ for (LiveRange const *liveRange : allocatedLiveRanges) {
+ auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
+ auto isAllocatedToSameTile = [&](Value value) {
+ if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+ tileOp && tileOp.getTileId() == tileIdAttr)
+ return true;
+ return liveRange->values.contains(value);
+ };
+ for (Value value : liveRange->values) {
+ for (Operation *user : value.getUsers()) {
+ if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+ // Ensure ArmSME ops that don't produce a value still get a tile ID.
+ if (!hasTileResult(tileOp))
+ tileOp.setTileId(tileIdAttr);
+ }
+ }
+ auto copyOp = value.getDefiningOp<CopyTileOp>();
+ if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
+ // Fold redundant copies.
+ rewriter.replaceAllUsesWith(copyOp, copyOp.getTile());
+ } else if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>()) {
+ tileOp.setTileId(tileIdAttr);
+ // Rectify operand tile IDs with result tile IDs.
+ OpOperand *tileOperand = getTileOpOperand(tileOp);
+ if (!tileOperand || isAllocatedToSameTile(tileOperand->get()))
+ continue;
+ auto operandTileOp =
+ tileOperand->get().getDefiningOp<ArmSMETileOpInterface>();
+ if (!isTriviallyCloneableTileOp(operandTileOp))
+ return tileOp.emitOpError("failed to rectify tile operand with tile "
+ "result (move required)");
+ // Cloning prevents a move/spill (though may require recomputation).
+ rewriter.setInsertionPoint(tileOp);
+ auto clonedOp = operandTileOp.clone();
+ clonedOp.setTileId(tileOp.getTileId());
+ rewriter.insert(clonedOp);
+ if (copyOp)
+ rewriter.replaceAllUsesWith(copyOp, clonedOp->getResult(0));
+ else
+ tileOperand->assign(clonedOp->getResult(0));
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
+ // Validate block arguments.
+ bool tileMismatch = false;
+ forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+ if (tileMismatch)
+ return;
+ if (!isAllocatedToSameTile(predecessorTile)) {
+ blockArg.getOwner()->getParentOp()->emitOpError(
+ "block argument not allocated to the same tile as "
+ "predecessors");
+ tileMismatch = true;
+ }
+ });
+ if (tileMismatch)
+ return failure();
}
}
+ }
+ return success();
+}
- return success();
+/// Prints live ranges alongside operation names for debugging.
+void dumpLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+ ArrayRef<LiveRange const *> liveRanges,
+ FunctionOpInterface function) {
+ llvm::errs() << "SME Tile Liveness: @" << function.getName()
+ << "\nKey:\nS - Start\nE - End\n| - Live\n";
+ for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) {
+ llvm::errs() << "^bb" << blockIdx << ":\n";
+ for (Operation &op : block.getOperations()) {
+ unsigned operationIndex = operationToIndexMap.at(&op);
+ for (LiveRange const *range : liveRanges) {
+ char liveness = ' ';
+ for (auto it = range->ranges->begin(); it != range->ranges->end();
+ ++it) {
+ if (it.start() == operationIndex)
+ liveness = (liveness == 'E' ? '|' : 'S');
+ else if (it.stop() == operationIndex)
+ liveness = (liveness == 'S' ? '|' : 'E');
+ else if (operationIndex >= it.start() && operationIndex < it.stop())
+ liveness = '|';
+ }
+ llvm::errs() << liveness;
+ }
+ llvm::errs() << ' ' << op.getName() << '\n';
+ }
}
-};
+ llvm::errs() << "==========\n";
+}
-struct TileAllocationPass
- : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
+struct TestTileAllocationPass
+ : public arm_sme::impl::TestTileAllocationBase<TestTileAllocationPass> {
+ using TestTileAllocationBase::TestTileAllocationBase;
void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- patterns.add<AssignTileIDsPattern>(patterns.getContext());
- GreedyRewriteConfig config;
- // Setting useTopDownTraversal ensures tiles are allocated in program
- // order.
- config.useTopDownTraversal = true;
- if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
- getOperation(), std::move(patterns), config))) {
+ if (failed(arm_sme::allocateSMETiles(getOperation(), dumpTileLiveRanges)))
signalPassFailure();
- }
}
};
} // namespace
-std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
- return std::make_unique<TileAllocationPass>();
+LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function,
+ bool dumpRanges) {
+ LiveRange::Allocator liveRangeAllocator;
+ IRRewriter rewriter(function.getContext());
+
+ // 1. Insert copy operations at branch operations.
+ insertCopiesAtBranches(rewriter, function);
+
+ // 2. Gather live ranges for each ArmSME tile within the function.
+ Liveness liveness(function);
+ auto operationToIndexMap = generateOperationNumbering(function);
+ auto initialLiveRanges = gatherTileLiveRanges(
+ operationToIndexMap, liveRangeAllocator, liveness, function);
+ if (initialLiveRanges.empty())
+ return success();
+
+ if (dumpRanges) {
+ // Wrangle initial live ranges into a form suitable for printing.
+ auto nonEmpty = llvm::make_filter_range(
+ llvm::make_second_range(initialLiveRanges),
+ [&](LiveRange const &liveRange) { return !liveRange.empty(); });
+ auto initialRanges = llvm::to_vector(llvm::map_range(
+ nonEmpty, [](LiveRange const &liveRange) { return &liveRange; }));
+ std::sort(initialRanges.begin(), initialRanges.end(),
+ [](LiveRange const *a, LiveRange const *b) { return *a < *b; });
+ llvm::errs() << "\n========== Initial Live Ranges:\n";
+ dumpLiveRanges(operationToIndexMap, initialRanges, function);
+ }
+
+ // 3. Coalesce (non-overlapping) live ranges where it would be beneficial
+ // for tile allocation. E.g. Unify the result of an operation with its
+ // operands.
+ auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges);
+
+ if (dumpRanges) {
+ llvm::errs() << "\n========== Coalesced Live Ranges:\n";
+ dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function);
+ }
+
+ // 4. Allocate tile IDs to live ranges.
+ allocateTilesToLiveRanges(coalescedLiveRanges);
+
+ // 5. Assign the tile IDs back to the ArmSME operations.
+ if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function,
+ coalescedLiveRanges))) {
+ return failure();
+ }
+
+ /// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
+ /// users). This prevents the LLVM conversion needlessly inserting spills.
----------------
c-rhodes wrote:
for consistency with above
```suggestion
// 6. Erase trivially dead tile operations (e.g. a ZeroOp with no
// users). This prevents the LLVM conversion needlessly inserting spills.
```
https://github.com/llvm/llvm-project/pull/90448
More information about the Mlir-commits
mailing list