[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

Benjamin Maxwell llvmlistbot at llvm.org
Fri May 10 08:19:39 PDT 2024


================
@@ -137,172 +138,629 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
-                                          TileMask &tilesInUse) {
-  auto masks = getMasks(tileType);
-  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
-    if ((tilesInUse & tileMask) == TileMask::kNone) {
-      tilesInUse |= tileMask;
-      return tileId;
+class TileAllocator {
+public:
+  /// Allocates and returns a tile ID. Fails if there are no tiles left.
+  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+    auto masks = getMasks(tileType);
+    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+      if ((tilesInUse & tileMask) == TileMask::kNone) {
+        tilesInUse |= tileMask;
+        return tileId;
+      }
     }
+    return failure();
+  }
+
+  /// Releases a previously allocated tile ID.
+  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) != TileMask::kNone &&
+           "cannot release unallocated tile!");
+    tilesInUse ^= tileMask;
   }
-  return failure();
-}
 
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
-                             SetVector<Operation *> &dependantOps) {
-  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
-    for (auto [idx, value] : llvm::enumerate(inputValues)) {
-      if (value == rootValue)
-        findDependantOps(exitValues[idx], dependantOps);
+  /// Allocates an in-memory tile ID.
+  unsigned allocateInMemoryTileId() {
+    // Note: We never release in-memory tile IDs. We could, which may allow
+    // reusing an allocation, but as we _never_ want to spill an SME tile this
+    // is not optimized.
+    return nextInMemoryTileId++;
+  }
+
+private:
+  TileMask tilesInUse = TileMask::kNone;
+  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+///  ^bb1_copy:
+///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ^bb2_copy:
+///    cf.br ^bb2
+///  ```
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+  SmallVector<cf::CondBranchOp> worklist;
+  function.walk([&](cf::CondBranchOp condBranch) {
+    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+          return isValidSMETileVectorType(value.getType());
+        })) {
+      worklist.push_back(condBranch);
     }
+  });
+
+  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+    rewriter.setInsertionPointToEnd(source);
+    rewriter.create<cf::BranchOp>(loc, dest, args);
   };
