[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

Benjamin Maxwell llvmlistbot at llvm.org
Tue May 7 02:24:45 PDT 2024


================
@@ -137,172 +129,508 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
-                                          TileMask &tilesInUse) {
-  auto masks = getMasks(tileType);
-  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
-    if ((tilesInUse & tileMask) == TileMask::kNone) {
-      tilesInUse |= tileMask;
-      return tileId;
+class TileAllocator {
+public:
+  /// Allocates and returns a tile ID. Fails if there are no tiles left.
+  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+    auto masks = getMasks(tileType);
+    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+      if ((tilesInUse & tileMask) == TileMask::kNone) {
+        tilesInUse |= tileMask;
+        return tileId;
+      }
     }
+    return failure();
+  }
+
+  /// Releases a previously allocated tile ID.
+  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) != TileMask::kNone &&
+           "cannot release unallocated tile!");
+    tilesInUse ^= tileMask;
+  }
+
+  /// Allocates an in-memory tile ID.
+  unsigned allocateInMemoryTileId() {
+    // Note: We never release in-memory tile IDs. We could, which may allow
+    // reusing an allocation, but as we _never_ want to spill an SME tile this
+    // is not optimized.
+    return nextInMemoryTileId++;
   }
-  return failure();
-}
 
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
-                             SetVector<Operation *> &dependantOps) {
-  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
-    for (auto [idx, value] : llvm::enumerate(inputValues)) {
-      if (value == rootValue)
-        findDependantOps(exitValues[idx], dependantOps);
+private:
+  TileMask tilesInUse = TileMask::kNone;
+  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+// Add new intermediate blocks for the true and false destinations of a
+// `cf.cond_br`. This prevents spurious liveness overlaps due to copies at
+// branches.
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+  SmallVector<cf::CondBranchOp> worklist;
+  function.walk([&](cf::CondBranchOp condBranch) {
+    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+          return isValidSMETileVectorType(value.getType());
+        })) {
+      worklist.push_back(condBranch);
     }
+  });
+
+  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+    rewriter.setInsertionPointToEnd(source);
+    rewriter.create<cf::BranchOp>(loc, dest, args);
   };
