[Mlir-commits] [mlir] [mlir][ArmSME] Lower vector.outerproduct to FMOPA/BFMOPA (PR #65621)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Sep 7 10:00:16 PDT 2023


================
@@ -0,0 +1,93 @@
+// DEFINE: %{entry_point} = test_outerproduct_4x4xf32
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="mode=locally enable-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arm-sme-to-scf \
+// DEFINE:   -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \
+// DEFINE:   -allocate-arm-sme-tiles -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry_point} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// REDEFINE: %{entry_point} = test_outerproduct_no_accumulator_4x4xf32
+// RUN: %{compile} | %{run} | FileCheck %s --check-prefix=CHECK-NO-ACC
+
+func.func @test_outerproduct_4x4xf32() {
+  %c0 = arith.constant 0 : index
+  %f1 = arith.constant 1.0 : f32
+  %f2 = arith.constant 2.0 : f32
+  %f10 = arith.constant 10.0 : f32
+
+  %a = vector.splat %f1 : vector<[4]xf32>
+  %b = vector.splat %f2 : vector<[4]xf32>
+  // TODO: vector.splat doesn't support ArmSME.
+  %c = vector.broadcast %f10 : f32 to vector<[4]x[4]xf32>
+
+  %tile = vector.outerproduct %a, %b, %c : vector<[4]xf32>, vector<[4]xf32>
+
+  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
+  %vscale = vector.vscale
+  %min_elts_s = arith.constant 4 : index
+  %svl_s = arith.muli %min_elts_s, %vscale : index
+  %za_s_size = arith.muli %svl_s, %svl_s : index
+
+  // Allocate memory.
+  %mem = memref.alloca(%za_s_size) : memref<?xf32>
+
+  // Store the tile to memory.
+  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
+
+  // Reload and print. The smallest SVL is 128-bits so the tile will be at
+  // least 4x4xf32.
+  //
+  // CHECK:      ( 12, 12, 12, 12
+  // CHECK-NEXT: ( 12, 12, 12, 12
+  // CHECK-NEXT: ( 12, 12, 12, 12
+  // CHECK-NEXT: ( 12, 12, 12, 12
----------------
MacDue wrote:

I think this may show the outer product clearer if it used something other than a constant splat.

Maybe something like:
```
%vector = llvm.intr.experimental.stepvector : vector<[4]xi32>
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
```
->
```
( 0, 0, 0, 0
( 0, 1, 2, 3,
( 0, 2, 4, 6
( 0, 3, 6, 9
```


https://github.com/llvm/llvm-project/pull/65621


More information about the Mlir-commits mailing list