-  for (Operation *user : rootValue.getUsers()) {
-    if (dependantOps.contains(user))
+
+  for (auto condBranch : worklist) {
+    auto loc = condBranch.getLoc();
+    Block *block = condBranch->getBlock();
+    auto newTrueBranch = rewriter.splitBlock(block, block->end());
+    auto newFalseBranch = rewriter.splitBlock(block, block->end());
+    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+               condBranch.getTrueDestOperands());
+    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+               condBranch.getFalseDestOperands());
+    rewriter.modifyOpInPlace(condBranch, [&] {
+      condBranch.getFalseDestOperandsMutable().clear();
+      condBranch.getTrueDestOperandsMutable().clear();
+      condBranch.setSuccessor(newTrueBranch, 0);
+      condBranch.setSuccessor(newFalseBranch, 1);
+    });
+  }
+}
+
+/// Inserts tile copies at `cf.br` operations.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+///  cf.br ^bb1(%copy: vector<[4]x[4]xf32>)
+///  ```
+void insertCopiesAtBranches(IRRewriter &rewriter,
+                            FunctionOpInterface function) {
+  for (Block &block : function.getBlocks()) {
+    Operation *terminator = block.getTerminator();
+    if (!isa<cf::BranchOp>(terminator))
       continue;
-    dependantOps.insert(user);
-    TypeSwitch<Operation *>(user)
-        .Case<cf::BranchOp>([&](auto branchOp) {
-          // (CF) Follow branch.
-          traverseCorrespondingValues(branchOp.getDestOperands(),
-                                      branchOp.getDest()->getArguments());
-        })
-        .Case<cf::CondBranchOp>([&](auto condBranchOp) {
-          // (CF) Follow true branch.
-          traverseCorrespondingValues(
-              condBranchOp.getTrueOperands(),
-              condBranchOp.getTrueDest()->getArguments());
-          // (CF) Follow false branch.
-          traverseCorrespondingValues(
-              condBranchOp.getFalseOperands(),
-              condBranchOp.getFalseDest()->getArguments());
-        })
-        .Case<LoopLikeOpInterface>([&](auto loopOp) {
-          // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
-          traverseCorrespondingValues(loopOp.getInits(),
-                                      loopOp.getRegionIterArgs());
-        })
-        .Case<scf::YieldOp>([&](auto yieldOp) {
-          // (SCF) Follow yields of (basic) control flow (e.g. for loops).
-          auto parent = user->getParentOp();
-          traverseCorrespondingValues(user->getOperands(),
-                                      parent->getResults());
+    rewriter.setInsertionPoint(terminator);
+    for (OpOperand &operand : terminator->getOpOperands()) {
+      if (isValidSMETileVectorType(operand.get().getType())) {
+        auto copy =
+            rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+        rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
+      }
+    }
+  }
+}
+
+/// Prepares the IR for tile allocation. It does this by first 'splitting'
+/// conditional branches (see `splitCondBranches`), then inserting tile copies
+/// at branch operations. The conditional branches are split to prevent the
+/// copies needed for them overlapping between the true and false paths of the
+/// branch (see `tile-allocation-copies.mlir` and
+/// `tile-allocation-liveness.mlir` for examples). The copies break up live
+/// ranges and ensure when moving out of SSA the semantics of the program are
+/// preserved.
+void preprocessForTileAllocation(IRRewriter &rewriter,
+                                 FunctionOpInterface function) {
+  splitCondBranches(rewriter, function);
+  insertCopiesAtBranches(rewriter, function);
+}
+
+/// A live range for a (collection of) tile values. A live range is built up of
+/// non-overlapping intervals [start, end) which represent parts of the program
+/// where a value in the range needs to be live (i.e. in an SME virtual tile).
+/// Note that as the intervals are non-overlapping all values within a live
+/// range can be allocated to the same SME virtual tile.
+struct LiveRange {
+  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+                                     llvm::IntervalMapHalfOpenInfo<unsigned>>;
+  using Allocator = RangeSet::Allocator;
+  static constexpr uint8_t kValidLiveRange = 0xff;
+
+  LiveRange(Allocator &allocator)
+      : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+  /// Returns true if this range overlaps with `otherRange`.
+  bool overlaps(LiveRange const &otherRange) const {
+    return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+                                                         *otherRange.ranges)
+        .valid();
+  }
+
+  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+  void unionWith(LiveRange const &otherRange) {
+    for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+         ++it)
+      ranges->insert(it.start(), it.stop(), kValidLiveRange);
+    values.set_union(otherRange.values);
+  }
+
+  /// Inserts an interval [start, end) for `value` into this range.
+  void insert(Value value, unsigned start, unsigned end) {
+    values.insert(value);
+    if (start != end)
+      ranges->insert(start, end, kValidLiveRange);
+  }
+
+  bool empty() const { return ranges->empty(); }
+  unsigned start() const { return ranges->start(); }
+  unsigned end() const { return ranges->stop(); }
+  bool operator<(LiveRange const &other) const {
+    return start() < other.start();
+  }
+
+  ArmSMETileType getTileType() const {
+    return *getSMETileType(cast<VectorType>(values[0].getType()));
+  }
+
+  /// The values contained in this live range.
+  SetVector<Value> values;
+
+  /// A set of (non-overlapping) intervals that mark where any value in `values`
+  /// is live.
+  std::unique_ptr<RangeSet> ranges;
+
+  /// The tile ID (or none) assigned to this live range.
+  std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+/// Operations are numbered consecutively wihin blocks, and the blocks are
+/// topologically sorted (using forward edges). This function is only correct if
+/// all ArmSME have been converted to CF (which is asserted).
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+  unsigned index = 0;
+  SetVector<Block *> blocks =
+      getTopologicallySortedBlocks(function.getFunctionBody());
+  DenseMap<Operation *, unsigned> operationToIndexMap;
+  for (Block *block : blocks) {
+    index++; // We want block args to have their own number.
+    for (Operation &op : block->getOperations()) {
+#ifndef NDEBUG
+      op.walk([&](ArmSMETileOpInterface nestedOp) {
+        assert(&op == nestedOp.getOperation() &&
+               "ArmSME tile allocation does not support nested regions");
+      });
+#endif
+      operationToIndexMap.try_emplace(&op, index++);
+    }
+  }
+  return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                     LiveRange::Allocator &liveRangeAllocator,
+                     Liveness &liveness, FunctionOpInterface function) {
+  assert(!operationToIndexMap.empty() && "expected operation numbering");
+  DenseMap<Value, LiveRange> liveRanges;
+  /// Defines or updates a live range for an SME tile value. Live-ins may update
+  /// an existing live range (rather than define a new one). Note: If
+  /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in
+  /// the block.
+  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+                                          LivenessBlockInfo const &livenessInfo,
+                                          bool liveAtBlockEntry = false) {
+    if (!isValidSMETileVectorType(value.getType()))
+      return;
+    // Find or create a live range for `value`.
+    auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+    LiveRange &valueLiveRange = it->second;
+    auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
+    unsigned startOpIdx =
+        operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+    unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock);
+    valueLiveRange.insert(value, startOpIdx, endOpIdx);
+  };
+
+  for (Block &block : function.getBlocks()) {
+    LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+    // Handle block arguments:
+    for (Value argument : block.getArguments())
+      defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
+    // Handle live-ins:
+    for (Value liveIn : livenessInfo->in())
+      defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
+    // Handle new definitions:
+    for (Operation &op : block) {
+      for (Value result : op.getResults())
+        defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
+    }
+  }
+
+  return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+                                        function_ref<void(Value)> callback) {
+  Block *block = blockArg.getOwner();
+  unsigned argNumber = blockArg.getArgNumber();
+  for (Block *pred : block->getPredecessors()) {
+    TypeSwitch<Operation *>(pred->getTerminator())
+        .Case<cf::BranchOp>([&](auto branch) {
+          Value predecessorOperand = branch.getDestOperands()[argNumber];
+          callback(predecessorOperand);
         })
-        .Default([&](auto) {
-          // Otherwise, assume users of _any_ result are dependant.
-          for (Value result : user->getResults())
-            findDependantOps(result, dependantOps);
+        .Case<cf::CondBranchOp>([&](auto condBranch) {
+          if (condBranch.getFalseDest() == block) {
+            Value predecessorOperand =
+                condBranch.getFalseDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
+          if (condBranch.getTrueDest() == block) {
+            Value predecessorOperand =
+                condBranch.getTrueDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
         });
   }
 }
-struct AssignTileIDsPattern
-    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
-                                PatternRewriter &rewriter) const override {
-    if (tileOp.getTileId())
-      return failure();
-
-    auto func = tileOp->getParentOfType<FunctionOpInterface>();
-    auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
-      if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
-              func->getDiscardableAttr(name)))
-        return unsigned(attr.getInt());
-      return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+  DenseMap<Value, LiveRange *> liveRanges;
+  for (auto &[value, liveRange] : initialLiveRanges) {
+    liveRanges.insert({value, &liveRange});
+  }
+
+  // Merge the live ranges of values `a` and `b` into one (if they do not
+  // overlap). After this, the values `a` and `b` will both point to the same
+  // live range (which will contain multiple values).
+  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+    LiveRange *aLiveRange = liveRanges.at(a);
+    LiveRange *bLiveRange = liveRanges.at(b);
+    if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+      aLiveRange->unionWith(*bLiveRange);
+      for (Value value : bLiveRange->values)
+        liveRanges[value] = aLiveRange;
+    }
+  };
+
+  // Merge the live ranges of new definitions with their tile operands.
+  auto unifyDefinitionsWithOperands = [&](Value value) {
+    auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+    if (!armSMEOp)
+      return;
+    for (auto operand : armSMEOp->getOperands()) {
+      if (isValidSMETileVectorType(operand.getType()))
+        mergeValuesIfNonOverlapping(value, operand);
+    }
+  };
+
+  // Merge the live ranges of block arguments with their predecessors.
+  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+    auto blockArg = dyn_cast<BlockArgument>(value);
+    if (!blockArg)
+      return;
+    forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+      mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+    });
+  };
+
+  auto applyRule = [&](auto rule) {
+    llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+  };
+
+  // Unify as many live ranges as we can. This prevents unnecessary moves.
+  applyRule(unifyBlockArgumentsWithPredecessors);
+  applyRule(unifyDefinitionsWithOperands);
+
+  // Remove duplicate live range entries.
+  SetVector<LiveRange *> uniqueLiveRanges;
+  for (auto [_, liveRange] : liveRanges) {
+    if (!liveRange->empty())
+      uniqueLiveRanges.insert(liveRange);
+  }
+
+  // Sort the new live ranges by starting point (ready for tile allocation).
+  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+  std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+            [](LiveRange *a, LiveRange *b) { return *a < *b; });
+  return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+  TileAllocator tileAllocator;
+  SetVector<LiveRange *> allocatedRanges;
+
+  auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
+    unsigned memoryTileId = tileAllocator.allocateInMemoryTileId();
+    auto spillActiveRange = [&](LiveRange *range) {
+      unsigned tileId = *range->tileId;
+      range->tileId = memoryTileId;
+      allocatedRanges.remove(range);
+      return tileId;
     };
