[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)

Benjamin Maxwell llvmlistbot at llvm.org
Fri Nov 24 02:06:48 PST 2023


================
@@ -107,73 +109,85 @@ enum class TileMask : unsigned {
 };
 
 /// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
-  static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
-  static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
-                                                   TileMask::kZA1H};
-  static const SmallVector<TileMask> ZA_S_MASKS = {
-      TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
-  static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+  static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+  static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+  static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+                                            TileMask::kZA2S, TileMask::kZA3S};
+  static constexpr std::array ZA_D_MASKS = {
       TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
       TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
-  static const SmallVector<TileMask> ZA_Q_MASKS = {
+  static constexpr std::array ZA_Q_MASKS = {
       TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
       TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
       TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
       TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
-  switch (cast<IntegerType>(type).getWidth()) {
-  default:
-    llvm_unreachable("unexpected type!");
-  case 8:
+  switch (type) {
+  case ArmSMETileType::ZAB:
     return ZA_B_MASKS;
-  case 16:
+  case ArmSMETileType::ZAH:
     return ZA_H_MASKS;
-  case 32:
+  case ArmSMETileType::ZAS:
     return ZA_S_MASKS;
-  case 64:
+  case ArmSMETileType::ZAD:
     return ZA_D_MASKS;
-  case 128:
+  case ArmSMETileType::ZAQ:
     return ZA_Q_MASKS;
   }
 }
 
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
-                             unsigned &tileID) {
-  auto masks = getMasks(tileIDOp.getType());
-  for (const auto &it : llvm::enumerate(masks)) {
-    const auto tileMask = it.value();
+/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
+static FailureOr<unsigned> getTile(ArmSMETileType tileType,
+                                   TileMask &tilesInUse) {
+  auto masks = getMasks(tileType);
+  for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
     if ((tilesInUse & tileMask) == TileMask::kNone) {
       tilesInUse |= tileMask;
-      tileID = it.index();
-      return success();
+      return tileId;
     }
   }
-  return tileIDOp.emitError("ran out of SME virtual tiles!");
+  return failure();
 }
 
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
-  using OpRewritePattern::OpRewritePattern;
-  LogicalResult matchAndRewrite(GetTileID tileIDOp,
+struct AssignTileIDsPattern
+    : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+  LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
                                 PatternRewriter &rewriter) const override {
-    auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+    if (tileOp.getTileId())
+      return failure();
+
+    std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+    if (!tileType)
+      return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+    auto func = tileOp->getParentOfType<FunctionOpInterface>();
     TileMask tilesInUse;
-    if (auto tilesInUseAttr =
-            funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+    if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
       tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
     else
       tilesInUse = TileMask::kNone;
 
-    unsigned tileID;
-    if (failed(getTile(tileIDOp, tilesInUse, tileID)))
-      return failure();
+    auto tileId = getTile(*tileType, tilesInUse);
+    if (failed(tileId))
+      return tileOp.emitError("ran out of SME virtual tiles!");
 
-    funcOp->setAttr(kTilesInUseAttr,
-                    rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+    func->setAttr(kTilesInUseAttr,
+                  rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+
+    // Find all the ops that (transitively) depend on this tile.
+    SetVector<Operation *> dependantOps;
+    getForwardSlice(tileOp.getOperation(), &dependantOps);
+
+    // Set all operations to use the same tile ID.
+    // This is a navie tile allocation scheme, but works for common cases.
----------------
MacDue wrote:

I think that's correct. I more mean that to correctly allocate tiles in all cases would require more inspection of the program, and inserting new ops like tile moves. 

For example you could end up with a case like:
```
%tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
   scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32>
} else {
   scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32>
}
```
Which would need to become (something like):
```
%ifResult = arm_sme.get_tile  {tile_id = 2} : vector<[4]x[4]xi32>
%tile = scf.if %some_cond -> vector<[4]x[4]xi32> {
   arm_sme.move_tile %ifResult <- %tileA : vector<[4]x[4]xi32>
   scf.yield %ifResult {tile_id = 2} : vector<[4]x[4]xi32>
} else {
   arm_sme.move_tile %ifResult <- %tileB : vector<[4]x[4]xi32>
   scf.yield %tileB {tile_id = 2} : vector<[4]x[4]xi32>
}

https://github.com/llvm/llvm-project/pull/73253


More information about the Mlir-commits mailing list