[Mlir-commits] [mlir] [mlir][ArmSME] Use liveness information in the tile allocator (PR #90448)

Andrzej Warzyński llvmlistbot at llvm.org
Wed May 8 08:59:09 PDT 2024


================
@@ -137,172 +129,551 @@ static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
   }
 }
 
-/// Allocates and returns a tile ID. Returns an error if there are no tiles
-/// left.
-static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
-                                          TileMask &tilesInUse) {
-  auto masks = getMasks(tileType);
-  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
-    if ((tilesInUse & tileMask) == TileMask::kNone) {
-      tilesInUse |= tileMask;
-      return tileId;
+class TileAllocator {
+public:
+  /// Allocates and returns a tile ID. Fails if there are no tiles left.
+  FailureOr<unsigned> allocateTileId(ArmSMETileType tileType) {
+    auto masks = getMasks(tileType);
+    for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
+      if ((tilesInUse & tileMask) == TileMask::kNone) {
+        tilesInUse |= tileMask;
+        return tileId;
+      }
     }
+    return failure();
+  }
+
+  /// Releases a previously allocated tile ID.
+  void releaseTileId(ArmSMETileType tileType, unsigned tileId) {
+    TileMask tileMask = getMasks(tileType)[tileId];
+    assert((tilesInUse & tileMask) != TileMask::kNone &&
+           "cannot release unallocated tile!");
+    tilesInUse ^= tileMask;
+  }
+
+  /// Allocates an in-memory tile ID.
+  unsigned allocateInMemoryTileId() {
+    // Note: We never release in-memory tile IDs. We could, which may allow
+    // reusing an allocation, but as we _never_ want to spill an SME tile this
+    // is not optimized.
+    return nextInMemoryTileId++;
   }
-  return failure();
-}
 
-/// Collects transitive uses of a root value through control flow. This can
-/// handle basic SCF constructs, along with control flow (br and cond_br).
-/// Simple loops work at the SCF level, while more complex control flow can be
-/// dealt with after lowering to CF. This is used to implement basic tile
-/// allocation.
-static void findDependantOps(Value rootValue,
-                             SetVector<Operation *> &dependantOps) {
-  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
-    for (auto [idx, value] : llvm::enumerate(inputValues)) {
-      if (value == rootValue)
-        findDependantOps(exitValues[idx], dependantOps);
+private:
+  TileMask tilesInUse = TileMask::kNone;
+  unsigned nextInMemoryTileId = kInMemoryTileIdBase;
+};
+
+/// Add new intermediate blocks for the true and false destinations of
+/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness
+/// overlaps due to copies at branches.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///    cf.cond_br %cond, ^bb1_copy, ^bb2_copy
+///  ^bb1_copy:
+///    cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ^bb2_copy:
+///    cf.br ^bb2
+///  ```
+void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
+  SmallVector<cf::CondBranchOp> worklist;
+  function.walk([&](cf::CondBranchOp condBranch) {
+    if (llvm::any_of(condBranch->getOperands(), [&](Value value) {
+          return isValidSMETileVectorType(value.getType());
+        })) {
+      worklist.push_back(condBranch);
     }
+  });
+
+  auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) {
+    rewriter.setInsertionPointToEnd(source);
+    rewriter.create<cf::BranchOp>(loc, dest, args);
   };
-  for (Operation *user : rootValue.getUsers()) {
-    if (dependantOps.contains(user))
+
+  for (auto condBranch : worklist) {
+    auto loc = condBranch.getLoc();
+    Block *block = condBranch->getBlock();
+    auto newTrueBranch = rewriter.splitBlock(block, block->end());
+    auto newFalseBranch = rewriter.splitBlock(block, block->end());
+    insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
+               condBranch.getTrueDestOperands());
+    insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
+               condBranch.getFalseDestOperands());
+    rewriter.modifyOpInPlace(condBranch, [&] {
+      condBranch.getFalseDestOperandsMutable().clear();
+      condBranch.getTrueDestOperandsMutable().clear();
+      condBranch.setSuccessor(newTrueBranch, 0);
+      condBranch.setSuccessor(newFalseBranch, 1);
+    });
+  }
+}
+
+/// Splits conditional branches (see `splitCondBranches`), then inserts tile
+/// copies at `cf.br` operations.
+///
+///  BEFORE:
+///  ```mlir
+///  cf.br ^bb1(%tile: vector<[4]x[4]xf32>)
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32>
+///  ```
+void insertCopiesAtBranches(IRRewriter &rewriter,
+                            FunctionOpInterface function) {
+  splitCondBranches(rewriter, function);
+  for (Block &block : function.getBlocks()) {
+    Operation *terminator = block.getTerminator();
+    if (!isa<cf::BranchOp>(terminator))
       continue;
-    dependantOps.insert(user);
-    TypeSwitch<Operation *>(user)
-        .Case<cf::BranchOp>([&](auto branchOp) {
-          // (CF) Follow branch.
-          traverseCorrespondingValues(branchOp.getDestOperands(),
-                                      branchOp.getDest()->getArguments());
-        })
-        .Case<cf::CondBranchOp>([&](auto condBranchOp) {
-          // (CF) Follow true branch.
-          traverseCorrespondingValues(
-              condBranchOp.getTrueOperands(),
-              condBranchOp.getTrueDest()->getArguments());
-          // (CF) Follow false branch.
-          traverseCorrespondingValues(
-              condBranchOp.getFalseOperands(),
-              condBranchOp.getFalseDest()->getArguments());
-        })
-        .Case<LoopLikeOpInterface>([&](auto loopOp) {
-          // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
-          traverseCorrespondingValues(loopOp.getInits(),
-                                      loopOp.getRegionIterArgs());
-        })
-        .Case<scf::YieldOp>([&](auto yieldOp) {
-          // (SCF) Follow yields of (basic) control flow (e.g. for loops).
-          auto parent = user->getParentOp();
-          traverseCorrespondingValues(user->getOperands(),
-                                      parent->getResults());
+    rewriter.setInsertionPoint(terminator);
+    for (OpOperand &operand : terminator->getOpOperands()) {
+      if (isValidSMETileVectorType(operand.get().getType())) {
+        auto copy =
+            rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get());
+        rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); });
+      }
+    }
+  }
+}
+
+/// A live range for a (collection of) tile values. A live range is built up of
+/// intervals [start, end) which represent parts of the program where the value
+/// needs to be live (i.e. in an SME virtual tile).
+struct LiveRange {
+  using RangeSet = llvm::IntervalMap<uint64_t, uint8_t, 16,
+                                     llvm::IntervalMapHalfOpenInfo<unsigned>>;
+  using Allocator = RangeSet::Allocator;
+  static constexpr uint8_t kValidLiveRange = 0xff;
+
+  LiveRange(Allocator &allocator)
+      : ranges(std::make_unique<RangeSet>(allocator)) {}
+
+  /// Returns true if this range overlaps with `otherRange`.
+  bool overlaps(LiveRange const &otherRange) const {
+    return llvm::IntervalMapOverlaps<RangeSet, RangeSet>(*ranges,
+                                                         *otherRange.ranges)
+        .valid();
+  }
+
+  /// Unions this live range with `otherRange`, aborts if the ranges overlap.
+  void unionWith(LiveRange const &otherRange) {
+    for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end();
+         ++it)
+      ranges->insert(it.start(), it.stop(), kValidLiveRange);
+    values.set_union(otherRange.values);
+  }
+
+  /// Inserts an interval [start, end) for `value` into this range.
+  void insert(Value value, unsigned start, unsigned end) {
+    values.insert(value);
+    if (start != end)
+      ranges->insert(start, end, kValidLiveRange);
+  }
+
+  bool empty() const { return ranges->empty(); }
+  unsigned start() const { return ranges->start(); }
+  unsigned end() const { return ranges->stop(); }
+  bool operator<(LiveRange const &other) const {
+    return start() < other.start();
+  }
+
+  ArmSMETileType getTileType() const {
+    return *getSMETileType(cast<VectorType>(values[0].getType()));
+  }
+
+  std::unique_ptr<RangeSet> ranges;
+  SetVector<Value> values;
+  std::optional<unsigned> tileId;
+};
+
+/// Number operations within a function to allow computing live ranges.
+DenseMap<Operation *, unsigned>
+generateOperationNumbering(FunctionOpInterface function) {
+  unsigned index = 0;
+  SetVector<Block *> blocks =
+      getTopologicallySortedBlocks(function.getFunctionBody());
+  DenseMap<Operation *, unsigned> operationToIndexMap;
+  for (Block *block : blocks) {
+    index++; // We want block args to have their own number.
+    for (Operation &op : block->getOperations()) {
+      // This is only correct if all ArmSME have been converted to CF.
+#ifndef NDEBUG
+      op.walk([&](ArmSMETileOpInterface nestedOp) {
+        assert(&op == nestedOp.getOperation() &&
+               "ArmSME tile allocation does not support nested regions");
+      });
+#endif
+      operationToIndexMap.try_emplace(&op, index++);
+    }
+  }
+  return operationToIndexMap;
+}
+
+/// Gather live ranges for SME tiles from the MLIR liveness analysis.
+DenseMap<Value, LiveRange>
+gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
+                     LiveRange::Allocator &liveRangeAllocator,
+                     Liveness &liveness, FunctionOpInterface function) {
+  DenseMap<Value, LiveRange> liveRanges;
+  /// Defines or updates a live range for an SME tile value. Live-ins may update
+  /// an existing live range (rather than define a new one).
+  auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef,
+                                          LivenessBlockInfo const &livenessInfo,
+                                          bool liveAtBlockEntry = false) {
+    if (!isValidSMETileVectorType(value.getType()))
+      return;
+    // Find or create a live range for `value`.
+    auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
+    LiveRange &valueLiveRange = it->second;
+    auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+    // Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
+    unsigned start =
+        operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
+    unsigned end = operationToIndexMap.at(lastUseInBlock);
+    valueLiveRange.insert(value, start, end);
+  };
+
+  for (Block &block : function.getBlocks()) {
+    LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block);
+    // Handle block arguments:
+    for (Value argument : block.getArguments())
+      defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
+    // Handle live-ins:
+    for (Value liveIn : livenessInfo->in())
+      defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo,
+                                   /*liveAtBlockEntry=*/true);
+    // Handle new definitions:
+    for (Operation &op : block) {
+      for (Value result : op.getResults())
+        defineOrUpdateValueLiveRange(result, &op, *livenessInfo);
+    }
+  }
+
+  return liveRanges;
+}
+
+/// Iterate over all predecessor tile values to a (tile) block argument.
+static void forEachPredecessorTileValue(BlockArgument blockArg,
+                                        function_ref<void(Value)> callback) {
+  Block *block = blockArg.getOwner();
+  unsigned argNumber = blockArg.getArgNumber();
+  for (Block *pred : block->getPredecessors()) {
+    TypeSwitch<Operation *>(pred->getTerminator())
+        .Case<cf::BranchOp>([&](auto branch) {
+          Value predecessorOperand = branch.getDestOperands()[argNumber];
+          callback(predecessorOperand);
         })
-        .Default([&](auto) {
-          // Otherwise, assume users of _any_ result are dependant.
-          for (Value result : user->getResults())
-            findDependantOps(result, dependantOps);
+        .Case<cf::CondBranchOp>([&](auto condBranch) {
+          if (condBranch.getFalseDest() == block) {
+            Value predecessorOperand =
+                condBranch.getFalseDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
+          if (condBranch.getTrueDest() == block) {
+            Value predecessorOperand =
+                condBranch.getTrueDestOperands()[argNumber];
+            callback(predecessorOperand);
+          }
         });
   }
 }
-struct AssignTileIDsPattern
-    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
-                                PatternRewriter &rewriter) const override {
-    if (tileOp.getTileId())
-      return failure();
-
-    auto func = tileOp->getParentOfType<FunctionOpInterface>();
-    auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) {
-      if (auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
-              func->getDiscardableAttr(name)))
-        return unsigned(attr.getInt());
-      return defaultVal;
+
+/// Coalesce live ranges where it would prevent unnecessary tile moves.
+SmallVector<LiveRange *>
+coalesceTileLiveRanges(DenseMap<Value, LiveRange> &initialLiveRanges) {
+  DenseMap<Value, LiveRange *> liveRanges;
+  for (auto &[value, liveRange] : initialLiveRanges) {
+    liveRanges.insert({value, &liveRange});
+  }
+
+  auto mergeValuesIfNonOverlapping = [&](Value a, Value b) {
+    LiveRange *aLiveRange = liveRanges.at(a);
+    LiveRange *bLiveRange = liveRanges.at(b);
+    if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) {
+      aLiveRange->unionWith(*bLiveRange);
+      for (Value value : bLiveRange->values)
+        liveRanges[value] = aLiveRange;
+    }
+  };
+
+  // Merge the live ranges of new definitions with their tile operands.
+  auto unifyDefinitionsWithOperands = [&](Value value) {
+    auto armSMEOp = value.getDefiningOp<ArmSMETileOpInterface>();
+    if (!armSMEOp)
+      return;
+    for (auto operand : armSMEOp->getOperands()) {
+      if (isValidSMETileVectorType(operand.getType()))
+        mergeValuesIfNonOverlapping(value, operand);
+    }
+  };
+
+  // Merge the live ranges of block arguments with their predecessors.
+  auto unifyBlockArgumentsWithPredecessors = [&](Value value) {
+    auto blockArg = dyn_cast<BlockArgument>(value);
+    if (!blockArg)
+      return;
+    forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) {
+      mergeValuesIfNonOverlapping(blockArg, predecessorTile);
+    });
+  };
+
+  auto applyRule = [&](auto rule) {
+    llvm::for_each(llvm::make_first_range(initialLiveRanges), rule);
+  };
+
+  // Unify as many live ranges as we can. This prevents unnecessary moves.
+  applyRule(unifyBlockArgumentsWithPredecessors);
+  applyRule(unifyDefinitionsWithOperands);
+
+  // Remove duplicate live range entries.
+  SetVector<LiveRange *> uniqueLiveRanges;
+  for (auto [_, liveRange] : liveRanges) {
+    if (!liveRange->empty())
+      uniqueLiveRanges.insert(liveRange);
+  }
+
+  // Sort the new live ranges by starting point (ready for tile allocation).
+  auto coalescedLiveRanges = uniqueLiveRanges.takeVector();
+  std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(),
+            [](LiveRange *a, LiveRange *b) { return *a < *b; });
+  return std::move(coalescedLiveRanges);
+}
+
+/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics.
+/// Note: This does not attempt to fill holes in live/allocated ranges.
+void allocateTilesToLiveRanges(ArrayRef<LiveRange *> liveRanges) {
+  TileAllocator tileAllocator;
+  SetVector<LiveRange *> allocatedRanges;
+
+  auto chooseSpillUsingHeuristics = [&](LiveRange *newRange) {
+    unsigned memoryTileId = tileAllocator.allocateInMemoryTileId();
+    auto spillActiveRange = [&](LiveRange *range) {
+      unsigned tileId = *range->tileId;
+      range->tileId = memoryTileId;
+      allocatedRanges.remove(range);
+      return tileId;
     };
-    auto setDiscardableIntAttr = [&](StringRef name, auto value) {
-      rewriter.modifyOpInPlace(tileOp, [&] {
-        func->setDiscardableAttr(name,
-                                 rewriter.getI32IntegerAttr((unsigned)value));
-      });
+
+    auto isTrivialSpill = [](LiveRange *allocatedRange) {
+      return allocatedRange->values.size() == 1 &&
+             isTriviallyCloneableTileOp(
+                 allocatedRange->values[0]
+                     .getDefiningOp<ArmSMETileOpInterface>());
     };
 
-    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
-    if (!tileType)
-      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
-
-    TileMask tilesInUse =
-        static_cast<TileMask>(getDiscardableIntAttr(kTilesInUseAttr));
-    auto tileId = allocateTileId(*tileType, tilesInUse);
-    bool tileIsInMemory = failed(tileId);
-    if (tileIsInMemory) {
-      // If we could not find a real tile ID, use an in-memory tile ID (ID >=
-      // 16). A later pass will insert the necessary spills and reloads.
-      tileId =
-          getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase);
-      tileOp->emitWarning(
-          "failed to allocate SME virtual tile to operation, all tile "
-          "operations will go through memory, expect degraded performance");
-    }
+    // Heuristic: Spill trivially copyable operations (usually free).
+    if (isTrivialSpill(newRange))
+      return memoryTileId;
+    auto trivialSpill = llvm::find_if(allocatedRanges, isTrivialSpill);
+    if (trivialSpill != allocatedRanges.end())
+      return spillActiveRange(*trivialSpill);
 
-    // Set all operations dependent on `tileOp` to use the same tile ID.
-    // This is a naive tile allocation scheme, but works for common cases. For
-    // example, as this only allocates tile IDs to existing ops, it can't solve
-    // cases like this (%tileA and %tileB come from different root operations):
-    //
-    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
-    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
-    // } else {
-    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
-    // }
-    //
-    // This case would require allocating a new tile for the result of the
-    // scf.if, and moving the contents of %tileA or %tileB to result tile (based
-    // on the %some_cond).
-    // Find all the ops that (transitively) depend on this tile.
-    SetVector<Operation *> dependantOps;
-    findDependantOps(tileOp->getResult(0), dependantOps);
-    auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId);
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
-        auto currentTileId = dependantTileOp.getTileId();
-        if (currentTileId && unsigned(currentTileId.getInt()) != tileId)
-          return dependantTileOp.emitOpError(
-              "already assigned different SME virtual tile!");
+    // Heuristic: Spill the live range that ends last.
+    LiveRange *lastActiveLiveRange = *std::max_element(
+        allocatedRanges.begin(), allocatedRanges.end(),
+        [](LiveRange *a, LiveRange *b) { return a->end() < b->end(); });
+    if (lastActiveLiveRange->end() >= newRange->end())
+      return spillActiveRange(lastActiveLiveRange);
+
+    return memoryTileId;
+  };
+
+  for (LiveRange *newRange : liveRanges) {
+    // Release tiles from live ranges that have ended.
+    allocatedRanges.remove_if([&](LiveRange *allocatedRange) {
+      if (allocatedRange->end() <= newRange->start()) {
+        tileAllocator.releaseTileId(allocatedRange->getTileType(),
+                                    *allocatedRange->tileId);
+        return true;
       }
-    }
+      return false;
+    });
 