-    auto setDiscardableIntAttr = [&](StringRef name, auto value) {
-      rewriter.modifyOpInPlace(tileOp, [&] {
-        func->setDiscardableAttr(name,
-                                 rewriter.getI32IntegerAttr((unsigned)value));
-      });
+
+    auto isTrivialSpill = [](LiveRange *allocatedRange) {
+      return allocatedRange->values.size() == 1 &&
+             isTriviallyCloneableTileOp(
+                 allocatedRange->values[0]
+                     .getDefiningOp<ArmSMETileOpInterface>());
     };
 
-    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
-    if (!tileType)
-      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
-    TileMask tilesInUse =
-        static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
-    auto tileId = allocateTileId(*tileType, tilesInUse);
-    bool tileIsInMemory = failed(tileId);
-    if (tileIsInMemory) {
-      // If we could not find a real tile ID, use an in-memory tile ID (ID >=
-      // 16). A later pass will insert the necessary spills and reloads.
-      tileId =
-          getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
-      tileOp->emitWarning(
-          "failed to allocate SME virtual tile to operation, all tile "
-          "operations will go through memory, expect degraded performance");
-    }
+    // Heuristic: Spill trivially copyable operations (usually free).
+    if (isTrivialSpill(newRange))
+      return memoryTileId;
+    auto trivialSpill = llvm::find_if(allocatedRanges, isTrivialSpill);
+    if (trivialSpill != allocatedRanges.end())
+      return spillActiveRange(*trivialSpill);
+
+    // Heuristic: Spill the live range that ends last.
+    LiveRange *lastActiveLiveRange = *std::max_element(
+        allocatedRanges.begin(), allocatedRanges.end(),
+        [](LiveRange *a, LiveRange *b) { return a->end() < b->end(); });
+    if (lastActiveLiveRange->end() >= newRange->end())
+      return spillActiveRange(lastActiveLiveRange);
+
+    return memoryTileId;
+  };
 
