[Mlir-commits] [mlir] fb54fec - [mlir][ArmSME] Implement tile allocation

Cullen Rhodes llvmlistbot at llvm.org
Tue Jul 18 01:46:54 PDT 2023


Author: Cullen Rhodes
Date: 2023-07-18T08:46:40Z
New Revision: fb54fec7263ab420e9e02e1ea2b66403cf655a33

URL: https://github.com/llvm/llvm-project/commit/fb54fec7263ab420e9e02e1ea2b66403cf655a33
DIFF: https://github.com/llvm/llvm-project/commit/fb54fec7263ab420e9e02e1ea2b66403cf655a33.diff

LOG: [mlir][ArmSME] Implement tile allocation

This patch adds a pass '-allocate-sme-tiles' to the ArmSME dialect that
implements allocation of SME ZA tiles.

It does this at the 'func.func' op level by replacing
'arm_sme.get_tile_id' ops with 'arith.constant' ops that represent the
tile number. The tiles in use in a given function are tracked by an
integer function attribute 'arm_sme.tiles_in_use' that is a 16-bit tile
mask with a bit for each 128-bit element tile (ZA0.Q-ZA15.Q), the
smallest ZA tile granule. This is initialized on the first
'arm_sme.get_tile_id' rewrite and updated on each subsequent rewrite.
Mixing of different element tile types is supported.

Section B2.3.2 of the SME spec [1] describes how the 128-bit element
tiles overlap with other element tiles.

Depends on D154941

[1] https://developer.arm.com/documentation/ddi0616/aa

Reviewed By: awarzynski

Differential Revision: https://reviews.llvm.org/D154955

Added: 
    mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
    mlir/test/Dialect/ArmSME/tile-allocation.mlir

Modified: 
    mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
    mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
    mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
    mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
    mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 00ac5376ed7d98..133968b60665b0 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -30,6 +30,9 @@ std::unique_ptr<Pass>
 createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default,
                              const bool enableZA = false);
 
+/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
+std::unique_ptr<Pass> createTileAllocationPass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index 7bc39e0534b8a8..3fa1b43eb9e67e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -39,4 +39,16 @@ def EnableArmStreaming
   let dependentDialects = ["func::FuncDialect"];
 }
 
+def TileAllocation
+    : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> {
+  let summary = "Allocate SME tiles";
+  let description = [{
+    This pass does tile allocation for SME "virtual tiles". It is run at the
+    'func.func' op level, replacing 'arm_sme.get_tile_id' ops with (i32) tile
+    ids. An error will be emitted when there's no tiles left.
+  }];
+  let constructor = "mlir::arm_sme::createTileAllocationPass()";
+  let dependentDialects = ["func::FuncDialect"];
+}
+
 #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 9f4c3a0ce51a1e..247da2a3a4aa11 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
   EnableArmStreaming.cpp
   LegalizeForLLVMExport.cpp
+  TileAllocation.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index cb556d8d4dfe6e..e837432410de89 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -50,17 +50,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
     return success();
   }
 };
-
-struct GetTileIDConversion : public ConvertOpToLLVMPattern<GetTileID> {
-  using ConvertOpToLLVMPattern<GetTileID>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(GetTileID op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    // TODO: implement tile allocation, currently only tile 0 is supported.
-    rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(op, rewriter.getI32Type(), 0);
-    return success();
-  }
-};
 } // namespace
 
 /// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return