-    // Rewrite IR.
-    if (!tileIsInMemory)
-      setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
+    // Allocate a tile ID to `newRange`.
+    auto tileId = tileAllocator.allocateTileId(newRange->getTileType());
+    if (succeeded(tileId))
+      newRange->tileId = *tileId;
     else
-      setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
-    rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
-    for (auto *op : dependantOps) {
-      if (auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
+      newRange->tileId = chooseSpillUsingHeuristics(newRange);
+
+    // Insert the live range into the allocated ranges.
+    if (newRange->tileId < kInMemoryTileIdBase)
+      allocatedRanges.insert(newRange);
+  }
+}
+
+/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts.
+LogicalResult assignTileIdsAndResolveTrivialConflicts(
+    IRRewriter &rewriter, FunctionOpInterface function,
+    ArrayRef<LiveRange *> allocatedLiveRanges) {
+  for (LiveRange const *liveRange : allocatedLiveRanges) {
+    auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId);
+    auto isAllocatedToSameTile = [&](Value value) {
+      if (auto tileOp = value.getDefiningOp<ArmSMETileOpInterface>();
+          tileOp && tileOp.getTileId() == tileIdAttr)
+        return true;
+      return liveRange->values.contains(value);
+    };
+    for (Value value : liveRange->values) {
+      for (Operation *user : value.getUsers()) {
+        if (auto tileOp = dyn_cast<ArmSMETileOpInterface>(user)) {
+          // Ensure ArmSME ops that don't produce a value still get a tile ID.
+          if (!hasTileResult(tileOp))
+            rewriter.modifyOpInPlace(tileOp,
+                                     [&] { tileOp.setTileId(tileIdAttr); });
+        }
+      }
+      auto copyOp = value.getDefiningOp<CopyTileOp>();
+      if (copyOp && isAllocatedToSameTile(copyOp.getTile())) {
----------------
banach-space wrote:

> I think you may be misunderstanding that guideline. It's for doing early exits when checking preconditions (which is not what this is doing).

The guidelines leave room for interpretations and that's intentional. There's always a subjective element to all of this - I am saying that **for me** it is hard to extract the structure of this nested `if/else-if` block (and what the actual cases are). So I'm asking for this to be clarified and simplified.

> I'll add some more comments 👍

Yes, please. Though splitting the actual blocks (via `continue` as suggested above) would also help reduce the cognitive load. I might be mistaken, but sounds like for you the code will be equally readable.

https://github.com/llvm/llvm-project/pull/90448


More information about the Mlir-commits mailing list