[Mlir-commits] [mlir] [mlir][ArmSME] Support 4-way widening outer products (PR #79288)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Feb 6 06:54:33 PST 2024


================
@@ -0,0 +1,150 @@
+// DEFINE: %{entry} = main
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   -convert-vector-to-arm-sme -convert-arith-to-arm-sme \
+// DEFINE:   -arm-sme-outer-product-fusion \
+// DEFINE:   -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \
+// DEFINE:   -convert-arm-sme-to-scf -allocate-arm-sme-tiles \
+// DEFINE:   -convert-arm-sme-to-llvm -cse -canonicalize \
+// DEFINE:   -test-lower-to-llvm
+// DEFINE: %{run} = %mcr_aarch64_cmd \
+// DEFINE:   -march=aarch64 -mattr=+sve,+sme \
+// DEFINE:   -e %{entry} -entry-point-result=void \
+// DEFINE:   -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils,%arm_sme_abi_shlib
+
+// RUN: %{compile} | %{run} | FileCheck %s
+
+// NOTE: QEMU gives incorrect result for SME SMOPA 4-way outer product
+// instruction (version <= 8.2.0, latest version at time of writing), see:
+// https://gitlab.com/qemu-project/qemu/-/issues/2083 This test is expected to
+// fail (CHECK lines are correct) until a fixed version of QEMU can be used.
+
+// FIXME: Remove the 'XFAIL' below once a fixed QEMU version is available
+// (and installed on CI buildbot).
+// XFAIL: *
+
+// NOTE: there is no non-widening variant for these types and this test can't
+// be lowered without the widening pass, therefore we can't check if the result
+// is the same without widening pass like 'test-outerproduct-f16f16f32.mlir'
+// does.
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmSVLBits(%c128) : (i32) -> ()
+
+  func.call @test_outerproduct_i8i8i32 () : () -> ()
+
+  func.call @test_masked_outerproduct_i8i8i32() : () -> ()
+
+  return
+}
+
+func.func @test_outerproduct_i8i8i32() {
+  %undef = llvm.mlir.undef : vector<[4]xi8>
+
+  %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+  %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+  %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+  %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+
+  %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+  %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+  %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+  %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+
+  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+  %0 = vector.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+  %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xi32>, vector<[4]xi32>
+  %2 = vector.outerproduct %a2_ext, %b2_ext, %1 : vector<[4]xi32>, vector<[4]xi32>
+  %3 = vector.outerproduct %a3_ext, %b3_ext, %2 : vector<[4]xi32>, vector<[4]xi32>
+
+  // CHECK:      ( 110,  134,  158,  182 )
+  // CHECK-NEXT: ( 390,  478,  566,  654 )
+  // CHECK-NEXT: ( 670,  822,  974, 1126 )
+  // CHECK-NEXT: ( 950, 1166, 1382, 1598 )
+  vector.print %3 : vector<[4]x[4]xi32>
+
+  return
+}
+
+func.func @test_masked_outerproduct_i8i8i32() {
+  %undef = llvm.mlir.undef : vector<[4]xi8>
+
+  %a0_data = arith.constant dense<[0, 4, 8, 12]> : vector<4xi8>
+  %a1_data = arith.constant dense<[1, 5, 9, 13]> : vector<4xi8>
+  %a2_data = arith.constant dense<[2, 6, 10, 14]> : vector<4xi8>
+  %a3_data = arith.constant dense<[3, 7, 11, 15]> : vector<4xi8>
+
+  %b0_data = arith.constant dense<[16, 20, 24, 28]> : vector<4xi8>
+  %b1_data = arith.constant dense<[17, 21, 25, 29]> : vector<4xi8>
+  %b2_data = arith.constant dense<[18, 22, 26, 30]> : vector<4xi8>
+  %b3_data = arith.constant dense<[19, 23, 27, 31]> : vector<4xi8>
+
+  %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a2 = vector.scalable.insert %a2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b2 = vector.scalable.insert %b2_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %a3 = vector.scalable.insert %a3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+  %b3 = vector.scalable.insert %b3_data, %undef[0] : vector<4xi8> into vector<[4]xi8>
+
+  %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+  %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+  %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+  %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+  %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+  %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+  %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+  %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+  %c1 = arith.constant 1 : index
+  %c2 = arith.constant 2 : index
+  %c3 = arith.constant 3 : index
+  %c4 = arith.constant 4 : index
+
----------------
banach-space wrote:

[nit] Add a comment that with these masks the bottom row of the output matrix is preserved

https://github.com/llvm/llvm-project/pull/79288


More information about the Mlir-commits mailing list