[Mlir-commits] [mlir] [mlir][ArmSME] Switch to an attribute-based tile allocation scheme (PR #73253)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Nov 24 02:07:29 PST 2023
================
@@ -107,73 +109,85 @@ enum class TileMask : unsigned {
};
/// Returns the set of masks relevant for the given type.
-static ArrayRef<TileMask> getMasks(Type type) {
- static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
- static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
- TileMask::kZA1H};
- static const SmallVector<TileMask> ZA_S_MASKS = {
- TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
- static const SmallVector<TileMask> ZA_D_MASKS = {
+static ArrayRef<TileMask> getMasks(ArmSMETileType type) {
+ static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
+ static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
+ static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
+ TileMask::kZA2S, TileMask::kZA3S};
+ static constexpr std::array ZA_D_MASKS = {
TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
- static const SmallVector<TileMask> ZA_Q_MASKS = {
+ static constexpr std::array ZA_Q_MASKS = {
TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
- switch (cast<IntegerType>(type).getWidth()) {
- default:
- llvm_unreachable("unexpected type!");
- case 8:
+ switch (type) {
+ case ArmSMETileType::ZAB:
return ZA_B_MASKS;
- case 16:
+ case ArmSMETileType::ZAH:
return ZA_H_MASKS;
- case 32:
+ case ArmSMETileType::ZAS:
return ZA_S_MASKS;
- case 64:
+ case ArmSMETileType::ZAD:
return ZA_D_MASKS;
- case 128:
+ case ArmSMETileType::ZAQ:
return ZA_Q_MASKS;
}
}
-/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
-static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
- unsigned &tileID) {
- auto masks = getMasks(tileIDOp.getType());
- for (const auto &it : llvm::enumerate(masks)) {
- const auto tileMask = it.value();
+/// Allocates a tile to 'tileId' or returns an error if there are no tiles left.
+static FailureOr<unsigned> getTile(ArmSMETileType tileType,
+ TileMask &tilesInUse) {
+ auto masks = getMasks(tileType);
+ for (auto [tileId, tileMask] : llvm::enumerate(masks)) {
if ((tilesInUse & tileMask) == TileMask::kNone) {
tilesInUse |= tileMask;
- tileID = it.index();
- return success();
+ return tileId;
}
}
- return tileIDOp.emitError("ran out of SME virtual tiles!");
+ return failure();
}
-struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(GetTileID tileIDOp,
+struct AssignTileIDsPattern
+ : public OpInterfaceRewritePattern<ArmSMETileOpInterface> {
+ using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
+ LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp,
PatternRewriter &rewriter) const override {
- auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+ if (tileOp.getTileId())
+ return failure();
+
+ std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
+ if (!tileType)
+ return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile");
+
+ auto func = tileOp->getParentOfType<FunctionOpInterface>();
TileMask tilesInUse;
- if (auto tilesInUseAttr =
- funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+ if (auto tilesInUseAttr = func->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
else
tilesInUse = TileMask::kNone;
- unsigned tileID;
- if (failed(getTile(tileIDOp, tilesInUse, tileID)))
- return failure();
+ auto tileId = getTile(*tileType, tilesInUse);
+ if (failed(tileId))
+ return tileOp.emitError("ran out of SME virtual tiles!");
- funcOp->setAttr(kTilesInUseAttr,
- rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+ func->setAttr(kTilesInUseAttr,
+ rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+
+ // Find all the ops that (transitively) depend on this tile.
+ SetVector<Operation *> dependantOps;
+ getForwardSlice(tileOp.getOperation(), &dependantOps);
+
+ // Set all operations to use the same tile ID.
+ // This is a navie tile allocation scheme, but works for common cases.
----------------
MacDue wrote:
But this simple allocation works for the use cases we currently need :)
https://github.com/llvm/llvm-project/pull/73253
More information about the Mlir-commits
mailing list