diff  --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
new file mode 100644
index 00000000000000..e0462a6dc12413
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -0,0 +1,198 @@
+//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass allocates SME tiles at the 'func.func' op level for
+// 'arm_sme.get_tile_id' ops. It does this using a 16-bit tile mask that has a
+// bit for each 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile
+// granule.
+//
+// The 128-bit tiles overlap with other element tiles as follows (see section
+// B2.3.2 of SME spec [1]):
+//
+//   Tile    Overlaps
+//   ---------------------------------------------------------------------------
+//   ZA0.B   ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q,
+//           ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q
+//   ZA0.H   ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
+//   ZA1.H   ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q
+//   ZA0.S   ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
+//   ZA1.S   ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
+//   ZA2.S   ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
+//   ZA3.S   ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
+//   ZA0.D   ZA0.Q, ZA8.Q
+//   ZA1.D   ZA1.Q, ZA9.Q
+//   ZA2.D   ZA2.Q, ZA10.Q
+//   ZA3.D   ZA3.Q, ZA11.Q
+//   ZA4.D   ZA4.Q, ZA12.Q
+//   ZA5.D   ZA5.Q, ZA13.Q
+//   ZA6.D   ZA6.Q, ZA14.Q
+//   ZA7.D   ZA7.Q, ZA15.Q
+//
+// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use'
+// that is initalized during the first 'arm_sme.get_tile_id' rewrite and
+// updated on each subsequent rewrite.
+//
+// [1] https://developer.arm.com/documentation/ddi0616/aa
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#define DEBUG_TYPE "allocate-arm-sme-tiles"
+
+namespace mlir {
+namespace arm_sme {
+#define GEN_PASS_DEF_TILEALLOCATION
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
+} // namespace arm_sme
+} // namespace mlir
+
+using namespace mlir;
+using namespace mlir::arm_sme;
+
+namespace {
+
+static constexpr char kTilesInUseAttr[] = "arm_sme.tiles_in_use";
+
+enum class TileMask : unsigned {
+  // clang-format off
+  kZA0B  = 0xffff, // 1111 1111 1111 1111
+
+  kZA0H  = 0xaaaa, // 1010 1010 1010 1010
+  kZA1H  = 0x5555, // 0101 0101 0101 0101
+
+  kZA0S  = 0x8888, // 1000 1000 1000 1000
+  kZA1S  = 0x4444, // 0100 0100 0100 0100
+  kZA2S  = 0x2222, // 0010 0010 0010 0010
+  kZA3S  = 0x1111, // 0001 0001 0001 0001
+
+  kZA0D  = 0x8080, // 1000 0000 1000 0000
+  kZA1D  = 0x4040, // 0100 0000 0100 0000
+  kZA2D  = 0x2020, // 0010 0000 0010 0000
+  kZA3D  = 0x1010, // 0001 0000 0001 0000
+  kZA4D  = 0x808,  // 0000 1000 0000 1000
+  kZA5D  = 0x404,  // 0000 0100 0000 0100
+  kZA6D  = 0x202,  // 0000 0010 0000 0010
+  kZA7D  = 0x101,  // 0000 0001 0000 0001
+
+  kZA0Q  = 0x8000, // 1000 0000 0000 0000
+  kZA1Q  = 0x4000, // 0100 0000 0000 0000
+  kZA2Q  = 0x2000, // 0010 0000 0000 0000
+  kZA3Q  = 0x1000, // 0001 0000 0000 0000
+  kZA4Q  = 0x800,  // 0000 1000 0000 0000
+  kZA5Q  = 0x400,  // 0000 0100 0000 0000
+  kZA6Q  = 0x200,  // 0000 0010 0000 0000
+  kZA7Q  = 0x100,  // 0000 0001 0000 0000
+  kZA8Q  = 0x80,   // 0000 0000 1000 0000
+  kZA9Q  = 0x40,   // 0000 0000 0100 0000
+  kZA10Q = 0x20,   // 0000 0000 0010 0000
+  kZA11Q = 0x10,   // 0000 0000 0001 0000
+  kZA12Q = 0x8,    // 0000 0000 0000 1000
+  kZA13Q = 0x4,    // 0000 0000 0000 0100
+  kZA14Q = 0x2,    // 0000 0000 0000 0010
+  kZA15Q = 0x1,    // 0000 0000 0000 0001
+
+  kNone = 0x0,     // 0000 0000 0000 0000
+  // clang-format on
+
+  LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
+};
+
+/// Returns the set of masks relevant for the given type.
+static ArrayRef<TileMask> getMasks(Type type) {
+  static const SmallVector<TileMask> ZA_B_MASKS = {TileMask::kZA0B};
+  static const SmallVector<TileMask> ZA_H_MASKS = {TileMask::kZA0H,
+                                                   TileMask::kZA1H};
+  static const SmallVector<TileMask> ZA_S_MASKS = {
+      TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S};
+  static const SmallVector<TileMask> ZA_D_MASKS = {
+      TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
+      TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
+  static const SmallVector<TileMask> ZA_Q_MASKS = {
+      TileMask::kZA0Q,  TileMask::kZA1Q,  TileMask::kZA2Q,  TileMask::kZA3Q,
+      TileMask::kZA4Q,  TileMask::kZA5Q,  TileMask::kZA6Q,  TileMask::kZA7Q,
+      TileMask::kZA8Q,  TileMask::kZA9Q,  TileMask::kZA10Q, TileMask::kZA11Q,
+      TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
+  switch (cast<IntegerType>(type).getWidth()) {
+  default:
+    llvm_unreachable("unexpected type!");
+  case 8:
+    return ZA_B_MASKS;
+  case 16:
+    return ZA_H_MASKS;
+  case 32:
+    return ZA_S_MASKS;
+  case 64:
+    return ZA_D_MASKS;
+  case 128:
+    return ZA_Q_MASKS;
+  }
+}
+
+/// Allocates a tile to 'tileID' or returns an error if there are no tiles left.
+static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse,
+                             unsigned &tileID) {
+  auto masks = getMasks(tileIDOp.getType());
+  for (const auto &it : llvm::enumerate(masks)) {
+    const auto tileMask = it.value();
+    if ((tilesInUse & tileMask) == TileMask::kNone) {
+      tilesInUse |= tileMask;
+      tileID = it.index();
+      return success();
+    }
+  }
+  return tileIDOp.emitError("ran out of SME virtual tiles!");
+}
+
+struct GetTileIDConversion : public OpRewritePattern<GetTileID> {
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(GetTileID tileIDOp,
+                                PatternRewriter &rewriter) const override {
+    auto funcOp = tileIDOp->getParentOfType<func::FuncOp>();
+    TileMask tilesInUse;
+    if (auto tilesInUseAttr =
+            funcOp->getAttrOfType<IntegerAttr>(kTilesInUseAttr))
+      tilesInUse = static_cast<TileMask>(tilesInUseAttr.getInt());
+    else
+      tilesInUse = TileMask::kNone;
+
+    unsigned tileID;
+    if (failed(getTile(tileIDOp, tilesInUse, tileID)))
+      return failure();
+
+    funcOp->setAttr(kTilesInUseAttr,
+                    rewriter.getI32IntegerAttr((unsigned)tilesInUse));
+
+    auto tileType = tileIDOp.getType();
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(
+        tileIDOp, tileType, rewriter.getIntegerAttr(tileType, tileID));
+    return success();
+  }
+};
+
+struct TileAllocationPass
+    : public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    ConversionTarget target(getContext());
+    patterns.add<GetTileIDConversion>(patterns.getContext());
+    target.addLegalOp<arith::ConstantOp>();
+    target.addIllegalOp<GetTileID>();
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::arm_sme::createTileAllocationPass() {
+  return std::make_unique<TileAllocationPass>();
+}

diff  --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
new file mode 100644
index 00000000000000..a481516d4c15f9
--- /dev/null
+++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir
@@ -0,0 +1,377 @@
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: mixed_tiles
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
+func.func @mixed_tiles() {
+  // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
+  // CHECK-NEXT: arith.constant 0
+  %za0_h = arm_sme.get_tile_id : i16
+  // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
+  // CHECK-NEXT: arith.constant 1
+  %za1_s = arm_sme.get_tile_id : i32
+  // ZA3.D ZA3.Q, ZA11.Q
+  // CHECK-NEXT: arith.constant 3
+  %za3_d = arm_sme.get_tile_id : i64
+  // ZA7.Q
+  // CHECK-NEXT: arith.constant 7
+  %za7_q = arm_sme.get_tile_id : i128
+  // ZA15.Q is still free.
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_b
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_b() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_b = arm_sme.get_tile_id : i8
+  return
+}
+
+// -----
+
+func.func @za_b__out_of_tiles() {
+  %za0_b = arm_sme.get_tile_id : i8
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i8
+  return
+}
+
+// -----
+
+func.func @za_b_overlapping_za_q() {
+  %za0_b = arm_sme.get_tile_id : i8
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i128
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za0_h
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
+func.func @za0_h() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_h = arm_sme.get_tile_id : i16
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_h
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_h() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_h = arm_sme.get_tile_id : i16
+  // CHECK-NEXT: arith.constant 1
+  %za1_h = arm_sme.get_tile_id : i16
+  return
+}
+
+// -----
+
+func.func @za_h__out_of_tiles() {
+  %za0_h = arm_sme.get_tile_id : i16
+  %za1_h = arm_sme.get_tile_id : i16
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i16
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_h_overlapping_za_s
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_h_overlapping_za_s() {
+  // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
+  // CHECK-NEXT: arith.constant 0
+  %za0_h = arm_sme.get_tile_id : i16
+  // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
+  // CHECK-NEXT: arith.constant 1
+  %za1_s = arm_sme.get_tile_id : i32
+  // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
+  // CHECK-NEXT: arith.constant 3
+  %za3_s = arm_sme.get_tile_id : i32
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_h_overlapping_za_d
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_h_overlapping_za_d() {
+  // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
+  // CHECK-NEXT: arith.constant 0
+  %za0_h = arm_sme.get_tile_id : i16
+  // ZA1.Q, ZA9.Q
+  // CHECK-NEXT: arith.constant 1
+  %za1_d = arm_sme.get_tile_id : i64
+  // ZA3.Q, ZA11.Q
+  // CHECK-NEXT: arith.constant 3
+  %za3_d = arm_sme.get_tile_id : i64
+  // ZA5.Q, ZA13.Q
+  // CHECK-NEXT: arith.constant 5
+  %za5_d = arm_sme.get_tile_id : i64
+  // ZA7.Q, ZA15.Q
+  // CHECK-NEXT: arith.constant 7
+  %za7_d = arm_sme.get_tile_id : i64
+  return
+}
+
+// -----
+
+func.func @za_h_overlapping_za_q() {
+  %za0_h = arm_sme.get_tile_id : i16
+  %za0_q = arm_sme.get_tile_id : i128
+  %za2_q = arm_sme.get_tile_id : i128
+  %za4_q = arm_sme.get_tile_id : i128
+  %za6_q = arm_sme.get_tile_id : i128
+  %za8_q = arm_sme.get_tile_id : i128
+  %za10_q = arm_sme.get_tile_id : i128
+  %za12_q = arm_sme.get_tile_id : i128
+  %za14_q = arm_sme.get_tile_id : i128
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i128
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za0_s
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
+func.func @za0_s() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_s = arm_sme.get_tile_id : i32
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_s
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_s() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: arith.constant 1
+  %za1_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: arith.constant 2
+  %za2_s = arm_sme.get_tile_id : i32
+  // CHECK-NEXT: arith.constant 3
+  %za3_s = arm_sme.get_tile_id : i32
+  return
+}
+
+// -----
+
+func.func @za_s__out_of_tiles() {
+  %za0_s = arm_sme.get_tile_id : i32
+  %za1_s = arm_sme.get_tile_id : i32
+  %za2_s = arm_sme.get_tile_id : i32
+  %za3_s = arm_sme.get_tile_id : i32
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i32
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_s_overlapping_za_d
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_s_overlapping_za_d() {
+  // ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
+  // CHECK-NEXT: arith.constant 0
+  %za0_s = arm_sme.get_tile_id : i32
+  // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q
+  // CHECK-NEXT: arith.constant 1
+  %za1_s = arm_sme.get_tile_id : i32
+  // ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q
+  // CHECK-NEXT: arith.constant 2
+  %za2_s = arm_sme.get_tile_id : i32
+  // ZA3.Q, ZA11.Q
+  // CHECK-NEXT: arith.constant 3
+  %za3_d = arm_sme.get_tile_id : i64
+  // ZA7.Q, ZA15.Q
+  // CHECK-NEXT: arith.constant 7
+  %za7_d = arm_sme.get_tile_id : i64
+  return
+}
+
+// -----
+
+func.func @za_s_overlapping_za_q() {
+  %za0_s = arm_sme.get_tile_id : i32
+  %za1_q = arm_sme.get_tile_id : i128
+  %za2_q = arm_sme.get_tile_id : i128
+  %za3_q = arm_sme.get_tile_id : i128
+  %za5_q = arm_sme.get_tile_id : i128
+  %za6_q = arm_sme.get_tile_id : i128
+  %za7_q = arm_sme.get_tile_id : i128
+  %za9_q = arm_sme.get_tile_id : i128
+  %za10_q = arm_sme.get_tile_id : i128
+  %za11_q = arm_sme.get_tile_id : i128
+  %za13_q = arm_sme.get_tile_id : i128
+  %za14_q = arm_sme.get_tile_id : i128
+  %za15_q = arm_sme.get_tile_id : i128
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i128
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za0_d
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
+func.func @za0_d() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_d = arm_sme.get_tile_id : i64
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_d
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_d() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 1
+  %za1_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 2
+  %za2_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 3
+  %za3_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 4
+  %za4_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 5
+  %za5_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 6
+  %za6_d = arm_sme.get_tile_id : i64
+  // CHECK-NEXT: arith.constant 7
+  %za7_d = arm_sme.get_tile_id : i64
+  return
+}
+
+// -----
+
+func.func @za_d__out_of_tiles() {
+  %za0_d = arm_sme.get_tile_id : i64
+  %za1_d = arm_sme.get_tile_id : i64
+  %za2_d = arm_sme.get_tile_id : i64
+  %za3_d = arm_sme.get_tile_id : i64
+  %za4_d = arm_sme.get_tile_id : i64
+  %za5_d = arm_sme.get_tile_id : i64
+  %za6_d = arm_sme.get_tile_id : i64
+  %za7_d = arm_sme.get_tile_id : i64
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i64
+  return
+}
+
+// -----
+
+func.func @za_d_overlapping_za_q() {
+  %za0_d = arm_sme.get_tile_id : i64
+  %za1_q = arm_sme.get_tile_id : i128
+  %za2_q = arm_sme.get_tile_id : i128
+  %za3_q = arm_sme.get_tile_id : i128
+  %za4_q = arm_sme.get_tile_id : i128
+  %za5_q = arm_sme.get_tile_id : i128
+  %za6_q = arm_sme.get_tile_id : i128
+  %za7_q = arm_sme.get_tile_id : i128
+  %za9_q = arm_sme.get_tile_id : i128
+  %za10_q = arm_sme.get_tile_id : i128
+  %za11_q = arm_sme.get_tile_id : i128
+  %za12_q = arm_sme.get_tile_id : i128
+  %za13_q = arm_sme.get_tile_id : i128
+  %za14_q = arm_sme.get_tile_id : i128
+  %za15_q = arm_sme.get_tile_id : i128
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i128
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za0_q
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
+func.func @za0_q() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_q = arm_sme.get_tile_id : i128
+  return
+}
+
+// -----
+
+// CHECK-LABEL: za_q
+// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
+func.func @za_q() {
+  // CHECK-NEXT: arith.constant 0
+  %za0_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 1
+  %za1_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 2
+  %za2_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 3
+  %za3_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 4
+  %za4_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 5
+  %za5_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 6
+  %za6_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 7
+  %za7_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 8
+  %za8_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 9
+  %za9_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 10
+  %za10_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 11
+  %za11_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 12
+  %za12_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 13
+  %za13_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 14
+  %za14_q = arm_sme.get_tile_id : i128
+  // CHECK-NEXT: arith.constant 15
+  %za15_q = arm_sme.get_tile_id : i128
+  return
+}
+
+// -----
+
+func.func @za_q__out_of_tiles() {
+  %za0_q = arm_sme.get_tile_id : i128
+  %za1_q = arm_sme.get_tile_id : i128
+  %za2_q = arm_sme.get_tile_id : i128
+  %za3_q = arm_sme.get_tile_id : i128
+  %za4_q = arm_sme.get_tile_id : i128
+  %za5_q = arm_sme.get_tile_id : i128
+  %za6_q = arm_sme.get_tile_id : i128
+  %za7_q = arm_sme.get_tile_id : i128
+  %za8_q = arm_sme.get_tile_id : i128
+  %za9_q = arm_sme.get_tile_id : i128
+  %za10_q = arm_sme.get_tile_id : i128
+  %za11_q = arm_sme.get_tile_id : i128
+  %za12_q = arm_sme.get_tile_id : i128
+  %za13_q = arm_sme.get_tile_id : i128
+  %za14_q = arm_sme.get_tile_id : i128
+  %za15_q = arm_sme.get_tile_id : i128
+  // expected-error at +2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}}
+  // expected-error at +1 {{ran out of SME virtual tiles!}}
+  %next_tile = arm_sme.get_tile_id : i128
+  return
+}

diff  --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
index 31a49a422192f7..e2991d18a03a1c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir
@@ -1,5 +1,6 @@
-// RUN: mlir-opt %s -convert-vector-to-arm-sme -enable-arm-streaming="mode=locally enable-za" \
-// RUN:   -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \
+// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \
+// RUN:   -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \
+// RUN:   -allocate-arm-sme-tiles -test-lower-to-llvm | \
 // RUN: mlir-translate -mlir-to-llvmir | \
 // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \
 // RUN:   --entry-function=entry \


        


More information about the Mlir-commits mailing list