[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)
Md Asghar Ahmad Shahid
llvmlistbot at llvm.org
Tue Aug 19 02:07:18 PDT 2025
================
@@ -0,0 +1,278 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// These test cases validate replacement of vector transfer ops with equivalent
+/// AMX tile data transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
+ %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
+ %c0_f16 = arith.constant 0.0 : f16
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<64x32x16x2xf16>, vector<4x8x2xf16>
+ %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<64x16x32x2xf16>, vector<8x16x2xf16>
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @transfers_into_amx_tiles(
+// CHECK-SAME: %[[A:.+]]: memref<64x32x16x2xf16>,
+// CHECK-SAME: %[[B:.+]]: memref<64x16x32x2xf16>,
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>,
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+/// Load A into an AMX tile
+// CHECK: %[[A_SUBVIEW:.+]] = memref.subview %[[A]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
+// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[A]]
+
+/// Load B into an AMX tile
+// CHECK: %[[B_SUBVIEW:.+]] = memref.subview %[[B]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
+// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[B]]
+
+/// Load C into an AMX tile
+// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[C]]
+
+/// Perform tile multiplication
+// CHECK: %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Store the result back
+// CHECK: %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK: amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK-NOT: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile can be loaded directly from the buffer. However, vector transfer
+/// has to remain due to other users that require data in registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+ return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @transfer_read_multiple_users(
----------------
shahidact wrote:
Could you pls have count check for `memref.alloca`, I think we have a couple of extra alloca here?
https://github.com/llvm/llvm-project/pull/154114
More information about the Mlir-commits
mailing list