[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Aug 19 03:23:33 PDT 2025


================
@@ -0,0 +1,278 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// These test cases validate replacement of vector transfer ops with equivalent
+/// AMX tile data transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
+    %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
+  %c0_f16 = arith.constant 0.0 : f16
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<64x32x16x2xf16>, vector<4x8x2xf16>
+  %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<64x16x32x2xf16>, vector<8x16x2xf16>
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+  return
+}
+
+// CHECK-LABEL: @transfers_into_amx_tiles(
+// CHECK-SAME:    %[[A:.+]]: memref<64x32x16x2xf16>,
+// CHECK-SAME:    %[[B:.+]]: memref<64x16x32x2xf16>,
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>,
+// CHECK-SAME:    %[[IDX:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Load A into an AMX tile
+// CHECK:       %[[A_SUBVIEW:.+]] = memref.subview %[[A]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
+// CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
+// CHECK:       %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[A]]
+
+/// Load B into an AMX tile
+// CHECK:       %[[B_SUBVIEW:.+]] = memref.subview %[[B]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
+// CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
+// CHECK:       %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[B]]
+
+/// Load C into an AMX tile
+// CHECK:       %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[C]]
+
+/// Perform tile multiplication
+// CHECK:       %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME:    %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Store the result back
+// CHECK:       %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK-NOT:   vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile can be loaded directly from the buffer. However, vector transfer
+/// has to remain due to other users that require data in registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+  return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @transfer_read_multiple_users(
----------------
adam-smnk wrote:

Added extra comments to clarify this case.

I don't want to add too many checks to avoid extra noise.
This test case focuses only on the rewrite of this one particular transfer that has multiple uses.

The extra allocation are inserted to cover loads for the remaining operands. This default allocation behavior is covered in more detail in `contract-to-amx.mlir`.

https://github.com/llvm/llvm-project/pull/154114


More information about the Mlir-commits mailing list