-    // Set all operations dependent on `tileOp` to use the same tile ID.
-    // This is a naive tile allocation scheme, but works for common cases. For
-    // example, as this only allocates tile IDs to existing ops, it can't solve
-    // cases like this (%tileA and %tileB come from different root operations):
-    //
-    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
-    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
-    // } else {
-    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
-    // }
-    //
-    // This case would require allocating a new tile for the result of the
-    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
-    // on the %some_cond).
-    // Find all the ops that (transitively) depend on this tile.
-    SetVector<Operation *> dependantOps;
-    findDependantOps(tileOp->getResult(0), dependantOps);
-    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
-        auto currentTileId = dependantTileOp.getTileId();
-        if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
-          return dependantTileOp.emitOpError(
-              "already assigned different SME virtual tile!");
+  for (LiveRange *newRange : liveRanges) {
+    // Release tiles from live ranges that have ended.
+    allocatedRanges.remove_if([&](LiveRange *allocatedRange) {
+      if (allocatedRange->end() <= newRange->start()) {
+        tileAllocator.releaseTileId(allocatedRange->getTileType(),
+                                    *allocatedRange->tileId);
+        return true;
----------------
MacDue wrote:

Not really... none of these names were good (including my old ones). 

- `liveRanges` -> `liveRangesSortedByStartPoint`
- `newRange` -> `nextRange` (it's the next live range in program order)
- `allocatedRange` -> `activeRange` 
- `allocateRanges` -> `activeRanges` (these are the live ranges that are currently _active_, not all live ranges that have been allocated a tile ID)

https://github.com/llvm/llvm-project/pull/90448


More information about the Mlir-commits mailing list