[Mlir-commits] [mlir] [MLIR][AArch64] Lower `vector.contract` to SVE FEAT_BF16 operations (PR #147052)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Jul 7 01:43:04 PDT 2025
================
@@ -0,0 +1,201 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-bf16' \
+// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \
+// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+bf16" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (n, k)>,
+ affine_map<(m, n, k) -> (m, n)>
+]
+
+//
+// Test the lowering of `vector.contract` using the `LowerContractionToSVEBFMMLAPattern`
+//
+// The operation that the `vector.contract` in this test performs is matrix
+// multiplication with accumulate
+// OUT = ACC + LHS * RHS
+// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT.
+//
+// Tested are calculations as well as that the relevant `ArmSVE` dialect
+// operation ('arm_sve.intr.bfmmla`) is emitted.
+//
+// That pattern above handles (therefore this test prepares) input/output vectors with
+// specific shapes:
+// * LHS: vector<Mx4xbf16>
+// * RHS: vector<[N]x4xbf16>
+// * ACC, OUT: vector<Mx[N]xf32>
+// Note that the RHS is transposed.
+// This data layout makes it efficient to load data into SVE
+// registers in the layout expected by te BFMMLA instruction.
+// Such a `vector.contract` is representative of the code we aim to generate
+// by scalable vectorisation of `linalg.mmt4d`.
+// See mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp
+// for more information and rationale about these shapes.
+//
+// In this specific test we use M == 4 and N == 4
+//
+
+// Allocate and initialise a memref containing test data for use as the ACC
+// operand. The memref has one dynamic dimension whose extent depends on the
+// runtime value of VSCALE.
+//
+// The input parameter `%in` is a vector that is replicated VSCALE times
+// across the columns of the memref.
+func.func private @prepareAccTestData(%in: vector<4x4xf32>) -> memref<4x?xf32> {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+
+ %vs = vector.vscale
+ %d = arith.muli %c4, %vs : index
----------------
banach-space wrote:
[nit] What does `d` stand for in `%d`? Why not `%ub` (upper bound) or `%c4_vscale` (4 x vscale)? A more descriptive would be helpful, otherwise it's hard to see what happens in the loop below (`scf.for %j = %c0 to %d step %c4`).
https://github.com/llvm/llvm-project/pull/147052
More information about the Mlir-commits
mailing list