[Mlir-commits] [mlir] [mlir][ArmSVE] Add `-arm-sve-legalize-vector-storage` pass (PR #68794)

Benjamin Maxwell llvmlistbot at llvm.org
Wed Oct 25 02:49:26 PDT 2023


================
@@ -0,0 +1,169 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -arm-sve-legalize-vector-storage -split-input-file -verify-diagnostics | FileCheck %s
+
+/// This tests the basic functionality of the -arm-sve-legalize-vector-storage pass.
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv1i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[1]xi1>)
+func.func @store_and_reload_sve_predicate_nxv1i1(%mask: vector<[1]xi1>) -> vector<[1]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[1]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[1]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[1]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[1]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[1]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[1]xi1>
+  return %reload : vector<[1]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv2i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[2]xi1>)
+func.func @store_and_reload_sve_predicate_nxv2i1(%mask: vector<[2]xi1>) -> vector<[2]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[2]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[2]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[2]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[2]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[2]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[2]xi1>
+  return %reload : vector<[2]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv4i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[4]xi1>)
+func.func @store_and_reload_sve_predicate_nxv4i1(%mask: vector<[4]xi1>) -> vector<[4]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[4]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[4]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[4]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[4]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[4]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[4]xi1>
+  return %reload : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_and_reload_sve_predicate_nxv8i1(
+// CHECK-SAME:                                         %[[MASK:.*]]: vector<[8]xi1>)
+func.func @store_and_reload_sve_predicate_nxv8i1(%mask: vector<[8]xi1>) -> vector<[8]xi1> {
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<[8]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<[8]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<[8]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[ALLOCA]][] : memref<vector<[16]xi1>>
+  // CHECK-NEXT: %[[MASK:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+  %reload = memref.load %alloca[] : memref<vector<[8]xi1>>
+  // CHECK-NEXT: return %[[MASK]] : vector<[8]xi1>
+  return %reload : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @store_2d_mask_and_reload_slice(
+// CHECK-SAME:                                  %[[MASK:.*]]: vector<3x[8]xi1>)
+func.func @store_2d_mask_and_reload_slice(%mask: vector<3x[8]xi1>) -> vector<[8]xi1> {
+  // CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() {alignment = 2 : i64} : memref<vector<3x[16]xi1>>
+  %alloca = memref.alloca() : memref<vector<3x[8]xi1>>
+  // CHECK-NEXT: %[[SVBOOL:.*]] = arm_sve.convert_to_svbool %[[MASK]] : vector<3x[8]xi1>
+  // CHECK-NEXT: memref.store %[[SVBOOL]], %[[ALLOCA]][] : memref<vector<3x[16]xi1>>
+  memref.store %mask, %alloca[] : memref<vector<3x[8]xi1>>
+  // CHECK-NEXT: %[[UNPACK:.*]] = vector.type_cast %[[ALLOCA]] : memref<vector<3x[16]xi1>> to memref<3xvector<[16]xi1>>
+  %unpack = vector.type_cast %alloca : memref<vector<3x[8]xi1>> to memref<3xvector<[8]xi1>>
+  // CHECK-NEXT: %[[RELOAD:.*]] = memref.load %[[UNPACK]][%[[C0]]] : memref<3xvector<[16]xi1>>
+  // CHECK-NEXT: %[[SLICE:.*]] = arm_sve.convert_from_svbool %[[RELOAD]] : vector<[8]xi1>
+  %slice = memref.load %unpack[%c0] : memref<3xvector<[8]xi1>>
+  // CHECK-NEXT: return %[[SLICE]] : vector<[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @set_sve_alloca_alignment
+func.func @set_sve_alloca_alignment() {
+  /// This checks the alignment of alloca's of scalable vectors will be
+  /// something the backend can handle. Currently, the backend sets the
+  /// alignment of scalable vectors to their base size (i.e. their size at
+  /// vscale = 1). This works for hardware-sized types, which always get a
+  /// 16-byte alignment. The problem is larger types e.g. vector<[8]xf32> end up
+  /// with alignments larger than 16-bytes (e.g. 32-bytes here), which are
+  /// unsupported. This pass avoids this issue by explicitly setting the
+  /// alignment to 16-bytes for all scalable vectors.
+
+  // CHECK-COUNT-6: alignment = 16
+  %a1 = memref.alloca() : memref<vector<[32]xi8>>
----------------
MacDue wrote:

Those are covered by the other tests in this file. It can't be tested alone as the pass tries to replace them with svbools, and will fail here as it can't update the "prevent.dce" op. 

https://github.com/llvm/llvm-project/pull/68794


More information about the Mlir-commits mailing list