[Mlir-commits] [mlir] [mlir][amx] Direct AMX data transfers (PR #154114)
Renato Golin
llvmlistbot at llvm.org
Tue Aug 19 03:49:11 PDT 2025
================
@@ -184,21 +189,144 @@ static Value collapseLastDim(PatternRewriter &rewriter,
reassocIndices);
}
-/// Loads vector values to an AMX tile.
+/// Attempt to create an AMX tile load/store operation equivalent to the given
+/// vector transfer `xfer` op.
+/// This approach allows to skip longer route through registers and a temporary
+/// buffer otherwise required to move data to/from an AMX tile.
+static Operation *
+loadStoreFromTransfer(PatternRewriter &rewriter,
+ VectorTransferOpInterface xferOp, bool isPacked,
+ TypedValue<amx::TileType> tileToStore = nullptr) {
+ if (!xferOp)
+ return nullptr;
+ if (xferOp.hasOutOfBoundsDim() ||
+ !xferOp.getPermutationMap().isMinorIdentity())
+ return nullptr;
+
+ // Extra checks in case of a write op.
+ // Stores must not be packed.
+ if (isa<vector::TransferWriteOp>(xferOp) &&
+ (!tileToStore || isPacked ||
+ tileToStore.getType().getShape() != xferOp.getVectorType().getShape()))
+ return nullptr;
+
+ // Check for a memref source buffer.
+ // AMX data transfer requires at least 2D shape to correctly
+ // infer stride between rows.
+ Value base = xferOp.getBase();
+ auto memTy = dyn_cast<MemRefType>(base.getType());
+ int64_t memRank = memTy.getRank();
+ if (!memTy || memRank < 2)
+ return nullptr;
+
+ // Check that the source buffer has enough contiguous elements to load whole
+ // AMX tile row.
+ //
+ // To ensure correctness, the validation is conservative and expects the
+ // buffer's innermost dimensions to be statically known, equal to or larger
+ // than the vector row length, and equal to the VNNI dimension if applicable.
+ //
+ // This check could be relaxed to accept more arbitrarily shaped buffers as
+ // long as there are enough contiguous elements to load a whole row.
+ if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1))
+ return nullptr;
+ VectorType vecTy = xferOp.getVectorType();
+ ArrayRef<int64_t> vecShape = vecTy.getShape();
+ ArrayRef<int64_t> memShape = memTy.getShape();
+ if (memShape.back() < vecShape.back())
+ return nullptr;
+ if (isPacked &&
+ (memShape.back() != vecShape.back() ||
+ memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2]))
+ return nullptr;
+
+ // Load values directly from the buffer to an AMX tile.
+ PatternRewriter::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(xferOp);
+ Location loc = xferOp.getLoc();
+
+ // Create a subview of the source buffer based on the transfer op to resolve
+ // offsets.
+ SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1));
+ int64_t vecRank = vecTy.getRank();
+ assert(memRank >= vecRank &&
+ "Expects buffer to be the same or greater rank than vector");
+ SmallVector<int64_t> shape(memRank - vecRank, 1);
+ shape.append(vecShape.begin(), vecShape.end());
+ TypedValue<MemRefType> src =
+ memref::SubViewOp::create(
+ rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()),
+ getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides)
+ .getResult();
+
+ // Collapse the VNNI dimension in case of packing.
+ if (isPacked)
+ src = collapseLastDim(rewriter, src);
+ int64_t rows = vecShape[0];
+ int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex);
+
+ Operation *amxTileOp = nullptr;
+ if (isa<vector::TransferReadOp>(xferOp)) {
+ amxTileOp =
+ amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides);
+ } else if (isa<vector::TransferWriteOp>(xferOp)) {
+ amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides,
+ tileToStore);
+ }
----------------
rengolin wrote:
If there's a case where you got all the way here and still isn't handling transfer read/write, we perhaps should assert / unreachable or make an earlier check.
https://github.com/llvm/llvm-project/pull/154114
More information about the Mlir-commits
mailing list