[Mlir-commits] [mlir] [mlir][ArmSME] Add initial SME vector legalization pass (PR #79152)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Jan 26 03:56:32 PST 2024
================
@@ -0,0 +1,171 @@
+// RUN: mlir-opt %s -arm-sme-vector-legalization -cse -canonicalize \
+// RUN: -convert-vector-to-arm-sme -allocate-arm-sme-tiles -convert-arm-sme-to-scf \
+// RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \
+// RUN: -convert-arm-sme-to-llvm \
+// RUN: -convert-vector-to-llvm=enable-arm-sve \
+// RUN: -cse -canonicalize -test-lower-to-llvm | \
+// RUN: %mcr_aarch64_cmd \
+// RUN: -e=main -entry-point-result=void \
+// RUN: -march=aarch64 -mattr="+sve,+sme" \
+// RUN: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib,%mlir_arm_runner_utils | \
+// RUN: FileCheck %s
+
+#transpose = affine_map<(d0, d1) -> (d1, d0)>
+
+func.func @fill2DMemrefRows(%memref: memref<?x?xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %rows = memref.dim %memref, %c0 : memref<?x?xf32>
+ %cols = memref.dim %memref, %c1 : memref<?x?xf32>
+ scf.for %i = %c0 to %rows step %c1 {
+ scf.for %j = %c0 to %cols step %c1 {
+ %n = arith.addi %i, %c1 : index
+ %val = arith.index_cast %n : index to i32
+ %val_f32 = arith.sitofp %val : i32 to f32
+ memref.store %val_f32, %memref[%i, %j] : memref<?x?xf32>
+ }
+ }
+ return
+}
+
+func.func @testTransposedReadWithMask() {
+ %in = memref.alloca() : memref<4x16xf32>
+ %out = memref.alloca() : memref<16x4xf32>
+
+ %inDyn = memref.cast %in : memref<4x16xf32> to memref<?x?xf32>
+ %outDyn = memref.cast %out : memref<16x4xf32> to memref<?x?xf32>
+
+ func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+ /// A mask so we only read the first 2x15 portion of %in.
+ %maskRows = arith.constant 2 : index
+ %maskCols = arith.constant 15 : index
+ %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+ %pad = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+
+ /// A vector.transfer_read with a transpose permutation map. So the input data
+ /// (and mask) have a [4]x[16] shape, but the output is [16]x[4].
+ %readTransposed = vector.transfer_read %inDyn[%c0, %c0], %pad, %mask
+ {permutation_map = #transpose, in_bounds = [true, true]} : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+ /// Write the vector back to a memref (that also has a transposed shape).
+ vector.transfer_write %readTransposed, %outDyn[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+ /// Print the input memref.
+ vector.print str "Input memref:"
+ %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+ /// Print the result memref.
+ vector.print str "(Masked 15x2) transposed result:"
+ %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+ return
+}
+
+func.func @testTransposedWriteWithMask() {
+ %in = memref.alloca() : memref<16x4xf32>
+ %out = memref.alloca() : memref<4x16xf32>
+
+ %fill = arith.constant -1.0 : f32
+ linalg.fill ins(%fill : f32) outs(%out : memref<4x16xf32>)
+
+ %inDyn = memref.cast %in : memref<16x4xf32> to memref<?x?xf32>
+ %outDyn = memref.cast %out : memref<4x16xf32> to memref<?x?xf32>
+
+ func.call @fill2DMemrefRows(%inDyn) : (memref<?x?xf32>) -> ()
+
+ %pad = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+
+ /// A regular read.
+ %read = vector.transfer_read %inDyn[%c0, %c0], %pad {in_bounds = [true, true]}
+ : memref<?x?xf32>, vector<[16]x[4]xf32>
+
+ /// A mask so we only write the first 3x8 portion of transpose(%in).
+ %maskRows = arith.constant 3 : index
+ %maskCols = arith.constant 8 : index
+ %mask = vector.create_mask %maskRows, %maskCols : vector<[4]x[16]xi1>
+
+ /// Write out the data with a transpose. Here (like the read test) the mask
+ /// matches the shape of the memory, not the vector.
+ vector.transfer_write %read, %outDyn[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]}
+ : vector<[16]x[4]xf32>, memref<?x?xf32>
+
+ /// Print the input memref.
+ vector.print str "Input memref:"
+ %inUnranked = memref.cast %inDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%inUnranked) : (memref<*xf32>) -> ()
+
+ /// Print the result memref.
+ vector.print str "(Masked 3x8) transposed result:"
+ %outUnranked = memref.cast %outDyn : memref<?x?xf32> to memref<*xf32>
+ call @printMemrefF32(%outUnranked) : (memref<*xf32>) -> ()
+
+ return
+}
+
+func.func @main() {
+ /// Set the SVL to 128-bits (i.e. vscale = 1).
+ /// This test is for the use of multiple tiles rather than scalability.
+ %c128 = arith.constant 128 : i32
+ func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+ // CHECK: Input memref:
+ // CHECK: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ // CHECK-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
+ // CHECK-NEXT: [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
+ // CHECK-NEXT: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
+ //
+ // CHECK: (Masked 15x2) transposed result:
----------------
banach-space wrote:
[nit] IMHO, this would be a bit easier to reason about:
```
// CHECK: (Masked 15x2) transposed result:
// ....
%maskRows = arith.constant 2 : index
%maskCols = arith.constant 15 : index
func.call @testTransposedReadWithMask(%maskRows, %maskCols) : () -> ()
```
Similar comment for other tests.
https://github.com/llvm/llvm-project/pull/79152
More information about the Mlir-commits
mailing list