-  for (Operation *user : rootValue.getUsers()) {
-    if (dependantOps.contains(user))
+
+  for (auto condBranch : worklist) {
+    auto loc = condBranch.getLoc();
+    Block *block = condBranch->getBlock();
+    auto newTrueBranch = rewriter.splitBlock(block, block->end());
+    auto newFalseBranch = rewriter.splitBlock(block, block->end());
+    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+               condBranch.getTrueDestOperands());
+    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+               condBranch.getFalseDestOperands());
+    condBranch.getFalseDestOperandsMutable().clear();
+    condBranch.getTrueDestOperandsMutable().clear();
+    condBranch.setSuccessor(newTrueBranch, 0);
+    condBranch.setSuccessor(newFalseBranch, 1);
+  }
+}
+
+/// Inserts tile copies at `cf.br` operations.
+void insertCopiesAtBranches(IRRewriter &rewriter,
+                            FunctionOpInterface function) {
+  splitCondBranches(rewriter, function);
+  for (Block &block : function.getBlocks()) {
+    Operation *terminator = block.getTerminator();
+    if (!isa<cf::BranchOp>(terminator))
       continue;
-    dependantOps.insert(user);
-    TypeSwitch<Operation *>(user)
-        .Case<cf::BranchOp>([&](auto branchOp) {
-          // (CF) Follow branch.
-          traverseCorrespondingValues(branchOp.getDestOperands(),
-                                      branchOp.getDest()->getArguments());
-        })
-        .Case<cf::CondBranchOp>([&](auto condBranchOp) {
-          // (CF) Follow true branch.
-          traverseCorrespondingValues(
-              condBranchOp.getTrueOperands(),
-              condBranchOp.getTrueDest()->getArguments());
-          // (CF) Follow false branch.
-          traverseCorrespondingValues(
-              condBranchOp.getFalseOperands(),
-              condBranchOp.getFalseDest()->getArguments());
-        })
-        .Case<LoopLikeOpInterface>([&](auto loopOp) {
-          // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
-          traverseCorrespondingValues(loopOp.getInits(),
-                                      loopOp.getRegionIterArgs());
-        })
-        .Case<scf::YieldOp>([&](auto yieldOp) {
-          // (SCF) Follow yields of (basic) control flow (e.g. for loops).
-          auto parent = user->getParentOp();
-          traverseCorrespondingValues(user->getOperands(),
-                                      parent->getResults());
+    rewriter.setInsertionPoint(terminator);
+    for (OpOperand &operand : terminator->getOpOperands()) {
+      if (isValidSMETileVectorType(operand.get().getType())) {
+        auto copy =
+            rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+        operand.assign(copy);
+      }
+    }
+  }
+}
+
+/// A range where a tile value is live. The range may contain holes.
+struct LiveRange {
+  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+                                     llvm::IntervalMapHalfOpenInfo<unsigned>>;
+  using Allocator = RangeSet::Allocator;
+  static constexpr uint8_t kValidLiveRange = 0xff;
+
+  LiveRange(Allocator &allocator)
+      : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+  /// Returns true if this range overlaps with `otherRange`.
+  bool overlaps(LiveRange const &otherRange) const {
+    return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+                                                         *otherRange.ranges)
+        .valid();
+  }
+
+  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+  void unionWith(LiveRange const &otherRange) {
+    for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+         ++it)
+      ranges->insert(it.start(), it.stop(), kValidLiveRange);
+    values.set_union(otherRange.values);
+  }
+
+  /// Inserts an interval [start, end) for `value` into this range.
+  void insert(Value value, unsigned start, unsigned end) {
+    values.insert(value);
+    if (start != end)
+      ranges->insert(start, end, kValidLiveRange);
+  }
+
+  bool empty() const { return ranges->empty(); }
+  unsigned start() const { return ranges->start(); }
+  unsigned end() const { return ranges->stop(); }
+  bool operator<(LiveRange const &other) const {
+    return start() < other.start();
+  }
+
+  ArmSMETileType getTileType() const {
+    return *getSMETileType(cast<VectorType>(values[0].getType()));
+  }
+
+  std::unique_ptr<RangeSet> ranges;
+  SetVector<Value> values;
+  std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+  unsigned index = 0;
+  SetVector<Block *> blocks =
+      getTopologicallySortedBlocks(function.getFunctionBody());
+  DenseMap<Operation *, unsigned> operationToIndexMap;
+  for (Block *block : blocks) {
+    index++; // We want block args to have their own number.
+    for (Operation &op : block->getOperations()) {
+      // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+      op.walk([&](ArmSMETileOpInterface nestedOp) {
+        assert(&op == nestedOp.getOperation() &&
+               "ArmSME tile allocation does not support nested regions");
+      });
+#endif
+      operationToIndexMap.try_emplace(&op, index++);
+    }
+  }
+  return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                     LiveRange::Allocator &liveRangeAllocator,
+                     Liveness &liveness, FunctionOpInterface function) {
+  DenseMap<Value, LiveRange> liveRanges;
+  auto updateLiveRanges = [&](Value value, Operation *firstUseOrDef,
+                              LivenessBlockInfo const &livenessInfo,
+                              bool liveAtBlockEntry = false) {
+    if (!isValidSMETileVectorType(value.getType()))
+      return;
+    auto it = liveRanges.try_emplace(value, liveRangeAllocator).first;
+    auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    unsigned start =
+        operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+    unsigned end = operationToIndexMap.at(lastUseInBlock);
+    it->second.insert(value, start, end);
+  };
----------------
MacDue wrote:

This already is a factored out into a lambda function, I don't see how making it a static function helps with clarity at all (since it would mean adding a bunch more parameters for captures -- I'm not passing everything as arguments, and it's never used outside this function).  

I've renamed things and added a few more comments though. Note that renaming `start`/`end` to `liveRangeStart` and `liveRangeEnd` is not necessarily correct `[start, end)` is an interval in a live range, but could be in the middle of the live range (e.g. for a live-in that crosses several basic blocks). Also note that this function does not always add a new live range (again the case of live-ins it'll update the existing live range by adding a new interval). 


https://github.com/llvm/llvm-project/pull/90448


More information about the Mlir-commits mailing list