[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)

Renato Golin llvmlistbot at llvm.org
Tue Aug 19 03:49:11 PDT 2025


================
@@ -184,21 +189,144 @@ static Value collapseLastDim(PatternRewriter &rewriter,
                                          reassocIndices);
 }
 
-/// Loads vector values to an AMX tile.
+/// Attempt to create an AMX tile load/store operation equivalent to the given
+/// vector transfer `xfer` op.
+/// This approach allows to skip longer route through registers and a temporary
+/// buffer otherwise required to move data to/from an AMX tile.
+static Operation *
+loadStoreFromTransfer(PatternRewriter &rewriter,
+                      VectorTransferOpInterface xferOp, bool isPacked,
+                      TypedValue<amx::TileType> tileToStore = nullptr) {
+  if (!xferOp)
+    return nullptr;
+  if (xferOp.hasOutOfBoundsDim() ||
+      !xferOp.getPermutationMap().isMinorIdentity())
+    return nullptr;
+
+  // Extra checks in case of a write op.
+  // Stores must not be packed.
+  if (isa<vector::TransferWriteOp>(xferOp) &&
+      (!tileToStore || isPacked ||
+       tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
+    return nullptr;
+
+  // Check for a memref source buffer.
+  // AMX data transfer requires at least 2D shape to correctly
+  // infer stride between rows.
+  Value base = xferOp.getBase();
+  auto memTy = dyn_cast<MemRefType>(base.getType());
+  int64_t memRank = memTy.getRank();
+  if (!memTy || memRank < 2)
+    return nullptr;
+
+  // Check that the source buffer has enough contiguous elements to load whole
+  // AMX tile row.
+  //
+  // To ensure correctness, the validation is conservative and expects the
+  // buffer's innermost dimensions to be statically known, equal to or larger
+  // than the vector row length, and equal to the VNNI dimension if applicable.
+  //
+  // This check could be relaxed to accept more arbitrarily shaped buffers as
+  // long as there are enough contiguous elements to load a whole row.
+  if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
+    return nullptr;
+  VectorType vecTy = xferOp.getVectorType();
+  ArrayRef<int64_t> vecShape = vecTy.getShape();
+  ArrayRef<int64_t> memShape = memTy.getShape();
+  if (memShape.back() < vecShape.back())
+    return nullptr;
+  if (isPacked &&
+      (memShape.back() != vecShape.back() ||
+       memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
+    return nullptr;
+
+  // Load values directly from the buffer to an AMX tile.
+  PatternRewriter::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(xferOp);
+  Location loc = xferOp.getLoc();
+
+  // Create a subview of the source buffer based on the transfer op to resolve
+  // offsets.
+  SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
+  int64_t vecRank = vecTy.getRank();
+  assert(memRank >= vecRank &&
+         "Expects buffer to be the same or greater rank than vector");
+  SmallVector<int64_t> shape(memRank - vecRank, 1);
+  shape.append(vecShape.begin(), vecShape.end());
+  TypedValue<MemRefType> src =
+      memref::SubViewOp::create(
+          rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
+          getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
+          .getResult();
+
+  // Collapse the VNNI dimension in case of packing.
+  if (isPacked)
+    src = collapseLastDim(rewriter, src);
+  int64_t rows = vecShape[0];
+  int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
+                                 std::multiplies<int64_t>());
+  auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+  Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+  SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
+
+  Operation *amxTileOp = nullptr;
+  if (isa<vector::TransferReadOp>(xferOp)) {
+    amxTileOp =
+        amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+  } else if (isa<vector::TransferWriteOp>(xferOp)) {
+    amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+                                         tileToStore);
+  }
----------------
rengolin wrote:

If there's a case where you got all the way here and still isn't handling transfer read/write, we perhaps should assert / unreachable or make an earlier check.

https://github.com/llvm/llvm-project/pull/154114


More information about the Mlir-commits mailing list