[Mlir-commits] [mlir] [mlir][ArmSME][test] Prepare tests for tile allocation changes (PR #91358)

Benjamin Maxwell llvmlistbot at llvm.org
Tue May 7 09:46:08 PDT 2024


https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/91358

This patch:

 1. Removes some duplicate test cases
 2. Removes unnecessary uses of `-convert-arm-sme-to-llvm`
 3. Ensures tile values have uses via `test.some_use()`

1 and 2 will make these tests easier to update. 3 will be needed as ArmSME operations will be pure.

>From 85223c6b6533a56183ba67e62c114bf769389703 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 May 2024 16:28:29 +0000
Subject: [PATCH] [mlir][ArmSME][test] Prepare tests for tile allocation
 changes

This patch:

 1. Removes some duplicate test cases
 2. Removes unnecessary uses of `-convert-arm-sme-to-llvm`
 3. Ensures tile values have uses via `test.some_use()`

1 and 2 will make these tests easier to update. 3 will be needed as
ArmSME operations will be pure.
---
 .../ArmSMEToLLVM/arm-sme-to-llvm.mlir         |  24 +-
 .../Conversion/ArmSMEToLLVM/unsupported.mlir  |   2 +-
 .../ArmSMEToSCF/arm-sme-to-scf.mlir           |   6 +
 ...cation.mlir => basic-tile-allocation.mlir} | 297 ++++++++++--------
 mlir/test/Dialect/ArmSME/enable-arm-za.mlir   |  20 +-
 .../Dialect/ArmSME/outer-product-fusion.mlir  |   7 +-
 mlir/test/Dialect/ArmSME/tile-zero-masks.mlir |  43 ++-
 7 files changed, 234 insertions(+), 165 deletions(-)
 rename mlir/test/Dialect/ArmSME/{tile-allocation.mlir => basic-tile-allocation.mlir} (52%)

diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 81087cc02099..f48046a8d799 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -25,6 +25,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -36,6 +37,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
@@ -47,6 +49,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -58,6 +61,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
@@ -69,6 +73,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+  "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
@@ -80,6 +85,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -91,6 +97,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
@@ -102,6 +109,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -113,6 +121,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
@@ -124,6 +133,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[16]x[16]xi8>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+  "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -135,6 +145,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xi16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
@@ -146,6 +157,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -157,6 +169,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xi64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
@@ -168,6 +181,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[1]x[1]xi128>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+  "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
@@ -179,6 +193,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -190,6 +205,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
@@ -201,6 +217,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xf32>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -212,6 +229,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[2]x[2]xf64>
   %tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+  "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
   return
 }
 
@@ -441,7 +459,8 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : v
 func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[4]x[4]xi32>
-  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+  "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -452,7 +471,8 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
 func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
-  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  %tile_update =  arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+  "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
   return
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 59665c471921..15767ff1dec3 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -9,6 +9,6 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs :
   // expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
   // expected-error at +1 {{unsupported type}}
   %0 = arm_sme.outerproduct %lhs, %rhs  acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
-  "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
 }
 
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 6c393bc38af9..a2f2beff78c4 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -20,6 +20,7 @@
 func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -30,6 +31,7 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
 func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
   %c0 = arith.constant 0 : index
   %tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -60,6 +62,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
   %pad = arith.constant 0 : i32
   %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -94,6 +97,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
   %c3 = arith.constant 3 : index
   %mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -104,6 +108,7 @@ func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32
   %pad = arith.constant 0 : i32
   // expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
@@ -113,6 +118,7 @@ func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?x
   %c0 = arith.constant 0 : index
   // expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
   %tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+  "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
similarity index 52%
rename from mlir/test/Dialect/ArmSME/tile-allocation.mlir
rename to mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index 9c368dd4fa23..e144bac970a7 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
 
 // -----
 
