[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)

Md Asghar Ahmad Shahid llvmlistbot at llvm.org
Tue Aug 19 02:07:18 PDT 2025


================
@@ -0,0 +1,278 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// These test cases validate replacement of vector transfer ops with equivalent
+/// AMX tile data transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
+    %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
+  %c0_f16 = arith.constant 0.0 : f16
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<64x32x16x2xf16>, vector<4x8x2xf16>
+  %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<64x16x32x2xf16>, vector<8x16x2xf16>
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+  return
+}
+
+// CHECK-LABEL: @transfers_into_amx_tiles(
+// CHECK-SAME:    %[[A:.+]]: memref<64x32x16x2xf16>,
+// CHECK-SAME:    %[[B:.+]]: memref<64x16x32x2xf16>,
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>,
+// CHECK-SAME:    %[[IDX:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Load A into an AMX tile
+// CHECK:       %[[A_SUBVIEW:.+]] = memref.subview %[[A]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
+// CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
+// CHECK:       %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[A]]
+
+/// Load B into an AMX tile
+// CHECK:       %[[B_SUBVIEW:.+]] = memref.subview %[[B]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
+// CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
+// CHECK:       %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[B]]
+
+/// Load C into an AMX tile
+// CHECK:       %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[C]]
+
+/// Perform tile multiplication
+// CHECK:       %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME:    %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Store the result back
+// CHECK:       %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK-NOT:   vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile can be loaded directly from the buffer. However, vector transfer
+/// has to remain due to other users that require data in registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+  return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @transfer_read_multiple_users(
----------------
shahidact wrote:

Could you pls have count check for `memref.alloca`, I think we have a couple of extra alloca here?

https://github.com/llvm/llvm-project/pull/154114


More information about the Mlir-commits mailing list