[Mlir-commits] [mlir] [MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM (PR #140573)

Andrzej Warzyński llvmlistbot at llvm.org
Mon Jun 16 03:30:01 PDT 2025


================
@@ -0,0 +1,375 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  \
+// DEFINE:   --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+func.func private @printMemrefI32(%ptr : memref<*xi32>)
+
+func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> memref<4x?xi32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0_i32 = arith.constant 0 : i32
+
+  %vs = vector.vscale
+  %d = arith.muli %c4, %vs : index
+  %mem = memref.alloc(%d) : memref<4x?xi32>
+
+  scf.for %j = %c0 to %d step %c4 {
+    vector.transfer_write %in, %mem[%c0, %j] {in_bounds = [true, true]} : vector<4x4xi32>, memref<4x?xi32>
+  }
+
+  return %mem : memref<4x?xi32>
+}
+
+func.func private @prepareLHSTestData(%in: vector<4x8xi8>) -> memref<4x8xi8> {
+  %c0 = arith.constant 0 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %mem = memref.alloc() : memref<4x8xi8>
+  vector.transfer_write %in, %mem[%c0, %c0] {in_bounds = [true, true]} : vector<4x8xi8>, memref<4x8xi8>
+
+  return %mem : memref<4x8xi8>
+}
+
+func.func private @prepareRHSTestData(%in: vector<4x8xi8>) -> memref<?xi8> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c4 = arith.constant 4 : index
+  %c0_i8 = arith.constant 0 : i8
+
+  %vs = vector.vscale
+  %d = arith.muli %c4, %vs : index
+  %mem = memref.alloc(%d) : memref<?x8xi8>
+
+  scf.for %i = %c0 to %d step %c4 {
+    vector.transfer_write %in, %mem[%i, %c0] {in_bounds = [true, true]} : vector<4x8xi8>, memref<?x8xi8>
+  }
+
+  %mem_out = memref.collapse_shape %mem [[0, 1]] : memref<?x8xi8> into memref<?xi8>
+  return %mem_out : memref<?xi8>
+}
+
+// CHECK-IR-LABEL: llvm.func @test_smmla
+// CHECK-IR-COUNT-4: arm_sve.intr.smmla
+func.func @test_smmla() {
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+
+  %acc_mem = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xi32>) -> memref<4x?xi32>
+  %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_i32 {in_bounds = [true, true]} : memref<4x?xi32>, vector<4x[4]xi32>
+
+  // Workaround for a crash, see https://github.com/llvm/llvm-project/issues/143670
+  %acc_cast = memref.cast %acc_mem : memref<4x?xi32> to memref<*xi32>
+  call @printMemrefI32(%acc_cast) : (memref<*xi32>) -> ()
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_mem = func.call @prepareLHSTestData(%lhs_cst) : (vector<4x8xi8>) -> memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<4x8xi8>, vector<4x8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> : vector<4x8xi8>
+
+  %rhs_mem = func.call @prepareRHSTestData(%rhs_cst) : (vector<4x8xi8>) -> memref<?xi8>
+  %rhs_flat = vector.transfer_read %rhs_mem[%c0], %c0_i8 {in_bounds = [true]} :  memref<?xi8>, vector<[32]xi8>
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
----------------
banach-space wrote:

> Strictly speaking, it's A * B + C operation, so yes, the accumulation makes it not strictly a matrix multiplication.

I wasn’t referring to the distinction between:
* Matrix multiplication: `C = A * B`
* GEMM: `C += A * B`

I meant specifically the "matrix multiplication" part - that is, how the operand access patterns relate to a typical matmul structure.

> there's not a single form of affine maps that determines if an operation is a matrix multiplication or not, there are at least 8 (depending on two options for input and output layouts: transposed or not with respect to the iteration space)

I see your point, but I think we're talking past each other slightly. I agree that there are multiple affine map permutations corresponding to matmuls with transposed inputs or outputs. But I do think it's fair to say that certain patterns - like the one below - clearly represent [matrix access in the conventional sense](https://en.wikipedia.org/wiki/Matrix_multiplication) (i.e. what most people expect):

```mlir
#packed_maps = [
  affine_map<(m, n, k) -> (m, k)>,
  affine_map<(m, n, k) -> (k, n)>,
  affine_map<(m, n, k) -> (m, n)>
]
```

The variant used in this test:

```mlir
#packed_maps = [
  affine_map<(m, n, k) -> (m, k)>,
  affine_map<(m, n, k) -> (n, k)>,
  affine_map<(m, n, k) -> (m, n)>
]
```

still clearly fits within that space - it simply reflects a different layout (e.g. B accessed with n, k instead of k, n), modeliing `A * B^T`. This is the access pattern expected by i8mm instructions.

This is exactly what we want for i8mm instructions. Which brings me to:
> I don't think the test program has any dependency on the i8mm extension.

Technically, no - there's no explicit architectural dependency. But the test is deliberately structured to reflect the data layout expected by i8mm, and that's the motivation behind how the affine maps are written. While this layout may be valid outside of i8mm, the intent is to exercise codegen paths that are optimized for i8mm, and that's what the test is meant to capture.

My ask is to make it clear that:
* This isn’t a “conventional” matrix multiplication - we’re doing `A * B^T` rather than `A * B`.
* The chosen data layout matches what the `i8mm` extension expects - that’s the motivation for using `A * B^T`.
* This is representative of the code we aim to generate from ops like `linalg.mmt4d` - to clarify context and expected usage.

> > Could you clarify that this is a vector.contract that would be generated from vectorizing linalg.mmt4d?
>
> Ack.

Sorry for not making this clearer earlier - could you add a brief comment in the code to that effect?

https://github.com/llvm/llvm-project/pull/140573


More information about the Mlir-commits mailing list