[Mlir-commits] [mlir] [mlir][ArmSME][test] Prepare tests for tile allocation changes (PR #91358)
Benjamin Maxwell
llvmlistbot at llvm.org
Tue May 7 09:46:08 PDT 2024
https://github.com/MacDue created https://github.com/llvm/llvm-project/pull/91358
This patch:
1. Removes some duplicate test cases
2. Removes unnecessary uses of `-convert-arm-sme-to-llvm`
3. Ensures tile values have uses via `test.some_use()`
1 and 2 will make these tests easier to update. 3 will be needed as ArmSME operations will be pure.
>From 85223c6b6533a56183ba67e62c114bf769389703 Mon Sep 17 00:00:00 2001
From: Benjamin Maxwell <benjamin.maxwell at arm.com>
Date: Tue, 7 May 2024 16:28:29 +0000
Subject: [PATCH] [mlir][ArmSME][test] Prepare tests for tile allocation
changes
This patch:
1. Removes some duplicate test cases
2. Removes unnecessary uses of `-convert-arm-sme-to-llvm`
3. Ensures tile values have uses via `test.some_use()`
1 and 2 will make these tests easier to update. 3 will be needed as
ArmSME operations will be pure.
---
.../ArmSMEToLLVM/arm-sme-to-llvm.mlir | 24 +-
.../Conversion/ArmSMEToLLVM/unsupported.mlir | 2 +-
.../ArmSMEToSCF/arm-sme-to-scf.mlir | 6 +
...cation.mlir => basic-tile-allocation.mlir} | 297 ++++++++++--------
mlir/test/Dialect/ArmSME/enable-arm-za.mlir | 20 +-
.../Dialect/ArmSME/outer-product-fusion.mlir | 7 +-
mlir/test/Dialect/ArmSME/tile-zero-masks.mlir | 43 ++-
7 files changed, 234 insertions(+), 165 deletions(-)
rename mlir/test/Dialect/ArmSME/{tile-allocation.mlir => basic-tile-allocation.mlir} (52%)
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
index 81087cc02099..f48046a8d799 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir
@@ -25,6 +25,7 @@ func.func @arm_sme_load_tile_slice_hor_i8(%src : memref<?x?xi8>, %mask : vector<
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -36,6 +37,7 @@ func.func @arm_sme_load_tile_slice_hor_i16(%src : memref<?x?xi16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
return
}
@@ -47,6 +49,7 @@ func.func @arm_sme_load_tile_slice_hor_i32(%src : memref<?x?xi32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -58,6 +61,7 @@ func.func @arm_sme_load_tile_slice_hor_i64(%src : memref<?x?xi64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
return
}
@@ -69,6 +73,7 @@ func.func @arm_sme_load_tile_slice_hor_i128(%src : memref<?x?xi128>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+ "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
return
}
@@ -80,6 +85,7 @@ func.func @arm_sme_load_tile_slice_hor_f16(%src : memref<?x?xf16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -91,6 +97,7 @@ func.func @arm_sme_load_tile_slice_hor_bf16(%src : memref<?x?xbf16>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
return
}
@@ -102,6 +109,7 @@ func.func @arm_sme_load_tile_slice_hor_f32(%src : memref<?x?xf32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -113,6 +121,7 @@ func.func @arm_sme_load_tile_slice_hor_f64(%src : memref<?x?xf64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
return
}
@@ -124,6 +133,7 @@ func.func @arm_sme_load_tile_slice_ver_i8(%src : memref<?x?xi8>, %mask : vector<
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[16]x[16]xi8>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi8>, vector<[16]xi1>, vector<[16]x[16]xi8>
+ "test.some_use" (%tile_update) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -135,6 +145,7 @@ func.func @arm_sme_load_tile_slice_ver_i16(%src : memref<?x?xi16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xi16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi16>, vector<[8]xi1>, vector<[8]x[8]xi16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xi16>) -> ()
return
}
@@ -146,6 +157,7 @@ func.func @arm_sme_load_tile_slice_ver_i32(%src : memref<?x?xi32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi32>, vector<[4]xi1>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -157,6 +169,7 @@ func.func @arm_sme_load_tile_slice_ver_i64(%src : memref<?x?xi64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xi64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi64>, vector<[2]xi1>, vector<[2]x[2]xi64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xi64>) -> ()
return
}
@@ -168,6 +181,7 @@ func.func @arm_sme_load_tile_slice_ver_i128(%src : memref<?x?xi128>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[1]x[1]xi128>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xi128>, vector<[1]xi1>, vector<[1]x[1]xi128>
+ "test.some_use" (%tile_update) : (vector<[1]x[1]xi128>) -> ()
return
}
@@ -179,6 +193,7 @@ func.func @arm_sme_load_tile_slice_ver_f16(%src : memref<?x?xf16>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf16>, vector<[8]xi1>, vector<[8]x[8]xf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -190,6 +205,7 @@ func.func @arm_sme_load_tile_slice_ver_bf16(%src : memref<?x?xbf16>, %mask : vec
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xbf16>, vector<[8]xi1>, vector<[8]x[8]xbf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
return
}
@@ -201,6 +217,7 @@ func.func @arm_sme_load_tile_slice_ver_f32(%src : memref<?x?xf32>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xf32>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -212,6 +229,7 @@ func.func @arm_sme_load_tile_slice_ver_f64(%src : memref<?x?xf64>, %mask : vecto
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[2]x[2]xf64>
%tile_update = arm_sme.load_tile_slice %src[%c0], %mask, %tile, %tile_slice_index layout<vertical> : memref<?x?xf64>, vector<[2]xi1>, vector<[2]x[2]xf64>
+ "test.some_use" (%tile_update) : (vector<[2]x[2]xf64>) -> ()
return
}
@@ -441,7 +459,8 @@ func.func @arm_sme_store_tile_slice_ver_f64(%tile_slice_index : index, %mask : v
func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[4]x[4]xi32>
- arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index : vector<[4]xi32> into vector<[4]x[4]xi32>
+ "test.some_use" (%tile_update) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -452,7 +471,8 @@ func.func @arm_sme_move_vector_to_tile_slice_hor_i32(%vector : vector<[4]xi32>,
func.func @arm_sme_move_vector_to_tile_slice_ver_bf16(%vector : vector<[8]xbf16>, %tile_slice_index : index) -> () {
%c0 = arith.constant 0 : index
%tile = arm_sme.get_tile : vector<[8]x[8]xbf16>
- arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ %tile_update = arm_sme.move_vector_to_tile_slice %vector, %tile, %tile_slice_index layout<vertical> : vector<[8]xbf16> into vector<[8]x[8]xbf16>
+ "test.some_use" (%tile_update) : (vector<[8]x[8]xbf16>) -> ()
return
}
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
index 59665c471921..15767ff1dec3 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir
@@ -9,6 +9,6 @@ func.func @arm_sme_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs :
// expected-error at +2 {{failed to legalize operation 'arm_sme.outerproduct'}}
// expected-error at +1 {{unsupported type}}
%0 = arm_sme.outerproduct %lhs, %rhs acc(%acc) : vector<[16]xi8>, vector<[16]xi8>
- "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> ()
+ "test.some_use"(%0) : (vector<[16]x[16]xi8>) -> ()
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index 6c393bc38af9..a2f2beff78c4 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -20,6 +20,7 @@
func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -30,6 +31,7 @@ func.func @arm_sme_tile_load_hor(%src : memref<?x?xi32>) {
func.func @arm_sme_tile_load_ver(%src : memref<?x?xi32>) {
%c0 = arith.constant 0 : index
%tile = arm_sme.tile_load %src[%c0, %c0] layout<vertical> : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -60,6 +62,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_pad_zero(%src : memref<?x?xi32>)
%pad = arith.constant 0 : i32
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -94,6 +97,7 @@ func.func @arm_sme_tile_load_hor_with_mask_and_nonzero_pad(%src : memref<?x?xi32
%c3 = arith.constant 3 : index
%mask = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1>
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -104,6 +108,7 @@ func.func @arm_sme_tile_load_zero_pad__unsupported_mask_op(%src : memref<?x?xi32
%pad = arith.constant 0 : i32
// expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
@@ -113,6 +118,7 @@ func.func @arm_sme_tile_load_nonzero_pad__unsupported_mask_op(%src : memref<?x?x
%c0 = arith.constant 0 : index
// expected-error at +1 {{failed to legalize operation 'arm_sme.tile_load' that was explicitly marked illegal}}
%tile = arm_sme.tile_load %src[%c0, %c0], %pad, %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
+ "test.some_use" (%tile) : (vector<[4]x[4]xi32>) -> ()
return
}
diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
similarity index 52%
rename from mlir/test/Dialect/ArmSME/tile-allocation.mlir
rename to mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
index 9c368dd4fa23..e144bac970a7 100644
--- a/mlir/test/Dialect/ArmSME/tile-allocation.mlir
+++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir
@@ -1,9 +1,10 @@
-// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s
// -----
+// Note: Tile IDs >= 16 are in-memory tile IDs (i.e. spills).
+
// CHECK-LABEL: mixed_tiles
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32}
func.func @mixed_tiles() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// CHECK-NEXT: tile_id = 0
@@ -18,76 +19,61 @@ func.func @mixed_tiles() {
// CHECK-NEXT: tile_id = 7
%za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// ZA15.Q is still free.
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_b
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_b() {
// CHECK-NEXT: tile_id = 0
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- return
-}
-
-// -----
-
-func.func @za_b__out_of_tiles() {
- %za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[16]x[16]xi8>
+ "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+ "test.some_use"(%next_tile) : (vector<[16]x[16]xi8>) -> ()
return
}
// -----
+// CHECK-LABEL: za_b_overlapping_za_q
func.func @za_b_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_b = arm_sme.get_tile : vector<[16]x[16]xi8>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
- return
-}
-
-// -----
-
-// CHECK-LABEL: za0_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32}
-func.func @za0_h() {
- // CHECK-NEXT: tile_id = 0
- %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
+ "test.some_use"(%za0_b) : (vector<[16]x[16]xi8>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_h
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h() {
// CHECK-NEXT: tile_id = 0
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
// CHECK-NEXT: tile_id = 1
%za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- return
-}
-
-// -----
-
-// CHECK-LABEL: za_h__out_of_tiles
-func.func @za_h__out_of_tiles() {
- // CHECK-NEXT: tile_id = 0
- %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- // CHECK-NEXT: tile_id = 1
- %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[8]x[8]xi16>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%next_tile) : (vector<[8]x[8]xi16>) -> ()
return
}
// -----
// CHECK-LABEL: za_h_overlapping_za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h_overlapping_za_s() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// CHECK-NEXT: tile_id = 0
@@ -98,13 +84,15 @@ func.func @za_h_overlapping_za_s() {
// ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q
// CHECK-NEXT: tile_id = 3
%za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
// CHECK-LABEL: za_h_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_h_overlapping_za_d() {
// ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q
// CHECK-NEXT: tile_id = 0
@@ -121,40 +109,55 @@ func.func @za_h_overlapping_za_d() {
// ZA7.Q, ZA15.Q
// CHECK-NEXT: tile_id = 7
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
+// CHECK-LABEL: za_h_overlapping_za_q
func.func @za_h_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_h = arm_sme.get_tile : vector<[8]x[8]xi16>
- %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // CHECK-NEXT: tile_id = 1
+ %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
+ %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
+ %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
+ %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
+ %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
+ %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
+ %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
+ %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
- return
-}
-
-// -----
-
-// CHECK-LABEL: za0_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32}
-func.func @za0_s() {
- // CHECK-NEXT: tile_id = 0
- %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "test.some_use"(%za0_h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_s
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_s() {
// CHECK-NEXT: tile_id = 0
%za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
@@ -164,25 +167,20 @@ func.func @za_s() {
%za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK-NEXT: tile_id = 3
%za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- return
-}
-
-// -----
-
-func.func @za_s__out_of_tiles() {
- %za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- %za1_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- %za2_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- %za3_s = arm_sme.get_tile : vector<[4]x[4]xi32>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[4]x[4]xi32>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za2_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%next_tile) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
// CHECK-LABEL: za_s_overlapping_za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_s_overlapping_za_d() {
// ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q
// CHECK-NEXT: tile_id = 0
@@ -199,44 +197,67 @@ func.func @za_s_overlapping_za_d() {
// ZA7.Q, ZA15.Q
// CHECK-NEXT: tile_id = 7
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za1_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za2_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
+// CHECK-LABEL: za_s_overlapping_za_q
func.func @za_s_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_s = arm_sme.get_tile : vector<[4]x[4]xi32>
+ // CHECK-NEXT: tile_id = 1
%za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 2
%za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
%za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
%za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 6
%za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
%za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
%za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 10
%za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
%za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 14
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
- return
-}
-
-// -----
-
-// CHECK-LABEL: za0_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32}
-func.func @za0_d() {
- // CHECK-NEXT: tile_id = 0
- %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_d
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_d() {
// CHECK-NEXT: tile_id = 0
%za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
@@ -254,62 +275,80 @@ func.func @za_d() {
%za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
// CHECK-NEXT: tile_id = 7
%za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- return
-}
-
-// -----
-
-func.func @za_d__out_of_tiles() {
- %za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za1_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za2_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za3_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za4_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za5_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za6_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- %za7_d = arm_sme.get_tile : vector<[2]x[2]xi64>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[2]x[2]xi64>
+ "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za1_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za2_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za3_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za4_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za5_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za6_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za7_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%next_tile) : (vector<[2]x[2]xi64>) -> ()
return
}
// -----
+// CHECK-LABEL: za_d_overlapping_za_q
func.func @za_d_overlapping_za_q() {
+ // CHECK-NEXT: tile_id = 0
%za0_d = arm_sme.get_tile : vector<[2]x[2]xi64>
+ // CHECK-NEXT: tile_id = 1
%za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 2
%za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 3
%za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 4
%za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 5
%za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 6
%za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 7
%za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 9
%za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 10
%za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 11
%za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 12
%za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 13
%za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 14
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ // CHECK-NEXT: tile_id = 15
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
- return
-}
-
-// -----
-
-// CHECK-LABEL: za0_q
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32}
-func.func @za0_q() {
- // CHECK-NEXT: tile_id = 0
- %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za4_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za12_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
// -----
// CHECK-LABEL: za_q
-// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32}
func.func @za_q() {
// CHECK-NEXT: tile_id = 0
%za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
@@ -343,29 +382,25 @@ func.func @za_q() {
%za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
// CHECK-NEXT: tile_id = 15
%za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- return
-}
-
-// -----
-
-func.func @za_q__out_of_tiles() {
- %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za1_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za2_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za3_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za4_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za5_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za6_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za7_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za8_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za9_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za10_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za11_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za12_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za13_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za14_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- %za15_q = arm_sme.get_tile : vector<[1]x[1]xi128>
- // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}}
+ // Next tile is in-memory:
+ // CHECK-NEXT: tile_id = 16
%next_tile = arm_sme.get_tile : vector<[1]x[1]xi128>
+ "test.some_use"(%za0_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za1_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za2_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za3_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za4_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za5_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za6_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za7_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za8_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za9_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za10_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za11_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za12_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za13_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za14_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%za15_q) : (vector<[1]x[1]xi128>) -> ()
+ "test.some_use"(%next_tile) : (vector<[1]x[1]xi128>) -> ()
return
}
diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
index a20203d7e557..d3325513a848 100644
--- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
+++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir
@@ -1,10 +1,9 @@
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=IN-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=OUT-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=INOUT-ZA
-// RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=PRESERVES-ZA
-// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za | FileCheck %s -check-prefix=ENABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming | FileCheck %s -check-prefix=DISABLE-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=in-za | FileCheck %s -check-prefix=IN-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=out-za | FileCheck %s -check-prefix=OUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=inout-za | FileCheck %s -check-prefix=INOUT-ZA
+// RUN: mlir-opt %s -enable-arm-streaming=za-mode=preserves-za | FileCheck %s -check-prefix=PRESERVES-ZA
// CHECK-LABEL: @declaration
func.func private @declaration()
@@ -22,11 +21,4 @@ func.func private @declaration()
// DISABLE-ZA-LABEL: @arm_new_za
// DISABLE-ZA-NOT: arm_new_za
// DISABLE-ZA-SAME: attributes {arm_streaming}
-// NO-ARM-STREAMING-LABEL: @arm_new_za
-// NO-ARM-STREAMING-NOT: arm_new_za
-// NO-ARM-STREAMING-NOT: arm_streaming
-// NO-ARM-STREAMING-NOT: arm_in_za
-// NO-ARM-STREAMING-NOT: arm_out_za
-// NO-ARM-STREAMING-NOT: arm_inout_za
-// NO-ARM-STREAMING-NOT: arm_preserves_za
func.func @arm_new_za() { return }
diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
index 01f54a4cf186..4887d611643f 100644
--- a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
+++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file | FileCheck %s
// CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32
// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>,
@@ -929,6 +929,7 @@ func.func @outerproduct_widening_4way__missing_acc(
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
// Missing accumulator breaks use-def chain.
%3 = arm_sme.outerproduct %a3_ext, %b3_ext : vector<[4]xi32>, vector<[4]xi32>
+ "test.some_use"(%2) : (vector<[4]x[4]xi32>) -> ()
return %3 : vector<[4]x[4]xi32>
}
@@ -1014,7 +1015,7 @@ func.func @outerproduct_widening_2way__cant_erase(
%acc = arith.constant dense<1.0> : vector<[4]x[4]xf32>
%0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32>
- "fake.use"(%0) : (vector<[4]x[4]xf32>) -> ()
+ "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> ()
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
return %1 : vector<[4]x[4]xf32>
@@ -1048,7 +1049,7 @@ func.func @outerproduct_widening_4way__multi_use_cant_erase(
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
- "fake.use"(%1) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> ()
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
%3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index 04412e4db1c5..cac2dcc24d10 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -9,6 +9,7 @@
func.func @zero_za_b() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
%zero_za0b = arm_sme.zero : vector<[16]x[16]xi8>
+ "test.some_use"(%zero_za0b) : (vector<[16]x[16]xi8>) -> ()
return
}
@@ -16,10 +17,12 @@ func.func @zero_za_b() {
// CHECK-LABEL: zero_za_h
func.func @zero_za_h() {
- // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
+ "test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
+ "test.some_use"(%zero_za1h) : (vector<[8]x[8]xf16>) -> ()
return
}
@@ -27,14 +30,18 @@ func.func @zero_za_h() {
// CHECK-LABEL: zero_za_s
func.func @zero_za_s() {
- // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
- // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
- // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
- // CHECK-NEXT: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
+ // CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
return
}
@@ -42,21 +49,29 @@ func.func @zero_za_s() {
// CHECK-LABEL: zero_za_d
func.func @zero_za_d() {
- // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
- // CHECK-NEXT: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
+ "test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za1d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za2d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za3d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za4d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za5d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za6d) : (vector<[2]x[2]xi64>) -> ()
+ "test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
More information about the Mlir-commits
mailing list