[Mlir-commits] [mlir] [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage` pass (PR #68794)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Wed Oct 25 00:59:39 PDT 2023
================
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s
+
+/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass.
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[1]xi1>)
+func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[1]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[1]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[1]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1>
+ return %reload : vector<[1]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[2]xi1>)
+func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[2]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[2]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[2]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1>
+ return %reload : vector<[2]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>)
+func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[4]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1>
+ return %reload : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
+// CHECK-SAME: %[[MASK:.*]]: vector<[8]xi1>)
+func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<[8]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+ // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+ %reload = memref.load %alloca[] : memref<vector<[8]xi1>>
+ // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1>
+ return %reload : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_2d_mask_and_reload_slice(
+// CHECK-SAME: %[[MASK:.*]]: vector<3x[8]xi1>)
+func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
+ // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+ %c0 = arith.constant 0 : index
+ // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>>
+ %alloca = memref.alloca() : memref<vector<3x[8]xi1>>
+ // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
+ // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
+ memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>>
+ // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
+ %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
+ // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>>
+ // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+ %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>>
+ // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @set_sve_alloca_alignment
+func.func @set_sve_alloca_alignment() {
+ /// This checks the alignment of alloca's of scalable vectors will be
+ /// something the backend can handle. Currently, the backend sets the
+ /// alignment of scalable vectors to their base size (i.e. their size at
+ /// vscale = 1). This works for hardware-sized types, which always get a
+ /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up
+ /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are
+ /// unsupported. This pass avoids this issue by explicitly setting the
+ /// alignment to 16-bytes for all scalable vectors.
+
+ // CHECK-COUNT-6: alignment = 16
+ %a1 = memref.alloca() : memref<vector<[32]xi8>>
----------------
banach-space wrote:
I guess we should test for `i1` as well?
https://github.com/llvm/llvm-project/pull/68794
More information about the Mlir-commits
mailing list