+// Note: Tile IDs >= 16 are in-memory tile IDs (i.e. spills).
+
 // CHECK-LABEL: mixed_tiles
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
 func.func @mixed_tiles() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -18,76 +19,61 @@ func.func @mixed_tiles() {
   // CHECK-NEXT: tile_id = 7
   %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // ZA15.Q is still free.
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_b
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_b() {
   // CHECK-NEXT: tile_id = 0
   %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  return
-}
-
-// -----
-
-func.func @za_b__out_of_tiles() {
-  %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
+  "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%next_tile) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_b_overlapping_za_q
 func.func @za_b_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
-func.func @za0_h() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+  "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h() {
   // CHECK-NEXT: tile_id = 0
   %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
   // CHECK-NEXT: tile_id = 1
   %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za_h__out_of_tiles
-func.func @za_h__out_of_tiles() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  // CHECK-NEXT: tile_id = 1
-  %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%next_tile) : (vector<[8]x[8]xi16>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h_overlapping_za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_s() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -98,13 +84,15 @@ func.func @za_h_overlapping_za_s() {
   // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 3
   %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_h_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_h_overlapping_za_d() {
   // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
   // CHECK-NEXT: tile_id = 0
@@ -121,40 +109,55 @@ func.func @za_h_overlapping_za_d() {
   // ZA7.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 7
   %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_h_overlapping_za_q
 func.func @za_h_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
-  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // CHECK-NEXT: tile_id = 1
+  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
+  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
+  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
+  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
+  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
+  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
+  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
+  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
-func.func @za0_s() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s() {
   // CHECK-NEXT: tile_id = 0
   %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
@@ -164,25 +167,20 @@ func.func @za_s() {
   %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
   // CHECK-NEXT: tile_id = 3
   %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  return
-}
-
-// -----
-
-func.func @za_s__out_of_tiles() {
-  %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+  "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za2_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%next_tile) : (vector<[4]x[4]xi32>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_s_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_s_overlapping_za_d() {
   // ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
   // CHECK-NEXT: tile_id = 0
@@ -199,44 +197,67 @@ func.func @za_s_overlapping_za_d() {
   // ZA7.Q, ZA15.Q
   // CHECK-NEXT: tile_id = 7
   %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za2_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_s_overlapping_za_q
 func.func @za_s_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+  // CHECK-NEXT: tile_id = 1
   %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 2
   %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
   %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
   %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 6
   %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
   %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
   %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 10
   %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
   %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
   %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 14
   %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
   %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
-func.func @za0_d() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_d() {
   // CHECK-NEXT: tile_id = 0
   %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
@@ -254,62 +275,80 @@ func.func @za_d() {
   %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
   // CHECK-NEXT: tile_id = 7
   %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  return
-}
-
-// -----
-
-func.func @za_d__out_of_tiles() {
-  %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
+  "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za2_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za4_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za6_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%next_tile) : (vector<[2]x[2]xi64>) -> ()
   return
 }
 
 // -----
 
+// CHECK-LABEL: za_d_overlapping_za_q
 func.func @za_d_overlapping_za_q() {
+  // CHECK-NEXT: tile_id = 0
   %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+  // CHECK-NEXT: tile_id = 1
   %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 2
   %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 3
   %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 4
   %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 5
   %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 6
   %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 7
   %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 9
   %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 10
   %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 11
   %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 12
   %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 13
   %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 14
   %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  // CHECK-NEXT: tile_id = 15
   %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-// CHECK-LABEL: za0_q
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
-func.func @za0_q() {
-  // CHECK-NEXT: tile_id = 0
-  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+  "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za4_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za12_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
 
 // -----
 
 // CHECK-LABEL: za_q
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
 func.func @za_q() {
   // CHECK-NEXT: tile_id = 0
   %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
@@ -343,29 +382,25 @@ func.func @za_q() {
   %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
   // CHECK-NEXT: tile_id = 15
   %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  return
-}
-
-// -----
-
-func.func @za_q__out_of_tiles() {
-  %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
-  // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+  // Next tile is in-memory:
+  // CHECK-NEXT: tile_id = 16
   %next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+  "test.some_use"(%za0_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za4_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za8_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za12_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+  "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
   return
 }
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index a20203d7e557..d3325513a848 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,10 +1,9 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=IN-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=OUT-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=INOUT-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=PRESERVES-ZA
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za | FileCheck %s -check-prefix=IN-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za | FileCheck %s -check-prefix=OUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za | FileCheck %s -check-prefix=INOUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za | FileCheck %s -check-prefix=PRESERVES-ZA
 
 // CHECK-LABEL: @declaration
 func.func private @declaration()
@@ -22,11 +21,4 @@ func.func private @declaration()
 // DISABLE-ZA-LABEL: @arm_new_za
 // DISABLE-ZA-NOT: arm_new_za
 // DISABLE-ZA-SAME: attributes {arm_streaming}
-// NO-ARM-STREAMING-LABEL: @arm_new_za
-// NO-ARM-STREAMING-NOT: arm_new_za
-// NO-ARM-STREAMING-NOT: arm_streaming
-// NO-ARM-STREAMING-NOT: arm_in_za
-// NO-ARM-STREAMING-NOT: arm_out_za
-// NO-ARM-STREAMING-NOT: arm_inout_za
-// NO-ARM-STREAMING-NOT: arm_preserves_za
 func.func @arm_new_za() { return }
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 01f54a4cf186..4887d611643f 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file | FileCheck %s
 
 // CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32
 // CHECK-SAME:    %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
@@ -929,6 +929,7 @@ func.func @outerproduct_widening_4way__missing_acc(
   %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
   // Missing accumulator breaks use-def chain.
   %3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
+  "test.some_use"(%2) : (vector<[4]x[4]xi32>) -> ()
 
   return %3 : vector<[4]x[4]xi32>
 }
@@ -1014,7 +1015,7 @@ func.func @outerproduct_widening_2way__cant_erase(
 
   %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
   %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
-  "fake.use"(%0) : (vector<[4]x[4]xf32>) -> ()
+  "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
   %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
 
   return %1 : vector<[4]x[4]xf32>
@@ -1048,7 +1049,7 @@ func.func @outerproduct_widening_4way__multi_use_cant_erase(
 
   %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
   %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
-  "fake.use"(%1) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
   %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
   %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
 
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 04412e4db1c5..cac2dcc24d10 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -9,6 +9,7 @@
 func.func @zero_za_b() {
   // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
   %zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
+  "test.some_use"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
   return
 }
 
@@ -16,10 +17,12 @@ func.func @zero_za_b() {
 
 // CHECK-LABEL: zero_za_h
 func.func @zero_za_h() {
-  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
   %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
   %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
+  "test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
+  "test.some_use"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
   return
 }
 
@@ -27,14 +30,18 @@ func.func @zero_za_h() {
 
 // CHECK-LABEL: zero_za_s
 func.func @zero_za_s() {
-  // CHECK:      arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
+  // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
   %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
-  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
+  // CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
   %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
-  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
+  // CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
   %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
-  // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
+  // CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
   %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+  "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+  "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
   return
 }
 
@@ -42,21 +49,29 @@ func.func @zero_za_s() {
 
 // CHECK-LABEL: zero_za_d
 func.func @zero_za_d() {
-  // CHECK:      "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
   %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
   %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
   %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
   %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
   %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
   %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
   %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
-  // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
+  // CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
   %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
+  "test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
+  "test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
   return
 }



More information about the Mlir-commits mailing list