[Mlir-commits] [mlir] [mlir][ArmSME][test] Prepare tests for tile allocation changes (PR #91358)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue May 7 09:46:38 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

<details>
<summary>Changes</summary>

This patch:

 1. Removes some duplicate test cases
 2. Removes unnecessary uses of `-convert-arm-sme-to-llvm`
 3. Ensures tile values have uses via `test.some_use()`

1 and 2 will make these tests easier to update. 3 will be needed as ArmSME operations will be pure.

---

Patch is 38.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91358.diff


7 Files Affected:

- (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+22-2) 
- (modified) mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir (+1-1) 
- (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (+6) 
- (renamed) mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir (+166-131) 
- (modified) mlir/test/Dialect/ArmSME/enable-arm-za.mlir (+6-14) 
- (modified) mlir/test/Dialect/ArmSME/outer-product-fusion.mlir (+4-3) 
- (modified) mlir/test/Dialect/ArmSME/tile-zero-masks.mlir (+29-14) 


``````````diff
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 81087cc02099..f48046a8d799 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -25,6 +25,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -36,6 +37,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
@@ -47,6 +49,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -58,6 +61,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
@@ -69,6 +73,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+  "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
@@ -80,6 +85,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -91,6 +97,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
@@ -102,6 +109,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -113,6 +121,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
@@ -124,6 +133,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -135,6 +145,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
@@ -146,6 +157,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -157,6 +169,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
@@ -168,6 +181,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+  "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
@@ -179,6 +193,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -190,6 +205,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
@@ -201,6 +217,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -212,6 +229,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
@@ -441,7 +459,8 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : v
 func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -452,7 +471,8 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
 func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  %tile_update =  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 59665c471921..15767ff1dec3 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -9,6 +9,6 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs :
   // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = arm_sme.outerproduct %lhs, %rhs  acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
-  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 6c393bc38af9..a2f2beff78c4 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -20,6 +20,7 @@
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -30,6 +31,7 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
 func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -60,6 +62,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
   %pad = arith.constant 0 : i32
   %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -94,6 +97,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
   %c3 = arith.constant 3 : index
   %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -104,6 +108,7 @@ func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32
   %pad = arith.constant 0 : i32
   // expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -113,6 +118,7 @@ func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?x
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
similarity index 52%
rename from mlir/test/Dialect/ArmSME/tile-allocation.mlir
rename to mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index 9c368dd4fa23..e144bac970a7 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
 
 // -----
 
+// Note: Tile IDs >= 16 are in-memory tile IDs (i.e. spills).
+
 // CHECK-LABEL: mixed_tiles
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
 func.func @mixed_tiles() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -18,76 +19,61 @@ func.func @mixed_tiles() {
   // CHECK-NEXT: tile_id = 7
   %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // ZA15.Q is still free.
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_b
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_b() {
   // CHECK-NEXT: tile_id = 0
   %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  return
-}
-
-// -----
-
-func.func @za_b__out_of_tiles() {
-  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%next_tile) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_b_overlapping_za_q
 func.func @za_b_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
-func.func @za0_h() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h() {
   // CHECK-NEXT: tile_id = 0
   %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // CHECK-NEXT: tile_id = 1
   %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za_h__out_of_tiles
-func.func @za_h__out_of_tiles() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  // CHECK-NEXT: tile_id = 1
-  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%next_tile) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h_overlapping_za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_s() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -98,13 +84,15 @@ func.func @za_h_overlapping_za_s() {
   // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 3
   %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_d() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -121,40 +109,55 @@ func.func @za_h_overlapping_za_d() {
   // ZA7.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 7
   %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_h_overlapping_za_q
 func.func @za_h_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // CHECK-NEXT: tile_id = 1
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
-func.func @za0_s() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s() {
   // CHECK-NEXT: tile_id = 0
   %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
@@ -164,25 +167,20 @@ func.func @za_s() {
   %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK-NEXT: tile_id = 3
   %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  return
-}
-
-// -----
-
-func.func @za_s__out_of_tiles() {
-  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/91358


More information about the Mlir-commits mailing list