[Mlir-commits] [mlir] [mlir][ArmSME] Support 2-way widening outer products (PR #78975)

Cullen Rhodes llvmlistbot at llvm.org
Tue Jan 30 01:30:33 PST 2024


================
@@ -0,0 +1,100 @@
+// DEFINE: %{entry} = test_outerproduct_f16f16f32
+// DEFINE: %{widening_opts} = -arm-sme-outer-product-widening
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{widening_opts} \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
+// DEFINE:   -test-lower-to-llvm -o %t
+// DEFINE: %{run} = %mcr_aarch64_cmd %t \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%arm_sme_abi_shlib
+
+// RUN: %{compile}
+
+// RUN: %{run} | FileCheck %s
+
+// Check result is the same when outerproducts are not combined into widening
+// variant.
+
+// REDEFINE: %{widening_opts} =
+// RUN: %{run} | FileCheck %s
+
+func.func @test_outerproduct_f16f16f32() {
+  %undef = llvm.mlir.undef : vector<[4]xf16>
+
+  %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16>
+  %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16>
+  %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16>
+  %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16>
+
+  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16>
+
+  %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+  %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+  %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+  %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+  %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32>
+  %0 = vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32>
+  %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32>
+
+  // CHECK:      (  79,  95, 111, 127
----------------
c-rhodes wrote:

Briefly tried this and got incorrect results, was all zeroes, not sure why. But it was correct if I explicitly passed `sme128=on` to QEMU.

https://github.com/llvm/llvm-project/pull/78975


More information about the Mlir-commits mailing list