[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 18 06:25:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Adam Siemieniuk (adam-smnk)
<details>
<summary>Changes</summary>
Extends Vector to AMX conversion to attempt populating AMX tiles directly from memory.
When possible, contraction producers and consumers are replaced by AMX tile data transfer operations. This shortens data path by skipping intermediate register loads and stores.
---
Patch is 22.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154114.diff
2 Files Affected:
- (modified) mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp (+156-14)
- (added) mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir (+278)
``````````diff
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
index a11e9b2624300..23194d0a4359b 100644
--- a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -10,7 +10,6 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -21,6 +20,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/Support/DebugLog.h"
+
#include <numeric>
namespace mlir {
@@ -30,6 +31,8 @@ namespace mlir {
using namespace mlir;
+#define DEBUG_TYPE "vector-to-amx"
+
namespace {
/// Return true if vector shape is compatible with AMX tiles.
@@ -49,8 +52,10 @@ static bool verifyAmxShape(VectorType vec) {
// 3D shape indicates VNNI packed layout.
if (vec.getRank() == 3) {
int64_t vnniFactor = 32 / elemBitWidth;
- if (shape.back() != vnniFactor)
+ if (shape.back() != vnniFactor) {
+ LDBG() << "invalid VNNI packing factor";
return false;
+ }
cols *= vnniFactor;
}
@@ -60,7 +65,7 @@ static bool verifyAmxShape(VectorType vec) {
return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
}
-/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
+/// Check if contraction operands are in AMX-compatible packed VNNI layout.
static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
vector::ContractionOp contractOp) {
VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
@@ -172,9 +177,9 @@ static LogicalResult validateOperands(PatternRewriter &rewriter,
return success();
}
-/// Collapses the two innermost dimensions together.
-static Value collapseLastDim(PatternRewriter &rewriter,
- TypedValue<MemRefType> memref) {
+/// Collapse the two innermost dimensions together.
+static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter,
+ TypedValue<MemRefType> memref) {
int64_t rank = memref.getType().getRank();
SmallVector<ReassociationIndices> reassocIndices;
for (auto i : llvm::seq<int64_t>(0, rank - 2))
@@ -184,21 +189,144 @@ static Value collapseLastDim(PatternRewriter &rewriter,
reassocIndices);
}
-/// Loads vector values to an AMX tile.
+/// Attempt to create an AMX tile load/store operation equivalent to the given
+/// vector transfer `xfer` op.
+/// This approach allows to skip longer route through registers and a temporary
+/// buffer otherwise required to move data to/from an AMX tile.
+static Operation *
+loadStoreFromTransfer(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp, bool isPacked,
+ TypedValue<amx::TileType> tileToStore = nullptr) {
+ if (!xferOp)
+ return nullptr;
+ if (xferOp.hasOutOfBoundsDim() ||
+ !xferOp.getPermutationMap().isMinorIdentity())
+ return nullptr;
+
+ // Extra checks in case of a write op.
+ // Stores must not be packed.
+ if (isa<vector::TransferWriteOp>(xferOp) &&
+ (!tileToStore || isPacked ||
+ tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
+ return nullptr;
+
+ // Check for a memref source buffer.
+ // AMX data transfer requires at least 2D shape to correctly
+ // infer stride between rows.
+ Value base = xferOp.getBase();
+ auto memTy = dyn_cast<MemRefType>(base.getType());
+ int64_t memRank = memTy.getRank();
+ if (!memTy || memRank < 2)
+ return nullptr;
+
+ // Check that the source buffer has enough contiguous elements to load whole
+ // AMX tile row.
+ //
+ // To ensure correctness, the validation is conservative and expects the
+ // buffer's innermost dimensions to be statically known, equal to or larger
+ // than the vector row length, and equal to the VNNI dimension if applicable.
+ //
+ // This check could be relaxed to accept more arbitrarily shaped buffers as
+ // long as there are enough contiguous elements to load a whole row.
+ if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
+ return nullptr;
+ VectorType vecTy = xferOp.getVectorType();
+ ArrayRef<int64_t> vecShape = vecTy.getShape();
+ ArrayRef<int64_t> memShape = memTy.getShape();
+ if (memShape.back() < vecShape.back())
+ return nullptr;
+ if (isPacked &&
+ (memShape.back() != vecShape.back() ||
+ memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
+ return nullptr;
+
+ // Load values directly from the buffer to an AMX tile.
+ PatternRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(xferOp);
+ Location loc = xferOp.getLoc();
+
+ // Create a subview of the source buffer based on the transfer op to resolve
+ // offsets.
+ SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
+ int64_t vecRank = vecTy.getRank();
+ assert(memRank >= vecRank &&
+ "Expects buffer to be the same or greater rank than vector");
+ SmallVector<int64_t> shape(memRank - vecRank, 1);
+ shape.append(vecShape.begin(), vecShape.end());
+ TypedValue<MemRefType> src =
+ memref::SubViewOp::create(
+ rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
+ .getResult();
+
+ // Collapse the VNNI dimension in case of packing.
+ if (isPacked)
+ src = collapseLastDim(rewriter, src);
+ int64_t rows = vecShape[0];
+ int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
+
+ Operation *amxTileOp = nullptr;
+ if (isa<vector::TransferReadOp>(xferOp)) {
+ amxTileOp =
+ amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+ } else if (isa<vector::TransferWriteOp>(xferOp)) {
+ amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+ tileToStore);
+ }
+
+ return amxTileOp;
+}
+
+/// Attempt to create an AMX tile load operation equivalent to the given
+/// vector transfer `readOp`.
+/// Returns loaded AMX tile if successful.
+static FailureOr<TypedValue<amx::TileType>>
+loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp,
+ bool isPacked) {
+ amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>(
+ loadStoreFromTransfer(rewriter, readOp, isPacked));
+ if (!loadOp)
+ return failure();
+ return loadOp.getRes();
+}
+
+/// Attempt to create an AMX tile store operation equivalent to the given
+/// vector transfer `writeOp`.
+static LogicalResult storeFromTransfer(PatternRewriter &rewriter,
+ vector::TransferWriteOp writeOp,
+ TypedValue<amx::TileType> tileToStore) {
+ return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false,
+ tileToStore));
+}
+
+/// Load vector values to an AMX tile.
static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
TypedValue<VectorType> vec) {
Location loc = vec.getLoc();
- Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
- // Transfer the vector to a tile through an intermediate buffer.
VectorType vecTy = vec.getType();
+ bool isPacked = vecTy.getRank() == 3;
+
+ // Try to load tile directly from vector producer's buffer.
+ auto readOp = vec.getDefiningOp<vector::TransferReadOp>();
+ FailureOr<TypedValue<amx::TileType>> tile =
+ loadFromTransfer(rewriter, readOp, isPacked);
+ if (succeeded(tile))
+ return *tile;
+
+ // Transfer the vector to a tile through an intermediate buffer.
Value buf = memref::AllocaOp::create(
rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
// Collapse the VNNI dimension in case of packing.
- bool isPacked = vecTy.getRank() == 3;
if (isPacked)
buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
@@ -212,17 +340,17 @@ static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
{zeroIndex, zeroIndex});
}
-/// Stores an AMX tile in a vector.
+/// Store an AMX tile in a vector.
static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
TypedValue<amx::TileType> tile) {
Location loc = tile.getLoc();
- Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
// Transfer the tile to a vector through an intermediate buffer.
amx::TileType tileTy = tile.getType();
Value buf = memref::AllocaOp::create(
rewriter, loc,
MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(2, zeroIndex);
amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
@@ -258,8 +386,22 @@ struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
lhsTile, rhsTile, accTile);
}
- Value res = storeTile(rewriter, tileMul);
- rewriter.replaceOp(contractOp, res);
+ // If the contraction result is only written back to memory, try to replace
+ // the vector op with an AMX store directly.
+ Value res = contractOp.getResult();
+ if (res.hasOneUse()) {
+ auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin());
+ LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul);
+ if (succeeded(storeRes)) {
+ rewriter.eraseOp(writeOp);
+ rewriter.eraseOp(contractOp);
+ return success();
+ }
+ }
+
+ // Load the result back into a vector.
+ Value newResult = storeTile(rewriter, tileMul);
+ rewriter.replaceOp(contractOp, newResult);
return success();
}
diff --git a/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
new file mode 100644
index 0000000000000..de1e0e9fdaeb2
--- /dev/null
+++ b/mlir/test/Conversion/VectorToAMX/transfer-to-amx.mlir
@@ -0,0 +1,278 @@
+// RUN: mlir-opt %s -convert-vector-to-amx -split-input-file | FileCheck %s
+
+/// These test cases validate replacement of vector transfer ops with equivalent
+/// AMX tile data transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfers_into_amx_tiles(%A: memref<64x32x16x2xf16>,
+ %B: memref<64x16x32x2xf16>, %C: memref<64x64xf32>, %idx: index) {
+ %c0_f16 = arith.constant 0.0 : f16
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<64x32x16x2xf16>, vector<4x8x2xf16>
+ %vecB = vector.transfer_read %B[%idx, %idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]} : memref<64x16x32x2xf16>, vector<8x16x2xf16>
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @transfers_into_amx_tiles(
+// CHECK-SAME: %[[A:.+]]: memref<64x32x16x2xf16>,
+// CHECK-SAME: %[[B:.+]]: memref<64x16x32x2xf16>,
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>,
+// CHECK-SAME: %[[IDX:.+]]: index
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+
+/// Load A into an AMX tile
+// CHECK: %[[A_SUBVIEW:.+]] = memref.subview %[[A]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[A_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[A_SUBVIEW]]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x4x8x2xf16{{.*}}into memref<1x4x16xf16
+// CHECK: %[[A_TILE:.+]] = amx.tile_load %[[A_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[A]]
+
+/// Load B into an AMX tile
+// CHECK: %[[B_SUBVIEW:.+]] = memref.subview %[[B]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]], %[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[B_PACKED_DIM_COLLAPSE:.+]] = memref.collapse_shape %[[B_SUBVIEW]]
+// CHECK-SAME: {{\[}}[0], [1], [2, 3]] : memref<1x8x16x2xf16{{.*}}into memref<1x8x32xf16
+// CHECK: %[[B_TILE:.+]] = amx.tile_load %[[B_PACKED_DIM_COLLAPSE]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[B]]
+
+/// Load C into an AMX tile
+// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK-SAME: {{\[}}%[[C0]], %[[C0]]{{\]}}
+// CHECK-NOT: vector.transfer_read %[[C]]
+
+/// Perform tile multiplication
+// CHECK: %[[RES:.+]] = amx.tile_mulf
+// CHECK-SAME: %[[A_TILE]], %[[B_TILE]], %[[C_TILE]]
+
+/// Store the result back
+// CHECK: %[[RES_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK-SAME: {{\[}}%[[IDX]], %[[IDX]]{{\]}}
+// CHECK: amx.tile_store %[[RES_SUBVIEW]]{{\[}}%[[C0]], %[[C0]]{{\]}}, %[[RES]]
+// CHECK-NOT: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile can be loaded directly from the buffer. However, vector transfer
+/// has to remain due to other users that require data in registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @transfer_read_multiple_users(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecC = vector.transfer_read %C[%idx, %idx], %c0_f32
+ {in_bounds = [true, true]} : memref<64x64xf32>, vector<4x16xf32>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+ return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @transfer_read_multiple_users(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>,
+// CHECK: %[[C_SUBVIEW:.+]] = memref.subview %[[C]]
+// CHECK: %[[C_TILE:.+]] = amx.tile_load %[[C_SUBVIEW]]
+// CHECK: %[[C_VEC:.+]] = vector.transfer_read %[[C]]
+// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf{{.*}}%[[C_TILE]]
+// CHECK: memref.alloca
+// CHECK: amx.tile_store
+// CHECK: vector.transfer_read
+// CHECK: %[[VEC_MUL:.+]] = arith.mulf %[[C_VEC]]
+
+// -----
+
+/// As contraction has multiple users, the results have to loaded back
+/// from AMX tile into registers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_contract_multiple_users(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %vecC: vector<4x16xf32>, %idx: index) -> vector<4x16xf32> {
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ %mul = arith.mulf %vecC, %vecD : vector<4x16xf32>
+ return %mul : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_contract_multiple_users(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
+// CHECK: %[[TILE_MUL:.+]] = amx.tile_mulf
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_out_of_bounds(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %vecC: vector<4x16xf32>, %idx: index) {
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {in_bounds = [true, false]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_out_of_bounds(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_non_identity_map(%C: memref<64x64xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %vecC: vector<4x16xf32>, %idx: index) {
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ vector.transfer_write %vecD, %C[%idx, %idx]
+ {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+ in_bounds = [true, true]} : vector<4x16xf32>, memref<64x64xf32>
+ return
+}
+
+// CHECK-LABEL: @negative_non_identity_map(
+// CHECK-SAME: %[[C:.+]]: memref<64x64xf32>
+// CHECK: vector.transfer_write{{.*}}%[[C]]
+
+// -----
+
+/// AMX tile transfers require row elements to be contiguous
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_non_contiguous_row(
+ %A: memref<8x128x2xf16, strided<[256, 4, 1]>>,
+ %vecB: vector<8x16x2xf16>, %vecC: vector<4x16xf32>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f16 = arith.constant 0.0 : f16
+ %vecA = vector.transfer_read %A[%idx, %idx, %idx], %c0_f16
+ {in_bounds = [true, true, true]}
+ : memref<8x128x2xf16, strided<[256, 4, 1]>>, vector<4x8x2xf16>
+ %vecD = vector.contract
+ {kind = #vector.kind<add>,
+ indexing_maps = [#map, #map1, #map2],
+ iterator_types = ["parallel", "parallel", "reduction", "reduction"]}
+ %vecA, %vecB, %vecC : vector<4x8x2xf16>, vector<8x16x2xf16> into vector<4x16xf32>
+ return %vecD : vector<4x16xf32>
+}
+
+// CHECK-LABEL: @negative_non_contiguous_row(
+// CHECK-SAME: %[[A:.+]]: memref<8x128x2xf16, strided<[256, 4, 1]>>
+// CHECK: vector.transfer_read %[[A]]
+
+// -----
+
+/// Buffer shape checks are conservative to avoid problems with deriving
+/// stride for AMX tile rows.
+/// When in doubt, vector operations are left to perform initial transfers.
+/// Afterwards, data can be placed in a contiguous temporary buffer which
+/// ensures correct layout for AMX transfers.
+
+#map = affine_map<(m, n, k, vnni) -> (m, k, vnni)>
+#map1 = affine_map<(m, n, k, vnni) -> (k, n, vnni)>
+#map2 = affine_map<(m, n, k, vnni) -> (m, n)>
+func.func @negative_1D_buffer(%C: memref<512xf32>,
+ %vecA: vector<4x8x2xf16>, %vecB: vector<8x16x2xf16>,
+ %idx: index) -> vector<4x16xf32> {
+ %c0_f32 = arith.constant 0.0 : f32
+ %vecC = vector.transfer_read ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/154114
More information about the Mlir-commits
mailing list