[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Jan 26 03:56:32 PST 2024


================
@@ -0,0 +1,171 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
+// RUN:   -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN:   -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN:   -convert-arm-sme-to-llvm \
+// RUN:   -convert-vector-to-llvm=enable-arm-sve \
+// RUN:   -cse -canonicalize -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN:   -e=main -entry-point-result=void \
+// RUN:   -march=aarch64 -mattr="+sve,+sme" \
+// RUN:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %rows = memref.dim %memref, %c0 : memref<?x?xf32>
+  %cols = memref.dim %memref, %c1 : memref<?x?xf32>
+  scf.for %i = %c0 to %rows step %c1 {
+    scf.for %j = %c0 to %cols step %c1 {
+      %n = arith.addi %i, %c1 : index
+      %val = arith.index_cast %n : index to i32
+      %val_f32 = arith.sitofp %val : i32 to f32
+      memref.store %val_f32, %memref[%i, %j] : memref<?x?xf32>
+    }
+  }
+  return
+}
+
+func.func @testTransposedReadWithMask() {
+  %in = memref.alloca() : memref<4x16xf32>
+  %out = memref.alloca() : memref<16x4xf32>
+
+  %inDyn = memref.cast %in : memref<4x16xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<16x4xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  /// A mask so we only read the first 2x15 portion of %in.
+  %maskRows = arith.constant 2 : index
+  %maskCols = arith.constant 15 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A vector.transfer_read with a transpose permutation map. So the input data
+  /// (and mask) have a [4]x[16] shape, but the output is [16]x[4].
+  %readTransposed = vector.transfer_read %inDyn[%c0, %c0], %pad, %mask
+    {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// Write the vector back to a memref (that also has a transposed shape).
+  vector.transfer_write %readTransposed, %outDyn[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 15x2) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @testTransposedWriteWithMask() {
+  %in = memref.alloca() : memref<16x4xf32>
+  %out = memref.alloca() : memref<4x16xf32>
+
+  %fill = arith.constant -1.0 : f32
+  linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
+
+  %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
+  %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
+
+  func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+  %pad = arith.constant 0.0 : f32
+  %c0 = arith.constant 0 : index
+
+  /// A regular read.
+  %read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
+    : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+  /// A mask so we only write the first 3x8 portion of transpose(%in).
+  %maskRows = arith.constant 3 : index
+  %maskCols = arith.constant 8 : index
+  %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+
+  /// Write out the data with a transpose. Here (like the read test) the mask
+  /// matches the shape of the memory, not the vector.
+  vector.transfer_write %read, %outDyn[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]}
+    : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+  /// Print the input memref.
+  vector.print str "Input memref:"
+  %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+  /// Print the result memref.
+  vector.print str "(Masked 3x8) transposed result:"
+  %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+  call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+  return
+}
+
+func.func @main() {
+  /// Set the SVL to 128-bits (i.e. vscale = 1).
+  /// This test is for the use of multiple tiles rather than scalability.
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  //      CHECK: Input memref:
+  //      CHECK:  [1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1]
+  // CHECK-NEXT:  [2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2,   2]
+  // CHECK-NEXT:  [3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3,   3]
+  // CHECK-NEXT:  [4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4,   4]
+  //
+  //      CHECK: (Masked 15x2) transposed result:
----------------
banach-space wrote:

[nit] IMHO, this would be a bit easier to reason about:
```
  //      CHECK: (Masked 15x2) transposed result:
  // ....
  %maskRows = arith.constant 2 : index
  %maskCols = arith.constant 15 : index
  func.call @testTransposedReadWithMask(%maskRows, %maskCols) : () -> ()
```

Similar comment for other tests.

https://github.com/llvm/llvm-project/pull/79152


More information about the Mlir-commits mailing list