[Mlir-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)
Andrzej Warzyński
llvmlistbot at llvm.org
Wed Jun 11 03:26:42 PDT 2025
banach-space wrote:
Thanks for all the updates, @momchil-velikov 🙏🏻
Just for context: Momchil and I discussed this offline - in its current form, the test crashes (after producing valid results). I’ve created a minimal repro here:
* https://github.com/llvm/llvm-project/issues/143670
Given that these tests (i.e., SVE e2e tests that require emulation) are only run by [clang-aarch64-full-2stage](https://github.com/llvm/llvm-zorg/blob/8741e96a3f5b633222d1cd3b3f2ac6223fdbdd3b/buildbot/osuosl/master/config/builders.py#L471-L495) - which we closely monitor - I'm inclined to land this with a temporary workaround (on top of other updates suggested below).
---
Looking at the current version of the tests, I do think there’s room to improve code reuse and to make the test structure more VLA-friendly (Vector-Length Agnostic). For instance, the function below mixes fixed and scalable shapes:
```mlir
func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> vector<4x[4]xi32> {
// ...
}
```
It also returns a scalable vector from a function - something we haven’t independently tested and should probably avoid until validated.
---
Since this comment is getting long and I’m effectively suggesting a refactor (after many solid improvements already), I’ve gone ahead and rewritten the tests myself. Below is a single test file that combines:
* `contraction-smmla-4x8x4.mlir`,
* `contraction-ummla-4x8x4.mlir`,
* `contraction-summla-4x8x4.mlir`,
* `contraction-usmmla-4x8x4.mlir`.
Summary of changes:
* Replaced custom init logic (via `arith.constant`) with `llvm.intr.stepvector` to keep it simple but sufficient for a smoke test.
* Unified test data across SMMLA, UMMLA, SUMMLA, USMMLA - since all inputs are positive, the output is the same.
* Added comments, a workaround for #143670, and a few TODOs.
To me, this format accomplishes 3 key goals:
* Maximizes code reuse
* Verifies codegen paths for all i8mm instructions
* Smoke-tests run-time correctness under scalable vector lengths
If this makes sense to you, please feel free to reuse the file. Alternatively, I’d be happy to push it to your branch directly.
–Andrzej
---
**NEW TEST FILE**
```mlir
// REQUIRES: arm-emulator
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
// DEFINE: -o %t
// DEFINE: %{entry_point} = main
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run}
#packed_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>
]
//=============================================================================
// Helper methods to allocate+initialise test data
//=============================================================================
// Allolocate and initialise a memref of 16 x vscale elements of type: i8. This
// matches the requirments for the accumulator for i8mm, it is precisely
// * 4 x [4] elements.
func.func private @getFlatMemRef_i32() -> memref<?xi32> {
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%vscale = vector.vscale
%c16_vscale = arith.muli %vscale, %c16 : index
%flat_mem = memref.alloc(%c16_vscale) : memref<?xi32>
%vector_i32 = llvm.intr.stepvector : vector<[16]xi32>
vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<[16]xi32>, memref<?xi32>
return %flat_mem : memref<?xi32>
}
// Allolocate and initialise a memref of 32 x vscale elements of type: i8. This
// matches the requirments for the RHS for i8mm, it is precisely
// * [4] x 8 elements.
func.func private @getFlatMemRef_i8_scalable() -> memref<?xi8> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%vscale = vector.vscale
%vscale_times_32 = arith.muli %vscale, %c32 : index
%flat_mem = memref.alloc(%vscale_times_32) : memref<?xi8>
%vector_i32 = llvm.intr.stepvector : vector<[32]xi8>
vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<[32]xi8>, memref<?xi8>
return %flat_mem : memref<?xi8>
}
// Allolocate and initialise a memref of 32 elements of type: i8. This
// matches the requirments for the RHS for i8mm, it is precisely:
// * 4 x 8 elements.
func.func private @getFlatMemRef_i8_fixed() -> memref<?xi8> {
%c0 = arith.constant 0 : index
%c32 = arith.constant 32 : index
%flat_mem = memref.alloc(%c32) : memref<?xi8>
%vector_i32 = llvm.intr.stepvector : vector<32xi8>
vector.transfer_write %vector_i32, %flat_mem[%c0] : vector<32xi8>, memref<?xi8>
return %flat_mem : memref<?xi8>
}
//=============================================================================
// Main entry point for test.
//=============================================================================
func.func @main() {
// NOTE: Update this value to some other valid value of VL (i.e. supported by
// SVE) to see the impact of "scalability".
// FIXME: https://github.com/llvm/llvm-project/issues/143670
%c128 = arith.constant 128 : i32
func.call @setArmVLBits(%c128) : (i32) -> ()
%c0_idx = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%c0_i8 = arith.constant 0 : i8
//---------------------------------------------------------------------------
// 1. GENERATE TEST DATA
//---------------------------------------------------------------------------
// 1.1. Accumulator test data
%acc_flat = func.call @getFlatMemRef_i32() : () -> memref<?xi32>
%flat_vec = vector.transfer_read %acc_flat[%c0_idx], %c0_i32 {in_bounds = [true]} : memref<?xi32>, vector<[16]xi32>
%acc = vector.shape_cast %flat_vec : vector<[16]xi32> to vector<4x[4]xi32>
// 1.2. LHS test data
%lhs_flat = func.call @getFlatMemRef_i8_fixed() : () -> memref<?xi8>
%lhs_flat_vec = vector.transfer_read %lhs_flat[%c0_idx], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<32xi8>
%lhs = vector.shape_cast %lhs_flat_vec : vector<32xi8> to vector<4x8xi8>
// 1.3. RHS test data
%rhs_flat = func.call @getFlatMemRef_i8_scalable() : () -> memref<?xi8>
%rhs_flat_vec = vector.transfer_read %rhs_flat[%c0_idx], %c0_i8 {in_bounds = [true]} : memref<?xi8>, vector<[32]xi8>
%rhs = vector.shape_cast %rhs_flat_vec : vector<[32]xi8> to vector<[4]x8xi8>
//---------------------------------------------------------------------------
// 2. "EXTEND" THE RHS + LHS VECTORS
// This is what i8mm expects.
//---------------------------------------------------------------------------
%lhs_si = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
%lhs_ui = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
%rhs_si = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
%rhs_ui = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
//---------------------------------------------------------------------------
// 3. MATRIX MULTIPLICATION
//---------------------------------------------------------------------------
// 3.1. SMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.smmla
%res_smmla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_si, %rhs_si, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
// 3.2. UMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.ummla
%res_ummla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_ui, %rhs_ui, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
// 3.3. USMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
%res_usmmla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_ui, %rhs_si, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
// 3.4. SUMMLA
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
%res_summla = vector.contract {indexing_maps = #packed_maps,
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %lhs_si, %rhs_ui, %acc
: vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
//---------------------------------------------------------------------------
// 4. DISPLAY THE RESULTS OF THE MULTIPLICATION
// TODO: Implement this and use instead:
// * vector.print %2 : vector<4x[4]xi32>
//---------------------------------------------------------------------------
vector.print str "RESULT (smmla):\n"
%s0 = vector.extract %res_smmla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%s1 = vector.extract %res_smmla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%s2 = vector.extract %res_smmla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%s3 = vector.extract %res_smmla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %s0 : vector<[4]xi32>
vector.print %s1 : vector<[4]xi32>
vector.print %s2 : vector<[4]xi32>
vector.print %s3 : vector<[4]xi32>
vector.print str "RESULT (ummla):\n"
%u0 = vector.extract %res_ummla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%u1 = vector.extract %res_ummla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%u2 = vector.extract %res_ummla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%u3 = vector.extract %res_ummla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %u0 : vector<[4]xi32>
vector.print %u1 : vector<[4]xi32>
vector.print %u2 : vector<[4]xi32>
vector.print %u3 : vector<[4]xi32>
vector.print str "RESULT (usmmla):\n"
%us0 = vector.extract %res_usmmla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%us1 = vector.extract %res_usmmla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%us2 = vector.extract %res_usmmla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%us3 = vector.extract %res_usmmla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %us0 : vector<[4]xi32>
vector.print %us1 : vector<[4]xi32>
vector.print %us2 : vector<[4]xi32>
vector.print %us3 : vector<[4]xi32>
vector.print str "RESULT (summla):\n"
%su0 = vector.extract %res_summla[0] : vector<[4]xi32> from vector<4x[4]xi32>
%su1 = vector.extract %res_summla[1] : vector<[4]xi32> from vector<4x[4]xi32>
%su2 = vector.extract %res_summla[2] : vector<[4]xi32> from vector<4x[4]xi32>
%su3 = vector.extract %res_summla[3] : vector<[4]xi32> from vector<4x[4]xi32>
vector.print %su0 : vector<[4]xi32>
vector.print %su1 : vector<[4]xi32>
vector.print %su2 : vector<[4]xi32>
vector.print %su3 : vector<[4]xi32>
// With all inputs positive, the results are identical for types of extensions.
// TOOD: Use negative numbers to demonstrate the run-time difference between e.g. UMMLA and SMMLA.
//CHECK-4: ( 140, 365, 590, 815, 1040, 1265, 1490, 1715 )
//CHECK-4: ( 372, 1109, 1846, 2583, 3320, 4057, 4794, 5531 )
//CHECK-4: ( 604, 1853, 3102, 4351, 5600, 6849, 8098, 9347 )
//CHECK-4: ( 836, 2597, 4358, 6119, 7880, 9641, 11402, 13163 )
//---------------------------------------------------------------------------
// 5. WORKAROUND
// This extra printing should not be required, but the test crashes without it.
// FIXME: https://github.com/llvm/llvm-project/issues/143670
//---------------------------------------------------------------------------
%res_smmla_flat = vector.shape_cast %res_smmla : vector<4x[4]xi32> to vector<[16]xi32>
vector.transfer_write %res_smmla_flat, %acc_flat[%c0_idx] : vector<[16]xi32>, memref<?xi32>
%acc_cast = memref.cast %acc_flat : memref<?xi32> to memref<*xi32>
call @printMemrefI32(%acc_cast) : (memref<*xi32>) -> ()
//---------------------------------------------------------------------------
// 6. BUFFER DEALLOCATION
//---------------------------------------------------------------------------
memref.dealloc %acc_flat : memref<?xi32>
memref.dealloc %rhs_flat : memref<?xi8>
memref.dealloc %lhs_flat : memref<?xi8>
return
}
func.func private @printMemrefI32(%ptr : memref<*xi32>)
func.func private @setArmVLBits(%bits : i32)
```
https://github.com/llvm/llvm-project/pull/140573
More information about the Mlir-commits
mailing list