[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)

Adam Siemieniuk llvmlistbot at llvm.org
Tue Aug 19 04:07:20 PDT 2025


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/154114

>From 7a24f6862397d5ef86535ebf7da9291ff5155874 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Wed, 6 Aug 2025 16:32:12 +0200
Subject: [PATCH 1/4] [mlir][amx] Direct AMX data transfers

Extends Vector to AMX conversion to attempt populating AMX tiles
directly from memory.

When possible, contraction producers and consumers are replaced by
AMX tile data transfer operations. This shortens data path by skipping
intermediate register loads and stores.
---
 .../Conversion/VectorToAMX/VectorToAMX.cpp    | 170 ++++++++++-
 .../VectorToAMX/transfer-to-amx.mlir          | 278 ++++++++++++++++++
 2 files changed, 434 insertions(+), 14 deletions(-)
 create mode 100644 mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir

diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index a11e9b2624300..23194d0a4359b 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -10,7 +10,6 @@
 
 #include "mlir/Dialect/AMX/AMXDialect.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -21,6 +20,8 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
+#include "llvm/Support/DebugLog.h"
+
 #include <numeric>
 
 namespace mlir {
@@ -30,6 +31,8 @@ namespace mlir {
 
 using namespace mlir;
 
+#define DEBUG_TYPE "vector-to-amx"
+
 namespace {
 
 /// Return true if vector shape is compatible with AMX tiles.
@@ -49,8 +52,10 @@ static bool verifyAmxShape(VectorType vec) {
   // 3D shape indicates VNNI packed layout.
   if (vec.getRank() == 3) {
     int64_t vnniFactor = 32 / elemBitWidth;
-    if (shape.back() != vnniFactor)
+    if (shape.back() != vnniFactor) {
+      LDBG() << "invalid VNNI packing factor";
       return false;
+    }
     cols *= vnniFactor;
   }
 
@@ -60,7 +65,7 @@ static bool verifyAmxShape(VectorType vec) {
   return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
 }
 
-/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
+/// Check if contraction operands are in AMX-compatible packed VNNI layout.
 static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
                                      vector::ContractionOp contractOp) {
   VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
@@ -172,9 +177,9 @@ static LogicalResult validateOperands(PatternRewriter &rewriter,
   return success();
 }
 
-/// Collapses the two innermost dimensions together.
-static Value collapseLastDim(PatternRewriter &rewriter,
-                             TypedValue<MemRefType> memref) {
+/// Collapse the two innermost dimensions together.
+static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
+                                              TypedValue<MemRefType> memref) {
   int64_t rank = memref.getType().getRank();
   SmallVector<ReassociationIndices> reassocIndices;
   for (auto i : llvm::seq<int64_t>(0, rank - 2))
@@ -184,21 +189,144 @@ static Value collapseLastDim(PatternRewriter &rewriter,
                                          reassocIndices);
 }
 
-/// Loads vector values to an AMX tile.
+/// Attempt to create an AMX tile load/store operation equivalent to the given
+/// vector transfer `xfer` op.
+/// This approach allows to skip longer route through registers and a temporary
+/// buffer otherwise required to move data to/from an AMX tile.
+static Operation *
+loadStoreFromTransfer(PatternRewriter &rewriter,
+                      VectorTransferOpInterface xferOp, bool isPacked,
+                      TypedValue<amx::TileType> tileToStore = nullptr) {
+  if (!xferOp)
+    return nullptr;
+  if (xferOp.hasOutOfBoundsDim() ||
+      !xferOp.getPermutationMap().isMinorIdentity())
+    return nullptr;
+
+  // Extra checks in case of a write op.
+  // Stores must not be packed.
+  if (isa<vector::TransferWriteOp>(xferOp) &&
+      (!tileToStore || isPacked ||
+       tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
+    return nullptr;
+
+  // Check for a memref source buffer.
+  // AMX data transfer requires at least 2D shape to correctly
+  // infer stride between rows.
+  Value base = xferOp.getBase();
+  auto memTy = dyn_cast<MemRefType>(base.getType());
+  int64_t memRank = memTy.getRank();
+  if (!memTy || memRank < 2)
+    return nullptr;
+
+  // Check that the source buffer has enough contiguous elements to load whole
+  // AMX tile row.
+  //
+  // To ensure correctness, the validation is conservative and expects the
+  // buffer's innermost dimensions to be statically known, equal to or larger
+  // than the vector row length, and equal to the VNNI dimension if applicable.
+  //
+  // This check could be relaxed to accept more arbitrarily shaped buffers as
+  // long as there are enough contiguous elements to load a whole row.
+  if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
+    return nullptr;
+  VectorType vecTy = xferOp.getVectorType();
+  ArrayRef<int64_t> vecShape = vecTy.getShape();
+  ArrayRef<int64_t> memShape = memTy.getShape();
+  if (memShape.back() < vecShape.back())
+    return nullptr;
+  if (isPacked &&
+      (memShape.back() != vecShape.back() ||
+       memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
+    return nullptr;
+
+  // Load values directly from the buffer to an AMX tile.
+  PatternRewriter::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(xferOp);
+  Location loc = xferOp.getLoc();
+
+  // Create a subview of the source buffer based on the transfer op to resolve
+  // offsets.
+  SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
+  int64_t vecRank = vecTy.getRank();
+  assert(memRank >= vecRank &&
+         "Expects buffer to be the same or greater rank than vector");
+  SmallVector<int64_t> shape(memRank - vecRank, 1);
+  shape.append(vecShape.begin(), vecShape.end());
+  TypedValue<MemRefType> src =
+      memref::SubViewOp::create(
+          rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
+          getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
+          .getResult();
+
+  // Collapse the VNNI dimension in case of packing.
+  if (isPacked)
+    src = collapseLastDim(rewriter, src);
+  int64_t rows = vecShape[0];
+  int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
+                                 std::multiplies<int64_t>());
+  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+  SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
+
+  Operation *amxTileOp = nullptr;
+  if (isa<vector::TransferReadOp>(xferOp)) {
+    amxTileOp =
+        amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+  } else if (isa<vector::TransferWriteOp>(xferOp)) {
+    amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+                                         tileToStore);
+  }
+
+  return amxTileOp;
+}
+
+/// Attempt to create an AMX tile load operation equivalent to the given
+/// vector transfer `readOp`.
+/// Returns loaded AMX tile if successful.
+static FailureOr<TypedValue<amx::TileType>>
+loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
+                 bool isPacked) {
+  amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
+      loadStoreFromTransfer(rewriter, readOp, isPacked));
+  if (!loadOp)
+    return failure();
+  return loadOp.getRes();
+}
+
+/// Attempt to create an AMX tile store operation equivalent to the given
+/// vector transfer `writeOp`.
+static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
+                                       vector::TransferWriteOp writeOp,
+                                       TypedValue<amx::TileType> tileToStore) {
+  return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
+                                       tileToStore));
+}
+
+/// Load vector values to an AMX tile.
 static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
                                           TypedValue<VectorType> vec) {
   Location loc = vec.getLoc();
-  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
 
-  // Transfer the vector to a tile through an intermediate buffer.
   VectorType vecTy = vec.getType();
+  bool isPacked = vecTy.getRank() == 3;
+
+  // Try to load tile directly from vector producer's buffer.
+  auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
+  FailureOr<TypedValue<amx::TileType>> tile =
+      loadFromTransfer(rewriter, readOp, isPacked);
+  if (succeeded(tile))
+    return *tile;
+
+  // Transfer the vector to a tile through an intermediate buffer.
   Value buf = memref::AllocaOp::create(
       rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
   SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
   vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
 
   // Collapse the VNNI dimension in case of packing.
-  bool isPacked = vecTy.getRank() == 3;
   if (isPacked)
     buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
 
@@ -212,17 +340,17 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
                                  {zeroIndex, zeroIndex});
 }
 
-/// Stores an AMX tile in a vector.
+/// Store an AMX tile in a vector.
 static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
                                         TypedValue<amx::TileType> tile) {
   Location loc = tile.getLoc();
-  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
 
   // Transfer the tile to a vector through an intermediate buffer.
   amx::TileType tileTy = tile.getType();
   Value buf = memref::AllocaOp::create(
       rewriter, loc,
       MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
   SmallVector<Value> indices(2, zeroIndex);
   amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
 
@@ -258,8 +386,22 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
                                         lhsTile, rhsTile, accTile);
     }
 
