[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Nov 28 12:44:04 PST 2023


================
@@ -107,73 +107,151 @@ enum class TileMask : unsigned {
 };
 
 /// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
-  static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
-  static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
-                                                   TileMask::kZA1H};
-  static const SmallVector<TileMask> ZA_S_MASKS = {
-      TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
-  static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+                                            TileMask::kZA2S, TileMask::kZA3S};
+  static constexpr std::array ZA_D_MASKS = {
       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
-  static const SmallVector<TileMask> ZA_Q_MASKS = {
+  static constexpr std::array ZA_Q_MASKS = {
       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
-  switch (cast<IntegerType>(type).getWidth()) {
-  default:
-    llvm_unreachable("unexpected type!");
-  case 8:
+  switch (type) {
+  case ArmSMETileType::ZAB:
     return ZA_B_MASKS;
-  case 16:
+  case ArmSMETileType::ZAH:
     return ZA_H_MASKS;
-  case 32:
+  case ArmSMETileType::ZAS:
     return ZA_S_MASKS;
-  case 64:
+  case ArmSMETileType::ZAD:
     return ZA_D_MASKS;
-  case 128:
+  case ArmSMETileType::ZAQ:
     return ZA_Q_MASKS;
   }
 }
 
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
-                             unsigned &tileID) {
-  auto masks = getMasks(tileIDOp.getType());
-  for (const auto &it : llvm::enumerate(masks)) {
-    const auto tileMask = it.value();
+/// Allocates and returns a tile ID. Returns an error if there are no tiles
+/// left.
+static FailureOr<unsigned> allocateTileId(ArmSMETileType tileType,
+                                          TileMask &tilesInUse) {
+  auto masks = getMasks(tileType);
+  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
     if ((tilesInUse & tileMask) == TileMask::kNone) {
       tilesInUse |= tileMask;
-      tileID = it.index();
-      return success();
+      return tileId;
     }
   }
-  return tileIDOp.emitError("ran out of SME virtual tiles!");
+  return failure();
 }
 
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetTileID tileIDOp,
+/// Collects transitive uses of a root value through control flow. This can
+/// handle basic SCF constructs, along with control flow (br and cond_br).
+/// Simple loops work at the SCF level, while more complex control flow can be
+/// delt with after lowering to CF. This can be used to implement basic tile
+/// allocation.
+static void findDependantOps(Value rootValue,
+                             SetVector<Operation *> &dependantOps) {
+  auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) {
+    for (auto [idx, value] : llvm::enumerate(inputValues)) {
+      if (value != rootValue)
+        continue;
+      findDependantOps(exitValues[idx], dependantOps);
+    }
+  };
+  for (Operation *user : rootValue.getUsers()) {
+    if (dependantOps.contains(user))
+      continue;
+    dependantOps.insert(user);
+    if (auto branchOp = llvm::dyn_cast<cf::BranchOp>(user)) {
+      // (CF) Follow branch.
+      traverseCorrespondingValues(branchOp.getDestOperands(),
+                                  branchOp.getDest()->getArguments());
+    } else if (auto condBranchOp = llvm::dyn_cast<cf::CondBranchOp>(user)) {
+      // (CF) Follow true branch.
+      traverseCorrespondingValues(condBranchOp.getTrueOperands(),
+                                  condBranchOp.getTrueDest()->getArguments());
+      // (CF) Follow false branch.
+      traverseCorrespondingValues(condBranchOp.getFalseOperands(),
+                                  condBranchOp.getFalseDest()->getArguments());
+    } else if (auto loop = llvm::dyn_cast<LoopLikeOpInterface>(user)) {
+      // (SCF) Follow iter_args of (basic) loops (e.g. for loops).
+      traverseCorrespondingValues(loop.getInits(), loop.getRegionIterArgs());
+    } else if (user->hasTrait<OpTrait::ReturnLike>()) {
+      // (SCF) Follow yields of (basic) control flow (e.g. for loops).
+      auto parent = user->getParentOp();
+      // Don't traverse outside a function.
+      if (llvm::isa<FunctionOpInterface>(parent))
+        continue;
+      traverseCorrespondingValues(user->getOperands(), parent->getResults());
+    } else {
+      // Otherwise, assume users of _any_ result are dependant.
+      for (Value result : user->getResults())
+        findDependantOps(result, dependantOps);
+    }
+  }
+}
+
+struct AssignTileIDsPattern
+    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
                                 PatternRewriter &rewriter) const override {
-    auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+    if (tileOp.getTileId())
+      return failure();
+
+    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+    if (!tileType)
+      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
     TileMask tilesInUse;
-    if (auto tilesInUseAttr =
-            funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+    if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
       tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
     else
       tilesInUse = TileMask::kNone;
 
-    unsigned tileID;
-    if (failed(getTile(tileIDOp, tilesInUse, tileID)))
-      return failure();
+    auto tileId = allocateTileId(*tileType, tilesInUse);
+    if (failed(tileId))
+      return tileOp.emitError("ran out of SME virtual tiles!");
+
+    func->setAttr(kTilesInUseAttr,
+                  rewriter.getI32IntegerAttr((unsigned)tilesInUse));
 
-    funcOp->setAttr(kTilesInUseAttr,
-                    rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+    // Find all the ops that (transitively) depend on this tile.
+    SetVector<Operation *> dependantOps;
+    findDependantOps(tileOp->getResult(0), dependantOps);
+
+    // Set all operations to use the same tile ID.
+    // This is a naive tile allocation scheme, but works for common cases. For
+    // example, as this only allocates tile IDs to existing ops, it can't solve
+    // cases like:
+    //
+    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
+    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
+    // } else {
+    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
+    // }
+    //
+    // Where %tileA and %tileB come from different root operations. This case
----------------
banach-space wrote:

```suggestion
    // cases like this (%tileA and %tileB come from different root operations):
    //
    // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
    //   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
    // } else {
    //   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
    // }
    //
    // This case ...
```

Suggesting to move this sentence higher up. Otherwise it reads a bit weird. This is a [nit].

https://github.com/llvm/llvm-project/pull/73253


More information about the Mlir-commits mailing list