-    Value res = storeTile(rewriter, tileMul);
-    rewriter.replaceOp(contractOp, res);
+    // If the contraction result is only written back to memory, try to replace
+    // the vector op with an AMX store directly.
+    Value res = contractOp.getResult();
+    if (res.hasOneUse()) {
+      auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
+      LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
+      if (succeeded(storeRes)) {
+        rewriter.eraseOp(writeOp);
+        rewriter.eraseOp(contractOp);
+        return success();
+      }
+    }
+
+    // Load the result back into a vector.
+    Value newResult = storeTile(rewriter, tileMul);
+    rewriter.replaceOp(contractOp, newResult);
 
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
new file mode 100644
index 0000000000000..de1e0e9fdaeb2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
@@ -0,0 +1,278 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// These test cases validate replacement of vector transfer ops with equivalent
+/// AMX tile data transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
+    %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
+  %c0_f16 = arith.constant 0.0 : f16
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<64x32x16x2xf16>, vector<4x8x2xf16>
+  %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<64x16x32x2xf16>, vector<8x16x2xf16>
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+  return
+}
+
+// CHECK-LABEL: @transfers_into_amx_tiles(
+// CHECK-SAME:    %[[A:.+]]: memref<64x32x16x2xf16>,
+// CHECK-SAME:    %[[B:.+]]: memref<64x16x32x2xf16>,
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>,
+// CHECK-SAME:    %[[IDX:.+]]: index
+// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+
+/// Load A into an AMX tile
+// CHECK:       %[[A_SUBVIEW:.+]] = memref.subview %[[A]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
+// CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
+// CHECK:       %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[A]]
+
+/// Load B into an AMX tile
+// CHECK:       %[[B_SUBVIEW:.+]] = memref.subview %[[B]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
+// CHECK-SAME:    {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
+// CHECK:       %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[B]]
+
+/// Load C into an AMX tile
+// CHECK:       %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK-SAME:    {{\[}}%[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT:   vector.transfer_read %[[C]]
+
+/// Perform tile multiplication
+// CHECK:       %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME:    %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Store the result back
+// CHECK:       %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME:    {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK:       amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK-NOT:   vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile can be loaded directly from the buffer. However, vector transfer
+/// has to remain due to other users that require data in registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+  return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @transfer_read_multiple_users(
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>,
+// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
+// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
+// CHECK: memref.alloca
+// CHECK: amx.tile_store
+// CHECK: vector.transfer_read
+// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]]
+
+// -----
+
+/// As contraction has multiple users, the results have to loaded back
+/// from AMX tile into registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_contract_multiple_users(%C: memref<64x64xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %vecC: vector<4x16xf32>, %idx: index) -> vector<4x16xf32> {
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+  %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+  return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_contract_multiple_users(
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>
+// CHECK:     %[[TILE_MUL:.+]] = amx.tile_mulf
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_out_of_bounds(%C: memref<64x64xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %vecC: vector<4x16xf32>, %idx: index) {
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, false]} : vector<4x16xf32>, memref<64x64xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_out_of_bounds(
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_non_identity_map(%C: memref<64x64xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %vecC: vector<4x16xf32>, %idx: index) {
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+    in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_non_identity_map(
+// CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile transfers require row elements to be contiguous
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_non_contiguous_row(
+    %A: memref<8x128x2xf16, strided<[256, 4, 1]>>,
+    %vecB: vector<8x16x2xf16>, %vecC: vector<4x16xf32>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f16 = arith.constant 0.0 : f16
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]}
+    : memref<8x128x2xf16, strided<[256, 4, 1]>>, vector<4x8x2xf16>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_non_contiguous_row(
+// CHECK-SAME:    %[[A:.+]]: memref<8x128x2xf16, strided<[256, 4, 1]>>
+// CHECK: vector.transfer_read %[[A]]
+
+// -----
+
+/// Buffer shape checks are conservative to avoid problems with deriving
+/// stride for AMX tile rows.
+/// When in doubt, vector operations are left to perform initial transfers.
+/// Afterwards, data can be placed in a contiguous temporary buffer which
+/// ensures correct layout for AMX transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_1D_buffer(%C: memref<512xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecC = vector.transfer_read %C[%idx], %c0_f32
+    {permutation_map = affine_map<(d0) -> (0, d0)>,
+    in_bounds = [true, true]} : memref<512xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_1D_buffer(
+// CHECK-SAME:    %[[C:.+]]: memref<512xf32>
+// CHECK: vector.transfer_read %[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_invalid_buffer_row_shape(%C: memref<5x2x4x4xf32>,
+    %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecC = vector.transfer_read %C[%idx, %idx, %idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<5x2x4x4xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_invalid_buffer_row_shape(
+// CHECK-SAME:    %[[C:.+]]: memref<5x2x4x4xf32>
+// CHECK: vector.transfer_read %[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_buffer_non_packed_source_shape(%A: memref<8x64x64xf16>,
+    %vecB: vector<8x16x2xf16>, %vecC: vector<4x16xf32>,
+    %idx: index) -> vector<4x16xf32> {
+  %c0_f16 = arith.constant 0.0 : f16
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<8x64x64xf16>, vector<4x8x2xf16>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_buffer_non_packed_source_shape(
+// CHECK-SAME:    %[[A:.+]]: memref<8x64x64xf16>
+// CHECK: vector.transfer_read %[[A]]

>From d39f2106b5b00d4a7e3f05978585a159170c8a0c Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 19 Aug 2025 12:16:50 +0200
Subject: [PATCH 2/4] Dynamic shapes + refine test desc

---
 .../Conversion/VectorToAMX/VectorToAMX.cpp    |  4 +-
 .../VectorToAMX/transfer-to-amx.mlir          | 89 +++++++++++++++++--
 2 files changed, 86 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index 23194d0a4359b..730ee842463f4 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -233,10 +233,12 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
   VectorType vecTy = xferOp.getVectorType();
   ArrayRef<int64_t> vecShape = vecTy.getShape();
   ArrayRef<int64_t> memShape = memTy.getShape();
-  if (memShape.back() < vecShape.back())
+  if (memShape.back() == ShapedType::kDynamic ||
+      memShape.back() < vecShape.back())
     return nullptr;
   if (isPacked &&
       (memShape.back() != vecShape.back() ||
+       memShape[memShape.size() - 2] == ShapedType::kDynamic ||
        memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
     return nullptr;
 
diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
index de1e0e9fdaeb2..f631cfb4ac603 100644
--- a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
+++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
@@ -6,7 +6,7 @@
 #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
 #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
 #map2 = affine_map<(m, n, k, vnni) -> (m, n)>
-func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
+func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
     %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
   %c0_f16 = arith.constant 0.0 : f16
   %c0_f32 = arith.constant 0.0 : f32
@@ -26,7 +26,7 @@ func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
   return
 }
 
-// CHECK-LABEL: @transfers_into_amx_tiles(
+// CHECK-LABEL: @transfers_static_dims(
 // CHECK-SAME:    %[[A:.+]]: memref<64x32x16x2xf16>,
 // CHECK-SAME:    %[[B:.+]]: memref<64x16x32x2xf16>,
 // CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>,
@@ -70,6 +70,40 @@ func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
 
 // -----
 
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_dynamic_outer_dims(%A: memref<?x?x16x2xf16>,
+    %B: memref<?x?x32x2xf16>, %C: memref<?x64xf32>, %idx: index) {
+  %c0_f16 = arith.constant 0.0 : f16
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<?x?x16x2xf16>, vector<4x8x2xf16>
+  %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<?x?x32x2xf16>, vector<8x16x2xf16>
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<?x64xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, true]} : vector<4x16xf32>, memref<?x64xf32>
+  return
+}
+
+// CHECK-LABEL: @transfer_dynamic_outer_dims(
+// CHECK-SAME:    %[[A:.+]]: memref<?x?x16x2xf16>,
+// CHECK-SAME:    %[[B:.+]]: memref<?x?x32x2xf16>,
+// CHECK-SAME:    %[[C:.+]]: memref<?x64xf32>
+// CHECK-NOT:  vector.transfer_read %[[A]]
+// CHECK-NOT:  vector.transfer_read %[[B]]
+// CHECK-NOT:  vector.transfer_read %[[C]]
+// CHECK-NOT:  vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
 /// AMX tile can be loaded directly from the buffer. However, vector transfer
 /// has to remain due to other users that require data in registers.
 
@@ -93,14 +127,22 @@ func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
 
 // CHECK-LABEL: @transfer_read_multiple_users(
 // CHECK-SAME:    %[[C:.+]]: memref<64x64xf32>,
+
+/// Load to AMX tile directly from buffer.
 // CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
 // CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+
+/// Vector read remains to load data for the other non-AMX consumer.
 // CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
+
+/// Contraction uses the directly loaded tile.
 // CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
-// CHECK: memref.alloca
-// CHECK: amx.tile_store
-// CHECK: vector.transfer_read
-// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]]
+
+/// Consumer uses original C value and the updated one after contraction.
+// CHECK: %[[RES_BUF:.+]] = memref.alloca
+// CHECK: amx.tile_store %[[RES_BUF]]
+// CHECK: %[[RES_VEC:.+]] = vector.transfer_read %[[RES_BUF]]
+// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]], %[[RES_VEC]]
 
 // -----
 
@@ -233,6 +275,41 @@ func.func @negative_1D_buffer(%C: memref<512xf32>,
 
 // -----
 
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_dynamic_shapes(%A: memref<?x?x?x2xf16>,
+    %B: memref<?x?x2xf16>, %C: memref<?x?xf32>, %idx: index) {
+  %c0_f16 = arith.constant 0.0 : f16
+  %c0_f32 = arith.constant 0.0 : f32
+  %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<?x?x?x2xf16>, vector<4x8x2xf16>
+  %vecB = vector.transfer_read %B[%idx, %idx, %idx], %c0_f16
+    {in_bounds = [true, true, true]} : memref<?x?x2xf16>, vector<8x16x2xf16>
+  %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+    {in_bounds = [true, true]} : memref<?x?xf32>, vector<4x16xf32>
+  %vecD = vector.contract
+    {kind = #vector.kind<add>,
+    indexing_maps = [#map, #map1, #map2],
+    iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+    %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+  vector.transfer_write %vecD, %C[%idx, %idx]
+    {in_bounds = [true, true]} : vector<4x16xf32>, memref<?x?xf32>
+  return
+}
+
+// CHECK-LABEL: @negative_dynamic_shapes(
+// CHECK-SAME:    %[[A:.+]]: memref<?x?x?x2xf16>,
+// CHECK-SAME:    %[[B:.+]]: memref<?x?x2xf16>,
+// CHECK-SAME:    %[[C:.+]]: memref<?x?xf32>
+// CHECK:  vector.transfer_read %[[A]]
+// CHECK:  vector.transfer_read %[[B]]
+// CHECK:  vector.transfer_read %[[C]]
+// CHECK:  vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+
 #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
 #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
 #map2 = affine_map<(m, n, k, vnni) -> (m, n)>

>From a4545227f86d4727515e50b6b62a6618fd85a503 Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 19 Aug 2025 12:17:51 +0200
Subject: [PATCH 3/4] Typo

---
 mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
index f631cfb4ac603..8fab4cf1f7ed1 100644
--- a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
+++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
@@ -73,7 +73,7 @@ func.func @transfers_static_dims(%A: memref<64x32x16x2xf16>,
 #map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
 #map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
 #map2 = affine_map<(m, n, k, vnni) -> (m, n)>
-func.func @transfer_dynamic_outer_dims(%A: memref<?x?x16x2xf16>,
+func.func @transfers_dynamic_outer_dims(%A: memref<?x?x16x2xf16>,
     %B: memref<?x?x32x2xf16>, %C: memref<?x64xf32>, %idx: index) {
   %c0_f16 = arith.constant 0.0 : f16
   %c0_f32 = arith.constant 0.0 : f32
@@ -93,7 +93,7 @@ func.func @transfer_dynamic_outer_dims(%A: memref<?x?x16x2xf16>,
   return
 }
 
-// CHECK-LABEL: @transfer_dynamic_outer_dims(
+// CHECK-LABEL: @transfers_dynamic_outer_dims(
 // CHECK-SAME:    %[[A:.+]]: memref<?x?x16x2xf16>,
 // CHECK-SAME:    %[[B:.+]]: memref<?x?x32x2xf16>,
 // CHECK-SAME:    %[[C:.+]]: memref<?x64xf32>

>From 5a8e13470802bd73fac7d590aa952e13d0e8808f Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 19 Aug 2025 13:06:52 +0200
Subject: [PATCH 4/4] Extra op type checks

---
 mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index 730ee842463f4..7b9ed1d8cd21a 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -197,7 +197,7 @@ static Operation *
 loadStoreFromTransfer(PatternRewriter &rewriter,
                       VectorTransferOpInterface xferOp, bool isPacked,
                       TypedValue<amx::TileType> tileToStore = nullptr) {
-  if (!xferOp)
+  if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp))
     return nullptr;
   if (xferOp.hasOutOfBoundsDim() ||
       !xferOp.getPermutationMap().isMinorIdentity())
@@ -279,6 +279,8 @@ loadStoreFromTransfer(PatternRewriter &rewriter,
   } else if (isa<vector::TransferWriteOp>(xferOp)) {
     amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
                                          tileToStore);
+  } else {
+    llvm_unreachable("unsupported vector transfer op");
   }
 
   return amxTileOp;



More information about the Mlir-commits mailing list