[Mlir-commits] [mlir] [MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix (PR #162780)
Jianhui Li
llvmlistbot at llvm.org
Wed Oct 15 16:36:57 PDT 2025
https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/162780
>From 4c58d3d6a627f23425528668ffff92bcca8f1461 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 7 Oct 2025 22:08:30 +0000
Subject: [PATCH 01/12] pass basic lowering test
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 22 ++
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 +-
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 50 +++-
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 124 ++++++++++
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 216 ++++++++++++++++++
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 20 +-
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 6 +-
.../XeGPUToXeVM/loadstore_matrix.mlir | 40 ++++
8 files changed, 466 insertions(+), 18 deletions(-)
create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..601e966b49890 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().getAs<ArrayAttr>("stride");
}
+ ArrayAttr getBlockAttr() {
+ return getAttrs().getAs<ArrayAttr>("block");
+ }
+
}];
}
+def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
+def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
+def MatrixAccessDirection :
+ I32EnumAttr<"MatrixAccessDirection",
+ "Matrix elements/vectors can have row or column direction", [
+ RowOriented, ColOriented
+]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+def MatrixAccessDirectionAttr :
+ EnumAttr<XeGPU_Dialect,
+ MatrixAccessDirection,
+ "matrix_access_direction">{
+ let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
+ let assemblyFormat = "`<` $value `>`";
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061f5debe..32d21bae8cd34 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,8 +1298,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}
def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
- AllElementTypesMatch<["mem_desc", "res"]>,
- AllRanksMatch<["mem_desc", "res"]>]> {
+ AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
@@ -1344,8 +1343,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- AllElementTypesMatch<["mem_desc", "data"]>,
- AllRanksMatch<["mem_desc", "data"]>]> {
+ AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
XeGPU_ValueType:$data,
XeGPU_MemDesc:$mem_desc,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2039643..c261fbb576642 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}
- ArrayAttr getStrides() {
+ ArrayAttr getStridesAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
@@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}
+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
+ /// based on the rank and the value of the first stride dimension.
+ bool isColMajor() {
+ auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+ return getRank() == 2 && dim0 && dim0.getInt() == 1;
+ }
+
+ // get the Blocking shape for a MemDescType, Which is represented
+ // as an attribute in MemDescType. By default it is the shape
+ // of the mdescTy
+ SmallVector<int64_t> getBlockSize() {
+ SmallVector<int64_t> size(getShape());
+ MemLayoutAttr layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ ArrayAttr attr = layout.getBlockAttr();
+ size.clear();
+ llvm::for_each(attr, [&](Attribute elem) {
+ if (auto intElem = dyn_cast<IntegerAttr>(elem))
+ size.push_back(intElem.getInt());
+ });
+ }
+ return size;
+ }
+
+ // Get strides as vector of integer.
+ // If it contains block attribute, the strides are blocked strides.
+ //
+ // The blocking is applied against the original matrix shape
+ // so that the linear offset is not impacted by the subview.
+ //
+ // It first computes the original matrix shape using the stride info,
+ // then computes the number of blocks in each dimension of original shape,
+ // then compute the outer block shape and stride,
+ // then combines the inner and outer block shape and stride
+ // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+ // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+ SmallVector<int64_t> getStrides();
+
+ /// Generates instructions to compute the linearize offset
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+ // the strides of memory descriptor is always considered regardless of blocked or not
+ Value getLinearOffsets(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> offsets);
+
+
}];
let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d89069d6..666df293bb8be 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -32,6 +32,8 @@
#include <numeric>
+#define DEBUG_TYPE "xegpu-to-xevm"
+
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -60,6 +62,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
return static_cast<int>(xevm::AddrSpace::GLOBAL);
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
+ default:
+ llvm_unreachable("Unknown XeGPU memory space");
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
}
}
@@ -366,6 +371,7 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
Value baseAddr, Value offset, int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -503,6 +509,113 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
}
};
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+ using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // DEBUG: Print operation and types
+ LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
+ TypedValue<MemRefType> src = op.getSource();
+ auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+
+ // Create the result MemRefType with the same shape, element type, and memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero,
+ ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
+ return success();
+ }
+};
+
+class MemDescSubviewOpPattern final
+ : public OpConversionPattern<xegpu::MemDescSubviewOp> {
+public:
+ using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return rewriter.notifyMatchFailure(
+ op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
+ }
+};
+
+
+template <typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+ Value basePtrStruct = adaptor.getMemDesc();
+ Value mdescVal = op.getMemDesc();
+ // Load result or Store value Type can be vector or scalar.
+ Value data;
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+ data = op.getResult();
+ else
+ data = adaptor.getData();
+ VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+
+ int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ // Element type must be multiple of 8 bits.
+ if (elemBitWidth % 8 != 0)
+ return rewriter.notifyMatchFailure(
+ op, "Expected element type bit width to be multiple of 8.");
+ int64_t elemByteSize = elemBitWidth / 8;
+
+ // Default memory space is SLM.
+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+ auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct);
+
+ // Convert base pointer (ptr) to i64
+ Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+
+ Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+
+ // convert base pointer (i64) to LLVM pointer type
+ basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+
+ Value loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ auto storeOp =
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
+};
+
class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -785,6 +898,13 @@ struct ConvertXeGPUToXeVMPass
auto i32Type = IntegerType::get(&getContext(), 32);
return VectorType::get(8, i32Type);
});
+ // Convert MemDescType into flattened MemRefType for SLM
+ typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+ Type elemTy = type.getElementType();
+ int numElems = type.getNumElements();
+ return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+ });
+
typeConverter.addConversion([&](MemRefType type) -> Type {
// Convert MemRefType to i64 type.
return IntegerType::get(&getContext(), 64);
@@ -919,6 +1039,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
typeConverter, patterns.getContext());
+ patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+ LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+ CreateMemDescOpPattern, MemDescSubviewOpPattern>(
+ typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..c64699c12cf4a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,6 +37,8 @@ void XeGPUDialect::initialize() {
>();
}
+#define DEBUG_TYPE "xegpu"
+
/// Generates instructions to compute offsets for a subgroup identified by
/// its multidimensional indices (sgId), using the specified subgroup layout
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
@@ -726,6 +728,220 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ return blockedOffsets;
+}
+
+// Get strides as vector of integer for MemDesc.
+SmallVector<int64_t> MemDescType::getStrides() {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(),
+ getShape().end());
+
+ ArrayAttr strideAttr = getStridesAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+
+ llvm::dbgs() << "DEBUG: matrixShape = [";
+ for (size_t i = 0; i < matrixShape.size(); ++i) {
+ llvm::dbgs() << matrixShape[i];
+ if (i < matrixShape.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ llvm::dbgs() << "DEBUG: strides = [";
+ for (size_t i = 0; i < strides.size(); ++i) {
+ llvm::dbgs() << strides[i];
+ if (i < strides.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ SmallVector<int64_t> innerBlkShape = getBlockSize();
+ llvm::dbgs() << "DEBUG: innerBlkShape = [";
+ for (size_t i = 0; i < innerBlkShape.size(); ++i) {
+ llvm::dbgs() << innerBlkShape[i];
+ if (i < innerBlkShape.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ if (innerBlkShape.empty())
+ return strides;
+
+ SmallVector<int, 4> perm = llvm::to_vector<4>(
+ llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+ llvm::dbgs() << "DEBUG: perm = [";
+ for (size_t i = 0; i < perm.size(); ++i) {
+ llvm::dbgs() << perm[i];
+ if (i < perm.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+ SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
+
+ llvm::dbgs() << "DEBUG: innerBlkStride = [";
+ for (size_t i = 0; i < innerBlkStride.size(); ++i) {
+ llvm::dbgs() << innerBlkStride[i];
+ if (i < innerBlkStride.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
+
+ llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
+ for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
+ llvm::dbgs() << matrixShapeOrig[i];
+ if (i < matrixShapeOrig.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
+ for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
+ llvm::dbgs() << BlkShapeOrig[i];
+ if (i < BlkShapeOrig.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
+
+ llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+ }
+
+ llvm::dbgs() << "DEBUG: outerBlkStride = [";
+ for (size_t i = 0; i < outerBlkStride.size(); ++i) {
+ llvm::dbgs() << outerBlkStride[i];
+ if (i < outerBlkStride.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+ llvm::dbgs() << "DEBUG: blockedStrides = [";
+ for (size_t i = 0; i < blockedStrides.size(); ++i) {
+ llvm::dbgs() << blockedStrides[i];
+ if (i < blockedStrides.size() - 1) llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ return blockedStrides;
+ }
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets) {
+
+ SmallVector<int64_t> blockShape = getBlockSize();
+ SmallVector<int64_t> strides = getStrides();
+
+ LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
+ llvm::interleaveComma(blockShape, llvm::dbgs());
+ llvm::dbgs() << "], strides=[";
+ llvm::interleaveComma(strides, llvm::dbgs());
+ llvm::dbgs() << "]\n");
+
+ if (!blockShape.empty()) {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ offsets = blockedOffsets;
+ LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
+ << offsets.size() << "\n");
+ }
+
+ // Start with initial value as matrix descriptor's base offset.
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
+ << linearOffset << "\n");
+
+ return linearOffset;
+}
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788d0b9b4..23e487787652d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1062,9 +1062,12 @@ LogicalResult LoadMatrixOp::verify() {
ArrayRef<int64_t> valueShape = resTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed mem_desc shape.");
+
+ if (valueShape.size() != 1) {
+ if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed mem_desc shape.");
+ }
return success();
}
@@ -1092,10 +1095,11 @@ LogicalResult StoreMatrixOp::verify() {
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("data shape must not exceed mem_desc shape.");
-
+ if (dataShape.size() != 1) {
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("data shape must not exceed mem_desc shape.");
+ }
return success();
}
@@ -1127,7 +1131,7 @@ LogicalResult MemDescSubviewOp::verify() {
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitOpError("result shape must not exceed source shape.");
- if (srcTy.getStrides() != resTy.getStrides())
+ if (srcTy.getStridesAttr() != resTy.getStridesAttr())
return emitOpError("result must inherit the source strides.");
return success();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index e6f22f0a9acbb..bbf313bf4fb60 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,10 +1,6 @@
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
-#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-
-gpu.module @load_store_check {
+gpu.module @test_kernel {
// CHECK-LABEL: func.func @dpas(
// CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
new file mode 100644
index 0000000000000..30d6274c9dccf
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
+
+gpu.module @test_kernel {
+ //CHECK-LABEL: load_store_matrix_1
+ gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+ %tid_x = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
+ gpu.return %1: vector<8xf32>
+ }
+
+ // e.g. for mem_desc<32x32xf16, @block=[16, 16], @strides=[1, 16]>
+ // its memory layout tuple is ([2,2,16,16],[256,512,1,16])
+
+ //CHECK-LABEL: load_store_matrix_2
+ gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<stride = [1, 16], block = [16, 16]>>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+ %tid_x = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
+ gpu.return %1: vector<8xf32>
+ }
+
+ // e.g. for mem_desc<32x32xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,2,16,16],[512,256,16,1])
+ //CHECK-LABEL: load_store_matrix_3
+ gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<block = [16, 16]>>
+ //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32>
+ //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+ %tid_x = gpu.thread_id x
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
+ gpu.return %1: vector<8xf32>
+ }
+
+}
>From 554b95edf3079fee2ac91ccd22078886244724f0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 7 Oct 2025 23:17:20 +0000
Subject: [PATCH 02/12] add attributes
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 3 +
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 63 +++--
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 262 +++++++++---------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 7 +-
.../XeGPUToXeVM/loadstore_matrix.mlir | 47 ++--
mlir/test/Dialect/XeGPU/ops.mlir | 57 +++-
6 files changed, 253 insertions(+), 186 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 32d21bae8cd34..a0a8669baf90d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1302,6 +1302,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<I32Attr>:$vec_length,
+ OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+ OptionalAttr<UnitAttr>:$subgroupBlockIO,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let results = (outs XeGPU_ValueType:$res);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 666df293bb8be..97deca167204a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -371,7 +371,8 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
Value baseAddr, Value offset, int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemByteSize);
- offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset);
+ offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
+ offset);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -513,29 +514,36 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
// 32 bits will be converted to 32 bits.
class CreateMemDescOpPattern final
- : public OpConversionPattern<xegpu::CreateMemDescOp> {
+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
public:
using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- // DEBUG: Print operation and types
- LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
- TypedValue<MemRefType> src = op.getSource();
- auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
-
- // Create the result MemRefType with the same shape, element type, and memory space
- auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-
- LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
- LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
- LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
- Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
- auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero,
- ValueRange());
- rewriter.replaceOp(op, viewOp);
- LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
- return success();
+ ConversionPatternRewriter &rewriter) const override {
+ // DEBUG: Print operation and types
+ LLVM_DEBUG(llvm::dbgs()
+ << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
+ TypedValue<MemRefType> src = op.getSource();
+ auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+
+ // Create the result MemRefType with the same shape, element type, and
+ // memory space
+ auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
+ LLVM_DEBUG(llvm::dbgs()
+ << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
+ Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+ auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+ Value(src), zero, ValueRange());
+ rewriter.replaceOp(op, viewOp);
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
+ return success();
}
};
@@ -551,7 +559,6 @@ class MemDescSubviewOpPattern final
}
};
-
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -577,7 +584,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
data = adaptor.getData();
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
- int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+ int64_t elemBitWidth =
+ valOrResVecTy.getElementType().getIntOrFloatBitWidth();
// Element type must be multiple of 8 bits.
if (elemBitWidth % 8 != 0)
return rewriter.notifyMatchFailure(
@@ -589,14 +597,17 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
-
- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct);
+
+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, basePtrStruct);
// Convert base pointer (ptr) to i64
- Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+ Value basePtrI64 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
- basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+ basePtrI64 =
+ addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
// convert base pointer (i64) to LLVM pointer type
basePtrLLVM =
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index c64699c12cf4a..3cbb39ee9b144 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -777,168 +777,176 @@ SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
return blockedOffsets;
}
-// Get strides as vector of integer for MemDesc.
+// Get strides as vector of integer for MemDesc.
SmallVector<int64_t> MemDescType::getStrides() {
- SmallVector<int64_t> matrixShape(getShape().begin(),
- getShape().end());
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
- ArrayAttr strideAttr = getStridesAttr();
- SmallVector<int64_t> strides;
- for (Attribute attr : strideAttr.getValue()) {
- strides.push_back(cast<IntegerAttr>(attr).getInt());
- }
+ ArrayAttr strideAttr = getStridesAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
- llvm::dbgs() << "DEBUG: matrixShape = [";
- for (size_t i = 0; i < matrixShape.size(); ++i) {
- llvm::dbgs() << matrixShape[i];
- if (i < matrixShape.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ llvm::dbgs() << "DEBUG: matrixShape = [";
+ for (size_t i = 0; i < matrixShape.size(); ++i) {
+ llvm::dbgs() << matrixShape[i];
+ if (i < matrixShape.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
- llvm::dbgs() << "DEBUG: strides = [";
- for (size_t i = 0; i < strides.size(); ++i) {
- llvm::dbgs() << strides[i];
- if (i < strides.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ llvm::dbgs() << "DEBUG: strides = [";
+ for (size_t i = 0; i < strides.size(); ++i) {
+ llvm::dbgs() << strides[i];
+ if (i < strides.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ SmallVector<int64_t> innerBlkShape = getBlockSize();
+ llvm::dbgs() << "DEBUG: innerBlkShape = [";
+ for (size_t i = 0; i < innerBlkShape.size(); ++i) {
+ llvm::dbgs() << innerBlkShape[i];
+ if (i < innerBlkShape.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
- SmallVector<int64_t> innerBlkShape = getBlockSize();
- llvm::dbgs() << "DEBUG: innerBlkShape = [";
- for (size_t i = 0; i < innerBlkShape.size(); ++i) {
- llvm::dbgs() << innerBlkShape[i];
- if (i < innerBlkShape.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ if (innerBlkShape.empty())
+ return strides;
- if (innerBlkShape.empty())
- return strides;
+ SmallVector<int, 4> perm =
+ llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
- SmallVector<int, 4> perm = llvm::to_vector<4>(
- llvm::seq<int>(0, strides.size()));
- llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+ llvm::dbgs() << "DEBUG: perm = [";
+ for (size_t i = 0; i < perm.size(); ++i) {
+ llvm::dbgs() << perm[i];
+ if (i < perm.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
- llvm::dbgs() << "DEBUG: perm = [";
- for (size_t i = 0; i < perm.size(); ++i) {
- llvm::dbgs() << perm[i];
- if (i < perm.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
- assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+ SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
- SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
-
- llvm::dbgs() << "DEBUG: innerBlkStride = [";
- for (size_t i = 0; i < innerBlkStride.size(); ++i) {
- llvm::dbgs() << innerBlkStride[i];
- if (i < innerBlkStride.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
- // compute the original matrix shape using the stride info
- // and compute the number of blocks in each dimension
- // The shape of highest dim can't be derived from stride info,
- // but doesn't impact the stride computation for blocked layout.
- SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
- SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
- for (size_t i = 0; i < perm.size() - 1; ++i) {
- matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
- BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
- }
+ llvm::dbgs() << "DEBUG: innerBlkStride = [";
+ for (size_t i = 0; i < innerBlkStride.size(); ++i) {
+ llvm::dbgs() << innerBlkStride[i];
+ if (i < innerBlkStride.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
- llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
- for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
- llvm::dbgs() << matrixShapeOrig[i];
- if (i < matrixShapeOrig.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
+ for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
+ llvm::dbgs() << matrixShapeOrig[i];
+ if (i < matrixShapeOrig.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
- llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
- for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
- llvm::dbgs() << BlkShapeOrig[i];
- if (i < BlkShapeOrig.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
+ for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
+ llvm::dbgs() << BlkShapeOrig[i];
+ if (i < BlkShapeOrig.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
- int64_t innerBlkSize = 1;
- for (auto s : innerBlkShape)
- innerBlkSize *= s;
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
- llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
+ llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
- SmallVector<int64_t> outerBlkStride(matrixShape.size());
- outerBlkStride[perm[0]] = innerBlkSize;
- for (size_t i = 0; i < perm.size() - 1; ++i) {
- outerBlkStride[perm[i + 1]] =
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
- }
-
- llvm::dbgs() << "DEBUG: outerBlkStride = [";
- for (size_t i = 0; i < outerBlkStride.size(); ++i) {
- llvm::dbgs() << outerBlkStride[i];
- if (i < outerBlkStride.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
- // combine the inner and outer strides
- SmallVector<int64_t> blockedStrides;
- blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
- blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+ }
- llvm::dbgs() << "DEBUG: blockedStrides = [";
- for (size_t i = 0; i < blockedStrides.size(); ++i) {
- llvm::dbgs() << blockedStrides[i];
- if (i < blockedStrides.size() - 1) llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ llvm::dbgs() << "DEBUG: outerBlkStride = [";
+ for (size_t i = 0; i < outerBlkStride.size(); ++i) {
+ llvm::dbgs() << outerBlkStride[i];
+ if (i < outerBlkStride.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+ llvm::dbgs() << "DEBUG: blockedStrides = [";
+ for (size_t i = 0; i < blockedStrides.size(); ++i) {
+ llvm::dbgs() << blockedStrides[i];
+ if (i < blockedStrides.size() - 1)
+ llvm::dbgs() << ", ";
+ }
+ llvm::dbgs() << "]\n";
- return blockedStrides;
- }
+ return blockedStrides;
+}
// Calculate the linear offset using the blocked offsets and stride
Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
- ArrayRef<OpFoldResult> offsets) {
+ ArrayRef<OpFoldResult> offsets) {
SmallVector<int64_t> blockShape = getBlockSize();
SmallVector<int64_t> strides = getStrides();
-
+
LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
- llvm::interleaveComma(blockShape, llvm::dbgs());
- llvm::dbgs() << "], strides=[";
- llvm::interleaveComma(strides, llvm::dbgs());
- llvm::dbgs() << "]\n");
-
- if (!blockShape.empty()) {
- assert(offsets.size() == blockShape.size() &&
- "offsets and blockShape must have the same size");
- // say the original offset is [y, x], and the block shape is [By, Bx],
- // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
- SmallVector<OpFoldResult> blockedOffsets;
- SmallVector<OpFoldResult> divs, rems;
+ llvm::interleaveComma(blockShape, llvm::dbgs());
+ llvm::dbgs() << "], strides=[";
+ llvm::interleaveComma(strides, llvm::dbgs());
+ llvm::dbgs() << "]\n");
- for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
- divs.push_back(div(offset, block));
- rems.push_back(rem(offset, block));
- }
- blockedOffsets.append(divs.begin(), divs.end());
- blockedOffsets.append(rems.begin(), rems.end());
+ if (!blockShape.empty()) {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
- offsets = blockedOffsets;
- LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
- << offsets.size() << "\n");
+ offsets = blockedOffsets;
+ LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
+ << offsets.size() << "\n");
}
// Start with initial value as matrix descriptor's base offset.
Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
for (size_t i = 0; i < offsets.size(); ++i) {
- OpFoldResult mulResult = mul(offsets[i], strides[i]);
- Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
- linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
}
- LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
- << linearOffset << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
+ << linearOffset << "\n");
return linearOffset;
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 23e487787652d..c40d5a42fb6e5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1049,8 +1049,11 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
llvm::SmallVector<int64_t> staticOffsets;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ // Call the generated builder with all parameters (including optional ones as
+ // nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
+ /*subgroupBlockIO=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
@@ -1097,7 +1100,7 @@ LogicalResult StoreMatrixOp::verify() {
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
if (dataShape.size() != 1) {
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitOpError("data shape must not exceed mem_desc shape.");
}
return success();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 30d6274c9dccf..7b87f32b876fe 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,40 +1,41 @@
// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
gpu.module @test_kernel {
+
+ // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+ // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
//CHECK-LABEL: load_store_matrix_1
- gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+ gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<1xf32> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf32>
%tid_x = gpu.thread_id x
%c0 = arith.constant 0 : index
- %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
- gpu.return %1: vector<8xf32>
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<1xf32>
+ gpu.return %1: vector<1xf32>
}
- // e.g. for mem_desc<32x32xf16, @block=[16, 16], @strides=[1, 16]>
- // its memory layout tuple is ([2,2,16,16],[256,512,1,16])
-
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+ // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_2
- gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
- %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<stride = [1, 16], block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+ gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
%tid_x = gpu.thread_id x
- %c0 = arith.constant 0 : index
- %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
- gpu.return %1: vector<8xf32>
+ %c13 = arith.constant 13 : index
+ %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<1xf16>
+ gpu.return %1: vector<1xf16>
}
- // e.g. for mem_desc<32x32xf16, @block=[16, 16]>
- // its memory layout tuple is ([2,2,16,16],[512,256,16,1])
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_3
- gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
- %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<block = [16, 16]>>
- //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32>
- //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+ gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
%tid_x = gpu.thread_id x
- %c0 = arith.constant 0 : index
- %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
- gpu.return %1: vector<8xf32>
+ %c17 = arith.constant 17 : index
+ %1 = xegpu.load_matrix %0[%c17, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<1xf16>
+ gpu.return %1: vector<1xf16>
}
-}
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index bb379024a34d7..47aa05763ee99 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -825,35 +825,76 @@ gpu.func @create_mem_desc_with_stride() {
gpu.return
}
-// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
gpu.return
}
-// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
gpu.return
}
+// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>)
+gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+ %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
+gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroupBlockIO}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 16] {subgroupBlockIO}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ gpu.return
+}
-// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
gpu.return
}
-// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
// CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
+// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16>
+ xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
+gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] {vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 8] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
>From 446b951f2ed0bffd8be64955b7c4e5a94d5e2eb7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 8 Oct 2025 22:49:43 +0000
Subject: [PATCH 03/12] add tests and refactoring
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 19 +++-
.../lib/Conversion/XeGPUToXeVM/CMakeLists.txt | 1 +
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 80 ++++++++++++++--
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 12 ++-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 92 +++++++++++++------
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 4 +-
.../Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
mlir/test/Conversion/XeGPUToXeVM/dpas.mlir | 2 +-
.../XeGPUToXeVM/loadstore_matrix.mlir | 54 ++++++++---
mlir/test/Dialect/XeGPU/ops.mlir | 22 ++---
10 files changed, 211 insertions(+), 77 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a0a8669baf90d..044a8ef22d891 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1304,10 +1304,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<I32Attr>:$vec_length,
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
- OptionalAttr<UnitAttr>:$subgroupBlockIO,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
- let results = (outs XeGPU_ValueType:$res);
+ let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1338,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}
ArrayRef<int64_t> getDataShape() {
- return getRes().getType().getShape();
+ auto resTy = getRes().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+ return vecTy.getShape();
+ return {};
}
}];
@@ -1348,10 +1351,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
- XeGPU_ValueType:$data,
+ AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<I32Attr>:$vec_length,
+ OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+ OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1379,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}
ArrayRef<int64_t> getDataShape() {
- return getData().getType().getShape();
+ auto DataTy = getData().getType();
+ if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+ return vecTy.getShape();
+ return {};
}
}];
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b25809f1ed0..dd9edc43a1657 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 97deca167204a..f4f0a46c54089 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/FormatVariadic.h"
@@ -371,8 +372,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
Value baseAddr, Value offset, int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
rewriter, loc, rewriter.getI64Type(), elemByteSize);
- offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
- offset);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -583,6 +582,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
else
data = adaptor.getData();
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+ if (!valOrResVecTy)
+ valOrResVecTy = VectorType::get(1, data.getType());
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -606,6 +607,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+ linearOffset = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI64Type(), linearOffset);
basePtrI64 =
addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
@@ -613,15 +616,72 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
basePtrLLVM =
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
-
- Value loadOp =
- LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
- rewriter.replaceOp(op, loadOp);
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we need
+ // to handle this case separately.
+ if (valOrResVecTy.getNumElements() == 1) {
+ Type scalarTy = valOrResVecTy.getElementType();
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+ basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ return success();
} else {
- auto storeOp =
- LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
- rewriter.eraseOp(op);
+ // if the attribute 'subgroup_block_io' is set to true, it lowers to
+ // xevm.blockload
+ auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
+ bool subgroup_block_io =
+ subgroupBlockIoAttr && cast<BoolAttr>(subgroupBlockIoAttr).getValue();
+ if (subgroup_block_io) {
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy,
+ basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM,
+ adaptor.getData(), nullptr);
+ rewriter.eraseOp(op);
+ }
+ } else {
+ // if the result is 1D vector, if the vector direction is Column, then
+ // the
+ // memory descriptor should be treated as column major
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
+ }
+ xegpu::MatrixAccessDirectionAttr vecDirection =
+ op.getVecDirectionAttr();
+ if (vecDirection &&
+ vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
+ !mdescTy.isColMajor())
+ return rewriter.notifyMatchFailure(
+ op, "mem_desc should be column major when "
+ "vec_direction is COLUMN for 1D result.");
+ if (vecDirection &&
+ vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
+ mdescTy.isColMajor())
+ return rewriter.notifyMatchFailure(
+ op, "mem_desc should be row major when "
+ "vec_direction is ROW for 1D result.");
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+ basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
+ }
}
return success();
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3cbb39ee9b144..26f2f691ab860 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -813,9 +813,8 @@ SmallVector<int64_t> MemDescType::getStrides() {
}
llvm::dbgs() << "]\n";
- if (innerBlkShape.empty())
- return strides;
-
+ // get perm from FCD to LCD
+ // perm[i] = the dim with i-th smallest stride
SmallVector<int, 4> perm =
llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
@@ -908,6 +907,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> offsets) {
+ SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
SmallVector<int64_t> blockShape = getBlockSize();
SmallVector<int64_t> strides = getStrides();
@@ -917,7 +917,11 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
llvm::interleaveComma(strides, llvm::dbgs());
llvm::dbgs() << "]\n");
- if (!blockShape.empty()) {
+ // blockshape equal to matrixshape means no blocking
+ if (llvm::equal(blockShape, matrixShape)) {
+ // remove the outer dims from strides
+ strides.erase(strides.begin(), strides.begin() + matrixShape.size());
+ } else {
assert(offsets.size() == blockShape.size() &&
"offsets and blockShape must have the same size");
// say the original offset is [y, x], and the block shape is [By, Bx],
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index c40d5a42fb6e5..0bc7b3f06ec53 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -173,6 +173,51 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
return success();
}
+LogicalResult IsValidStoreMatrixParams(
+ VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
+ MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!dataTy)
+ if (subgroup_block_io || vecDirection || vecLength)
+ return emitError() << "vec_length, vec_direction and subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ else
+ return success();
+
+ if (mdescTy.getRank() != 2)
+ return emitError() << "mem_desc must be 2D.";
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+
+ if (dataShape.size() == 2) {
+ if (subgroup_block_io || vecDirection || vecLength)
+ return emitError() << "vec_length, vec_direction and subgroup_block_io "
+ "are only allowed when result is a 1D VectorType.";
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitError() << "data shape must not exceed mem_desc shape.";
+ } else if (dataShape.size() == 1) {
+
+ SmallVector<int64_t> blockSize = mdescTy.getBlockSize();
+ // if the subgroup_block_io attribute is set, mdescTy must have block
+ // attribute
+ if (subgroup_block_io && !blockSize.size())
+ return emitError() << "mem_desc must have block attribute when "
+ "subgroup_block_io is set.";
+ // if the subgroup_block_io attribute is set, the memdesc should be row
+ // major
+ if (subgroup_block_io && mdescTy.isColMajor())
+ return emitError() << "mem_desc should be row major when "
+ "subgroup_block_io is set.";
+ } else if (dataShape.size() == 0) {
+ return emitError() << "result shape must not be empty.";
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -1053,25 +1098,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
// nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
/*vec_length=*/nullptr, /*vec_direction=*/nullptr,
- /*subgroupBlockIO=*/nullptr, layout);
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult LoadMatrixOp::verify() {
- VectorType resTy = getRes().getType();
- MemDescType mdescTy = getMemDesc().getType();
-
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
- ArrayRef<int64_t> valueShape = resTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ auto resTy = dyn_cast<VectorType>(getRes().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
+ IntegerAttr vecLength = getVecLengthAttr();
+ MemDescType mdescTy = getMemDesc().getType();
- if (valueShape.size() != 1) {
- if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed mem_desc shape.");
- }
- return success();
+ return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
+ vecDirection, vecLength,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1086,24 +1126,20 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
- layout);
+ /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
+ /*subgroup_block_io=*/nullptr, layout);
}
LogicalResult StoreMatrixOp::verify() {
- VectorType dataTy = getData().getType();
- MemDescType mdescTy = getMemDesc().getType();
- if (mdescTy.getRank() != 2)
- return emitOpError("mem_desc must be 2D.");
-
- ArrayRef<int64_t> dataShape = dataTy.getShape();
- ArrayRef<int64_t> mdescShape = mdescTy.getShape();
- if (dataShape.size() != 1) {
- if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("data shape must not exceed mem_desc shape.");
- }
- return success();
+ auto dataTy = dyn_cast<VectorType>(getData().getType());
+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+ MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
+ IntegerAttr vecLength = getVecLengthAttr();
+ MemDescType mdescTy = getMemDesc().getType();
+ return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
+ vecDirection, vecLength,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0fe4b0b0..6d17b27849a43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -941,7 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- VectorType valueTy = op.getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
return failure();
@@ -984,7 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
return failure();
Location loc = op.getLoc();
- VectorType valueTy = op.getData().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
ArrayRef<int64_t> shape = valueTy.getShape();
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..d57289a6b21e9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -867,7 +867,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
return failure();
ArrayRef<int64_t> wgShape = op.getDataShape();
- VectorType valueTy = op.getRes().getType();
+ VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
Type elemTy = valueTy.getElementType();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index bbf313bf4fb60..a9ab0be00722c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -7,7 +7,7 @@ gpu.module @test_kernel {
// Loads are checked in a separate test.
// CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
// CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
- %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+ %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded
: vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
return %d : vector<8xf32>
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 7b87f32b876fe..372f477219817 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,41 +1,65 @@
// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
// e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
//CHECK-LABEL: load_store_matrix_1
- gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<1xf32> {
+ gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf32>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
%tid_x = gpu.thread_id x
%c0 = arith.constant 0 : index
- %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<1xf32>
- gpu.return %1: vector<1xf32>
+ %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+ gpu.return %1: f32
}
- // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_2
- gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+ gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
%tid_x = gpu.thread_id x
%c13 = arith.constant 13 : index
- %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<1xf16>
- gpu.return %1: vector<1xf16>
+ %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
+ gpu.return %1: f16
}
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_3
- gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+ gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
+ %tid_x = gpu.thread_id x
+ %c19 = arith.constant 19: index
+ %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
+ gpu.return %1: f16
+ }
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+ // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
+ //CHECK-LABEL: load_store_matrix_4
+ gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
%tid_x = gpu.thread_id x
- %c17 = arith.constant 17 : index
- %1 = xegpu.load_matrix %0[%c17, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<1xf16>
- gpu.return %1: vector<1xf16>
+ %c16 = arith.constant 16 : index
+ %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+ gpu.return %1: vector<8xf16>
+ }
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+ // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
+ //CHECK-LABEL: load_store_matrix_5
+ gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
+ %c16 = arith.constant 16 : index
+ %c48 = arith.constant 48 : index
+ %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
+ gpu.return %1: vector<8xf16>
}
}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 47aa05763ee99..eb5d653be8b9c 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -846,17 +846,17 @@ gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
gpu.return
}
-// CHECK: gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
-gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
- // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroupBlockIO}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
- %data = xegpu.load_matrix %arg0[8, 16] {subgroupBlockIO}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
+gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
gpu.return
}
// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
// CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
- %data = xegpu.load_matrix %arg0[8, 8]{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
gpu.return
}
@@ -881,17 +881,17 @@ gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf
gpu.return
}
-// CHECK: gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
-gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
- // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- xegpu.store_matrix %arg1, %arg0[8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
+gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
gpu.return
}
// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
- // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] {vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
- xegpu.store_matrix %arg1, %arg0[8, 8] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
>From 9f9744cecbd30fea7b63c47768b323879222d105 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 9 Oct 2025 02:00:49 +0000
Subject: [PATCH 04/12] bug fixes
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 43 +++++----
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 92 +------------------
.../XeGPUToXeVM/loadstore_matrix.mlir | 2 +-
mlir/test/Dialect/XeGPU/invalid.mlir | 2 +-
4 files changed, 30 insertions(+), 109 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index f4f0a46c54089..67e8246e5536a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -33,8 +33,6 @@
#include <numeric>
-#define DEBUG_TYPE "xegpu-to-xevm"
-
namespace mlir {
#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -519,9 +517,6 @@ class CreateMemDescOpPattern final
LogicalResult
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- // DEBUG: Print operation and types
- LLVM_DEBUG(llvm::dbgs()
- << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
TypedValue<MemRefType> src = op.getSource();
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
@@ -529,19 +524,10 @@ class CreateMemDescOpPattern final
// memory space
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
- LLVM_DEBUG(llvm::dbgs()
- << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
- LLVM_DEBUG(llvm::dbgs()
- << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
- LLVM_DEBUG(llvm::dbgs()
- << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
Value(src), zero, ValueRange());
rewriter.replaceOp(op, viewOp);
- LLVM_DEBUG(
- llvm::dbgs()
- << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
return success();
}
};
@@ -635,16 +621,33 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
// if the attribute 'subgroup_block_io' is set to true, it lowers to
// xevm.blockload
auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
- bool subgroup_block_io =
- subgroupBlockIoAttr && cast<BoolAttr>(subgroupBlockIoAttr).getValue();
+ bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
+
+ // BlockLoadOp only supports integer types, so we need to bitcast
+ // Get integer type with matching bit width
+ Type elemTy = valOrResVecTy.getElementType();
+ int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
+ Type intElemTy = rewriter.getIntegerType(bitWidth);
+ VectorType intVecTy =
+ VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
if (subgroup_block_io) {
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
- Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy,
- basePtrLLVM);
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+ }
rewriter.replaceOp(op, loadOp);
} else {
- xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM,
- adaptor.getData(), nullptr);
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+ }
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
rewriter.eraseOp(op);
}
} else {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 26f2f691ab860..cccc8fab4adbc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,8 +37,6 @@ void XeGPUDialect::initialize() {
>();
}
-#define DEBUG_TYPE "xegpu"
-
/// Generates instructions to compute offsets for a subgroup identified by
/// its multidimensional indices (sgId), using the specified subgroup layout
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
@@ -788,30 +786,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
strides.push_back(cast<IntegerAttr>(attr).getInt());
}
- llvm::dbgs() << "DEBUG: matrixShape = [";
- for (size_t i = 0; i < matrixShape.size(); ++i) {
- llvm::dbgs() << matrixShape[i];
- if (i < matrixShape.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
- llvm::dbgs() << "DEBUG: strides = [";
- for (size_t i = 0; i < strides.size(); ++i) {
- llvm::dbgs() << strides[i];
- if (i < strides.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
SmallVector<int64_t> innerBlkShape = getBlockSize();
- llvm::dbgs() << "DEBUG: innerBlkShape = [";
- for (size_t i = 0; i < innerBlkShape.size(); ++i) {
- llvm::dbgs() << innerBlkShape[i];
- if (i < innerBlkShape.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
// get perm from FCD to LCD
// perm[i] = the dim with i-th smallest stride
@@ -819,25 +794,13 @@ SmallVector<int64_t> MemDescType::getStrides() {
llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
- llvm::dbgs() << "DEBUG: perm = [";
- for (size_t i = 0; i < perm.size(); ++i) {
- llvm::dbgs() << perm[i];
- if (i < perm.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
- SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
-
- llvm::dbgs() << "DEBUG: innerBlkStride = [";
- for (size_t i = 0; i < innerBlkStride.size(); ++i) {
- llvm::dbgs() << innerBlkStride[i];
- if (i < innerBlkStride.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+ innerBlkStride[perm[0]] = 1;
+ for (size_t i = 1; i < perm.size(); ++i)
+ innerBlkStride[perm[i]] =
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
// compute the original matrix shape using the stride info
// and compute the number of blocks in each dimension
@@ -850,28 +813,10 @@ SmallVector<int64_t> MemDescType::getStrides() {
BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
}
- llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
- for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
- llvm::dbgs() << matrixShapeOrig[i];
- if (i < matrixShapeOrig.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
- llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
- for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
- llvm::dbgs() << BlkShapeOrig[i];
- if (i < BlkShapeOrig.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
int64_t innerBlkSize = 1;
for (auto s : innerBlkShape)
innerBlkSize *= s;
- llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
-
SmallVector<int64_t> outerBlkStride(matrixShape.size());
outerBlkStride[perm[0]] = innerBlkSize;
for (size_t i = 0; i < perm.size() - 1; ++i) {
@@ -879,27 +824,11 @@ SmallVector<int64_t> MemDescType::getStrides() {
outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
}
- llvm::dbgs() << "DEBUG: outerBlkStride = [";
- for (size_t i = 0; i < outerBlkStride.size(); ++i) {
- llvm::dbgs() << outerBlkStride[i];
- if (i < outerBlkStride.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
// combine the inner and outer strides
SmallVector<int64_t> blockedStrides;
blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
- llvm::dbgs() << "DEBUG: blockedStrides = [";
- for (size_t i = 0; i < blockedStrides.size(); ++i) {
- llvm::dbgs() << blockedStrides[i];
- if (i < blockedStrides.size() - 1)
- llvm::dbgs() << ", ";
- }
- llvm::dbgs() << "]\n";
-
return blockedStrides;
}
@@ -911,12 +840,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
SmallVector<int64_t> blockShape = getBlockSize();
SmallVector<int64_t> strides = getStrides();
- LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
- llvm::interleaveComma(blockShape, llvm::dbgs());
- llvm::dbgs() << "], strides=[";
- llvm::interleaveComma(strides, llvm::dbgs());
- llvm::dbgs() << "]\n");
-
// blockshape equal to matrixshape means no blocking
if (llvm::equal(blockShape, matrixShape)) {
// remove the outer dims from strides
@@ -937,8 +860,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
blockedOffsets.append(rems.begin(), rems.end());
offsets = blockedOffsets;
- LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
- << offsets.size() << "\n");
}
// Start with initial value as matrix descriptor's base offset.
@@ -949,9 +870,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
}
- LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
- << linearOffset << "\n");
-
return linearOffset;
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 372f477219817..3713635a1cc71 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -55,7 +55,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK-LABEL: load_store_matrix_5
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
+ //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16>
%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 228ef69d9a478..bef45438c944e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>
// -----
func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{result shape must not exceed mem_desc shape}}
+ // expected-error at +1 {{data shape must not exceed mem_desc shape}}
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
return
}
>From bbd43d089096c8c66507c59c7df0c42d2806bcc0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 10 Oct 2025 00:20:47 +0000
Subject: [PATCH 05/12] polish tests
---
.../XeGPUToXeVM/loadstore_matrix.mlir | 154 +++++++++++++++++-
mlir/test/Dialect/XeGPU/invalid.mlir | 29 ++++
2 files changed, 174 insertions(+), 9 deletions(-)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 3713635a1cc71..6302758195e51 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,64 +1,200 @@
-// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
- // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+ // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
// its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
//CHECK-LABEL: load_store_matrix_1
gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+
+ //CHECK: %[[TID:.*]] = gpu.thread_id x
+ //CHECK: %[[C1:.*]] = arith.constant 1 : index
+ //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i64
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
%tid_x = gpu.thread_id x
%c0 = arith.constant 0 : index
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+ //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
gpu.return %1: f32
}
- // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
+// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_2
gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK: %[[c13:.*]] = arith.constant 13 : index
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
+
+
%tid_x = gpu.thread_id x
%c13 = arith.constant 13 : index
%1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+
+ xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
gpu.return %1: f16
}
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_3
gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
+
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+ //CHECK: %[[c19:.*]] = arith.constant 19 : index
%tid_x = gpu.thread_id x
%c19 = arith.constant 19: index
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+ xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
+ //CHECK: gpu.return %[[loaded]] : f16
gpu.return %1: f16
}
-
- // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+
+ // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
// its memory layout tuple is ([2,4,16,16],[256,512,1,16])
//CHECK-LABEL: load_store_matrix_4
gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
- //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
+
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+ //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c512:.*]] = arith.constant 512 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+ //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+ //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
+
%tid_x = gpu.thread_id x
%c16 = arith.constant 16 : index
%1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+
+ //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
+ xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+
gpu.return %1: vector<8xf16>
}
+
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
//CHECK-LABEL: load_store_matrix_5
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+ //CHECK: %[[c0:.*]] = arith.constant 0 : index
+ //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
- //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16>
+
+ //CHECK: %[[c16:.*]] = arith.constant 16 : index
+ //CHECK: %[[c48:.*]] = arith.constant 48 : index
+
%c16 = arith.constant 16 : index
%c48 = arith.constant 48 : index
+
+ //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+ //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+ //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+ //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+ //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
+ //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+ //CHECK: %[[c256:.*]] = arith.constant 256 : index
+ //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+ //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+ //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+ //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+ //CHECK: %[[c1:.*]] = arith.constant 1 : index
+ //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
+ //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+ //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64
+ //CHECK: %[[c2:.*]] = arith.constant 2 : i64
+ //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64
+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3>
+ //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+ //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
+
+ //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
+ //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
+
+ xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
gpu.return %1: vector<8xf16>
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index bef45438c944e..fee3136195e1d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -870,6 +870,21 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
return
}
+// -----
+func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
+ return
+}
+
+// -----
+func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
+ return
+}
+
+
// -----
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
// expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -891,6 +906,20 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
return
}
+// -----
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+ // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+ // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+
// -----
func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
// expected-error at +1 {{result shape must not exceed source shape}}
>From 034476186425dde826929584ae36f95fa7263fd8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 10 Oct 2025 06:19:21 +0000
Subject: [PATCH 06/12] fix minor issues
---
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 +++++--------
1 file changed, 5 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 67e8246e5536a..05f26354e5a2a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -61,10 +61,8 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
return static_cast<int>(xevm::AddrSpace::GLOBAL);
case xegpu::MemorySpace::SLM:
return static_cast<int>(xevm::AddrSpace::SHARED);
- default:
- llvm_unreachable("Unknown XeGPU memory space");
- return static_cast<int>(xevm::AddrSpace::GLOBAL);
}
+ llvm_unreachable("Unknown XeGPU memory space");
}
// Get same bitwidth flat vector type of new element type.
@@ -186,8 +184,9 @@ class CreateNdDescToXeVMPattern
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
- if (rank != 2)
+ if (rank != 2) {
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+ }
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -612,8 +611,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
rewriter.replaceOp(op, loadOp);
} else {
- auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
- basePtrLLVM);
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
rewriter.eraseOp(op);
}
return success();
@@ -680,8 +678,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
rewriter.replaceOp(op, loadOp);
} else {
- auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
- basePtrLLVM);
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
rewriter.eraseOp(op);
}
}
>From 966525b19652cb75c20722dfa9c22bb74d43a87b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Oct 2025 18:04:56 +0000
Subject: [PATCH 07/12] remove vector direction and lenght attirbutes
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 18 ------------
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 4 ---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 ++----------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 29 +++++++------------
.../XeGPUToXeVM/loadstore_matrix.mlir | 4 +--
mlir/test/Dialect/XeGPU/invalid.mlir | 13 ++-------
mlir/test/Dialect/XeGPU/ops.mlir | 8 ++---
7 files changed, 22 insertions(+), 72 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 601e966b49890..2efd575a652db 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -724,22 +724,4 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
}
-def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
-def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
-def MatrixAccessDirection :
- I32EnumAttr<"MatrixAccessDirection",
- "Matrix elements/vectors can have row or column direction", [
- RowOriented, ColOriented
-]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::xegpu";
-}
-def MatrixAccessDirectionAttr :
- EnumAttr<XeGPU_Dialect,
- MatrixAccessDirection,
- "matrix_access_direction">{
- let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
- let assemblyFormat = "`<` $value `>`";
-}
-
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 044a8ef22d891..f41f9e887cff7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1302,8 +1302,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
- OptionalAttr<I32Attr>:$vec_length,
- OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
@@ -1355,8 +1353,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
- OptionalAttr<I32Attr>:$vec_length,
- OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 05f26354e5a2a..2ff2c98d291d2 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -184,9 +184,9 @@ class CreateNdDescToXeVMPattern
SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
// Descriptor shape is expected to be 2D.
int64_t rank = mixedSizes.size();
- if (rank != 2) {
+ if (rank != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
- }
+
auto sourceTy = source.getType();
auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
// If source is a memref, we need to extract the aligned pointer as index.
@@ -658,20 +658,6 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
return rewriter.notifyMatchFailure(
op, "The lowering is specific to pvc or bmg.");
}
- xegpu::MatrixAccessDirectionAttr vecDirection =
- op.getVecDirectionAttr();
- if (vecDirection &&
- vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
- !mdescTy.isColMajor())
- return rewriter.notifyMatchFailure(
- op, "mem_desc should be column major when "
- "vec_direction is COLUMN for 1D result.");
- if (vecDirection &&
- vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
- mdescTy.isColMajor())
- return rewriter.notifyMatchFailure(
- op, "mem_desc should be row major when "
- "vec_direction is ROW for 1D result.");
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
Value loadOp =
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0bc7b3f06ec53..8d86e78fcbf4f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -173,17 +173,18 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
return success();
}
-LogicalResult IsValidStoreMatrixParams(
- VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
- MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
- function_ref<InFlightDiagnostic()> emitError) {
-
- if (!dataTy)
- if (subgroup_block_io || vecDirection || vecLength)
- return emitError() << "vec_length, vec_direction and subgroup_block_io "
+LogicalResult
+IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
+ UnitAttr subgroup_block_io,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!dataTy) {
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
"are only allowed when result is a 1D VectorType.";
else
return success();
+ }
if (mdescTy.getRank() != 2)
return emitError() << "mem_desc must be 2D.";
@@ -192,8 +193,8 @@ LogicalResult IsValidStoreMatrixParams(
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
if (dataShape.size() == 2) {
- if (subgroup_block_io || vecDirection || vecLength)
- return emitError() << "vec_length, vec_direction and subgroup_block_io "
+ if (subgroup_block_io)
+ return emitError() << "subgroup_block_io "
"are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -1097,7 +1098,6 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
// Call the generated builder with all parameters (including optional ones as
// nullptr/empty)
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
- /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
/*subgroup_block_io=*/nullptr, layout);
}
@@ -1105,12 +1105,9 @@ LogicalResult LoadMatrixOp::verify() {
auto resTy = dyn_cast<VectorType>(getRes().getType());
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
- MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
- IntegerAttr vecLength = getVecLengthAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
- vecDirection, vecLength,
[&]() { return emitError(); });
}
@@ -1126,7 +1123,6 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
- /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
/*subgroup_block_io=*/nullptr, layout);
}
@@ -1134,11 +1130,8 @@ LogicalResult StoreMatrixOp::verify() {
auto dataTy = dyn_cast<VectorType>(getData().getType());
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
- MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
- IntegerAttr vecLength = getVecLengthAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
- vecDirection, vecLength,
[&]() { return emitError(); });
}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 6302758195e51..ebb3c2b2b6a83 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -139,10 +139,10 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%tid_x = gpu.thread_id x
%c16 = arith.constant 16 : index
- %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+ %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
//CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
- xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+ xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
gpu.return %1: vector<8xf16>
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index fee3136195e1d..6062eba709b88 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -870,16 +870,9 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
return
}
-// -----
-func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
- %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
- return
-}
-
// -----
func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
%data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
return
}
@@ -908,14 +901,14 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
// -----
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
// -----
func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+ // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
return
}
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index eb5d653be8b9c..f1f5f86d33bc0 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -855,8 +855,8 @@ gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #
// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
- // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
- %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+ %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
gpu.return
}
@@ -890,8 +890,8 @@ gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16,
// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
- // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
- xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
gpu.return
}
>From 272f51213290a1784ac8a44124fbb38b67c9b1c3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Oct 2025 23:15:41 +0000
Subject: [PATCH 08/12] address comments
---
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 26 ++++++++++++-------
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 8 +++---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 4 +--
.../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 3 +++
.../Transforms/XeGPUWgToSgDistribute.cpp | 1 +
.../XeGPUToXeVM/loadstore_matrix.mlir | 2 +-
6 files changed, 27 insertions(+), 17 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index c261fbb576642..99526159cd2e7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -242,7 +242,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
}
-
// derive and return default strides
SmallVector<int64_t> defaultStrides;
llvm::append_range(defaultStrides, getShape().drop_front());
@@ -251,6 +250,15 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return builder.getI64ArrayAttr(defaultStrides);
}
+ ArrayAttr getBlockAttr() {
+ auto layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ return layout.getBlockAttr();
+ }
+ Builder builder(getContext());
+ return builder.getI64ArrayAttr({});
+ }
+
/// Heuristic to determine if the MemDesc uses column-major layout,
/// based on the rank and the value of the first stride dimension.
bool isColMajor() {
@@ -261,16 +269,14 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
// get the Blocking shape for a MemDescType, Which is represented
// as an attribute in MemDescType. By default it is the shape
// of the mdescTy
- SmallVector<int64_t> getBlockSize() {
+ SmallVector<int64_t> getBlockShape() {
SmallVector<int64_t> size(getShape());
- MemLayoutAttr layout = getMemLayout();
- if (layout && layout.hasAttr("block")) {
- ArrayAttr attr = layout.getBlockAttr();
+ ArrayAttr blockAttr = getBlockAttr();
+ if (!blockAttr.empty()) {
size.clear();
- llvm::for_each(attr, [&](Attribute elem) {
- if (auto intElem = dyn_cast<IntegerAttr>(elem))
- size.push_back(intElem.getInt());
- });
+ for (auto attr : blockAttr.getValue()) {
+ size.push_back(cast<IntegerAttr>(attr).getInt());
+ }
}
return size;
}
@@ -289,7 +295,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
- SmallVector<int64_t> getStrides();
+ SmallVector<int64_t> getStrideShape();
/// Generates instructions to compute the linearize offset
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index cccc8fab4adbc..78eee0102ba85 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -776,7 +776,7 @@ SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
}
// Get strides as vector of integer for MemDesc.
-SmallVector<int64_t> MemDescType::getStrides() {
+SmallVector<int64_t> MemDescType::getStrideShape() {
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
@@ -786,7 +786,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
strides.push_back(cast<IntegerAttr>(attr).getInt());
}
- SmallVector<int64_t> innerBlkShape = getBlockSize();
+ SmallVector<int64_t> innerBlkShape = getBlockShape();
// get perm from FCD to LCD
// perm[i] = the dim with i-th smallest stride
@@ -837,8 +837,8 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> offsets) {
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
- SmallVector<int64_t> blockShape = getBlockSize();
- SmallVector<int64_t> strides = getStrides();
+ SmallVector<int64_t> blockShape = getBlockShape();
+ SmallVector<int64_t> strides = getStrideShape();
// blockshape equal to matrixshape means no blocking
if (llvm::equal(blockShape, matrixShape)) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8d86e78fcbf4f..8c7a686b8ce0d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -201,10 +201,10 @@ IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
return emitError() << "data shape must not exceed mem_desc shape.";
} else if (dataShape.size() == 1) {
- SmallVector<int64_t> blockSize = mdescTy.getBlockSize();
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
- if (subgroup_block_io && !blockSize.size())
+ if (subgroup_block_io && !blockShape.size())
return emitError() << "mem_desc must have block attribute when "
"subgroup_block_io is set.";
// if the subgroup_block_io attribute is set, the memdesc should be row
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 6d17b27849a43..aafa1b7deb84b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -942,6 +942,8 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+ assert(valueTy && "the value type must be vector type!");
+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
return failure();
@@ -985,6 +987,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
Location loc = op.getLoc();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
+ assert(valueTy && "the value type must be vector type!");
ArrayRef<int64_t> shape = valueTy.getShape();
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index baee57c512ddf..31a967dcd04c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -992,6 +992,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
ArrayRef<int64_t> wgShape = op.getDataShape();
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
+ assert(valueTy && "the value type must be vector type!");
Type elemTy = valueTy.getElementType();
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ebb3c2b2b6a83..df1433e7b98ae 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -198,4 +198,4 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
gpu.return %1: vector<8xf16>
}
-}
\ No newline at end of file
+}
>From b1857a275d7e30a55ac9b17b335f61f556b2e695 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Oct 2025 01:05:09 +0000
Subject: [PATCH 09/12] address more comments
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 134 ++++++++----------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 14 +-
.../XeGPUToXeVM/loadstore_matrix.mlir | 18 +--
3 files changed, 77 insertions(+), 89 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 2ff2c98d291d2..e5e797a1fa1c3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -365,10 +365,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
// Add a builder that creates
// offset * elemByteSize + baseAddr
-static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
- Value baseAddr, Value offset, int64_t elemByteSize) {
+static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
+ Location loc, Value baseAddr, Value offset,
+ int64_t elemByteSize) {
Value byteSize = arith::ConstantIntOp::create(
- rewriter, loc, rewriter.getI64Type(), elemByteSize);
+ rewriter, loc, baseAddr.getType(), elemByteSize);
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
return newAddr;
@@ -443,7 +444,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
// If offset is provided, we add them to the base pointer.
// Offset is in number of elements, we need to multiply by
// element byte size.
- basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
+ basePtrI64 =
+ addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
}
// Convert base pointer (i64) to LLVM pointer type.
Value basePtrLLVM =
@@ -516,7 +518,7 @@ class CreateMemDescOpPattern final
LogicalResult
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- TypedValue<MemRefType> src = op.getSource();
+
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
// Create the result MemRefType with the same shape, element type, and
@@ -525,7 +527,7 @@ class CreateMemDescOpPattern final
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
- Value(src), zero, ValueRange());
+ op.getSource(), zero, ValueRange());
rewriter.replaceOp(op, viewOp);
return success();
}
@@ -587,88 +589,74 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
rewriter, loc, basePtrStruct);
- // Convert base pointer (ptr) to i64
- Value basePtrI64 = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+ // Convert base pointer (ptr) to i32
+ Value basePtrI32 = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
linearOffset = arith::IndexCastUIOp::create(
- rewriter, loc, rewriter.getI64Type(), linearOffset);
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+ rewriter, loc, rewriter.getI32Type(), linearOffset);
+ basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+ elemByteSize);
- // convert base pointer (i64) to LLVM pointer type
+ // convert base pointer (i32) to LLVM pointer type
basePtrLLVM =
- LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
- // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
- // operation. LLVM load/store does not support vector of size 1, so we need
- // to handle this case separately.
- if (valOrResVecTy.getNumElements() == 1) {
- Type scalarTy = valOrResVecTy.getElementType();
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
- Value loadOp =
- LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
- rewriter.replaceOp(op, loadOp);
- } else {
- LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
- rewriter.eraseOp(op);
- }
- return success();
- } else {
+ if (op.getSubgroupBlockIoAttr()) {
// if the attribute 'subgroup_block_io' is set to true, it lowers to
// xevm.blockload
- auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
- bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
-
- // BlockLoadOp only supports integer types, so we need to bitcast
- // Get integer type with matching bit width
- Type elemTy = valOrResVecTy.getElementType();
- int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
- Type intElemTy = rewriter.getIntegerType(bitWidth);
+
+ Type intElemTy = rewriter.getIntegerType(elemBitWidth);
VectorType intVecTy =
VectorType::get(valOrResVecTy.getShape(), intElemTy);
- if (subgroup_block_io) {
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
- Value loadOp =
- xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
- if (intVecTy != valOrResVecTy) {
- loadOp =
- vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
- }
- rewriter.replaceOp(op, loadOp);
- } else {
- Value dataToStore = adaptor.getData();
- if (valOrResVecTy != intVecTy) {
- dataToStore =
- vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
- }
- xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
- nullptr);
- rewriter.eraseOp(op);
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ Value loadOp =
+ xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+ if (intVecTy != valOrResVecTy) {
+ loadOp =
+ vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
}
+ rewriter.replaceOp(op, loadOp);
} else {
- // if the result is 1D vector, if the vector direction is Column, then
- // the
- // memory descriptor should be treated as column major
- auto chipOpt = xegpu::getChipStr(op);
- if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
- // the lowering only works for pvc and bmg
- return rewriter.notifyMatchFailure(
- op, "The lowering is specific to pvc or bmg.");
+ Value dataToStore = adaptor.getData();
+ if (valOrResVecTy != intVecTy) {
+ dataToStore =
+ vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
}
+ xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+ nullptr);
+ rewriter.eraseOp(op);
+ }
+ return success();
+ }
- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
- Value loadOp =
- LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
- rewriter.replaceOp(op, loadOp);
- } else {
- LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
- rewriter.eraseOp(op);
- }
+ if (valOrResVecTy.getNumElements() >= 1) {
+ auto chipOpt = xegpu::getChipStr(op);
+ if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+ // the lowering for chunk load only works for pvc and bmg
+ return rewriter.notifyMatchFailure(
+ op, "The lowering is specific to pvc or bmg.");
}
}
+
+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+ // operation. LLVM load/store does not support vector of size 1, so we
+ // need to handle this case separately.
+ auto scalarTy = valOrResVecTy.getElementType();
+ LLVM::LoadOp loadOp;
+ if (valOrResVecTy.getNumElements() == 1)
+ loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+ else
+ loadOp =
+ LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+ rewriter.replaceOp(op, loadOp);
+ } else {
+ LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+ rewriter.eraseOp(op);
+ }
return success();
}
};
@@ -715,8 +703,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
op, "Expected element type bit width to be multiple of 8.");
elemByteSize = elemBitWidth / 8;
}
- basePtrI64 =
- addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+ basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
+ elemByteSize);
}
}
// Default memory space is global.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8c7a686b8ce0d..7108afffe99d5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -174,9 +174,9 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
}
LogicalResult
-IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
- UnitAttr subgroup_block_io,
- function_ref<InFlightDiagnostic()> emitError) {
+IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
+ UnitAttr subgroup_block_io,
+ function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
@@ -1107,8 +1107,8 @@ LogicalResult LoadMatrixOp::verify() {
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
- return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1131,8 +1131,8 @@ LogicalResult StoreMatrixOp::verify() {
auto dataTy = dyn_cast<VectorType>(getData().getType());
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
- return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
+ [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index df1433e7b98ae..d4cb493271d0d 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -11,8 +11,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[TID:.*]] = gpu.thread_id x
//CHECK: %[[C1:.*]] = arith.constant 1 : index
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
- //CHECK: %[[C4:.*]] = arith.constant 4 : i64
- //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64
+ //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+ //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
%tid_x = gpu.thread_id x
@@ -80,7 +80,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%c19 = arith.constant 19: index
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[c16:.*]] = arith.constant 16 : index
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
@@ -164,7 +164,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
%c48 = arith.constant 48 : index
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
- //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+ //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
@@ -180,11 +180,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
//CHECK: %[[c1:.*]] = arith.constant 1 : index
//CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
//CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
- //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64
- //CHECK: %[[c2:.*]] = arith.constant 2 : i64
- //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64
- //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64
- //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3>
+ //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
+ //CHECK: %[[c2:.*]] = arith.constant 2 : i32
+ //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
+ //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
+ //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
>From 7a63d93d076d8b90ff27e3e4f88b008780078f75 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Oct 2025 21:24:07 +0000
Subject: [PATCH 10/12] address more feedback
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 2 +-
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 37 -------------------
.../mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 6 +--
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 15 +-------
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 34 -----------------
mlir/test/Dialect/XeGPU/invalid.mlir | 28 --------------
mlir/test/Dialect/XeGPU/ops.mlir | 21 -----------
8 files changed, 6 insertions(+), 139 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 2efd575a652db..19a52317956d2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -712,7 +712,7 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().contains(name);
}
- ArrayAttr getStrides() {
+ ArrayAttr getStrideAttr() {
return getAttrs().getAs<ArrayAttr>("stride");
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index f41f9e887cff7..73b70da9642e4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1392,41 +1392,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
let hasVerifier = 1;
}
-def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
- [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
- let description = [{
- Creates a subview of a memory descriptor. The resulting memory descriptor can have
- a lower rank than the source; in this case, the result dimensions correspond to the
- higher-order dimensions of the source memory descriptor.
-
- Arguments:
- - `src` : a memory descriptor.
- - `offsets` : the coordinates within the matrix the subview will be created from.
-
- Results:
- - `res` : a memory descriptor with smaller size.
-
- }];
- let arguments = (ins XeGPU_MemDesc:$src,
- Variadic<Index>:$offsets,
- DenseI64ArrayAttr:$const_offsets);
- let results = (outs XeGPU_MemDesc:$res);
- let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
- attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
- let builders = [
- OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
- ];
-
- let extraClassDeclaration = [{
- mlir::Value getViewSource() { return getSrc(); }
-
- SmallVector<OpFoldResult> getMixedOffsets() {
- return getMixedValues(getConstOffsets(), getOffsets(), getContext());
- }
- }];
-
- let hasVerifier = 1;
-}
-
-
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 99526159cd2e7..024ca2023c811 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,10 +237,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}
- ArrayAttr getStridesAttr() {
+ ArrayAttr getStrideAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
- return layout.getStrides();
+ return layout.getStrideAttr();
}
// derive and return default strides
SmallVector<int64_t> defaultStrides;
@@ -262,7 +262,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
/// Heuristic to determine if the MemDesc uses column-major layout,
/// based on the rank and the value of the first stride dimension.
bool isColMajor() {
- auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+ auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
return getRank() == 2 && dim0 && dim0.getInt() == 1;
}
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fdd29dd96cd55..9cf963e101816 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -534,18 +534,6 @@ class CreateMemDescOpPattern final
}
};
-class MemDescSubviewOpPattern final
- : public OpConversionPattern<xegpu::MemDescSubviewOp> {
-public:
- using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
- LogicalResult
- matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- return rewriter.notifyMatchFailure(
- op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
- }
-};
-
template <typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -1085,8 +1073,7 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
typeConverter, patterns.getContext());
patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
- CreateMemDescOpPattern, MemDescSubviewOpPattern>(
- typeConverter, patterns.getContext());
+ CreateMemDescOpPattern>(typeConverter, patterns.getContext());
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
patterns.getContext());
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index dc880308e6b3a..1cfae28f31188 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -781,7 +781,7 @@ SmallVector<int64_t> MemDescType::getStrideShape() {
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
- ArrayAttr strideAttr = getStridesAttr();
+ ArrayAttr strideAttr = getStrideAttr();
SmallVector<int64_t> strides;
for (Attribute attr : strideAttr.getValue()) {
strides.push_back(cast<IntegerAttr>(attr).getInt());
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7108afffe99d5..f2d1b85fa1737 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1135,40 +1135,6 @@ LogicalResult StoreMatrixOp::verify() {
[&]() { return emitError(); });
}
-//===----------------------------------------------------------------------===//
-// XeGPU_MemDescSubviewOp
-//===----------------------------------------------------------------------===//
-
-void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
- Type resTy, Value src,
- llvm::ArrayRef<OpFoldResult> offsets) {
- llvm::SmallVector<Value> dynamicOffsets;
- llvm::SmallVector<int64_t> staticOffsets;
- dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
- auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
- build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
-}
-
-LogicalResult MemDescSubviewOp::verify() {
- MemDescType srcTy = getSrc().getType();
- MemDescType resTy = getRes().getType();
- ArrayRef<int64_t> srcShape = srcTy.getShape();
- ArrayRef<int64_t> resShape = resTy.getShape();
-
- if (srcTy.getRank() < resTy.getRank())
- return emitOpError("result rank must not exceed source rank.");
-
- if (llvm::any_of(
- llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
- [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
- return emitOpError("result shape must not exceed source shape.");
-
- if (srcTy.getStridesAttr() != resTy.getStridesAttr())
- return emitOpError("result must inherit the source strides.");
-
- return success();
-}
-
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 6062eba709b88..ebbe3ce0ec0d0 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -913,31 +913,3 @@ func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data
return
}
-// -----
-func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{result shape must not exceed source shape}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
- return
-}
-
-// -----
-func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
- // expected-error at +1 {{result must inherit the source strides}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
- return
-}
-
-// -----
-func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{failed to verify that all of {src, res} have same element type}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
- return
-}
-
-// -----
-func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{result rank must not exceed source rank}}
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
- return
-}
-
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index f1f5f86d33bc0..0a10f6814ae96 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -895,25 +895,4 @@ gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_
gpu.return
}
-// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- gpu.return
-}
-
-// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
- gpu.return
-}
-
-// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
- //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
- %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
- gpu.return
-}
-
}
>From de87d094a9a309b66138d2a357d1ac73b8270c2b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Oct 2025 22:05:14 +0000
Subject: [PATCH 11/12] address minor comments
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f2d1b85fa1737..464a9e2d2a806 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -199,8 +199,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
- } else if (dataShape.size() == 1) {
-
+ } else {
SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
@@ -212,8 +211,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
if (subgroup_block_io && mdescTy.isColMajor())
return emitError() << "mem_desc should be row major when "
"subgroup_block_io is set.";
- } else if (dataShape.size() == 0) {
- return emitError() << "result shape must not be empty.";
}
return success();
>From faa0bfb3eb6004dcaf33269b9e161051a96baa79 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 15 Oct 2025 23:35:57 +0000
Subject: [PATCH 12/12] address comments
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 6 ++++++
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 14 ++++++++------
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +-
3 files changed, 15 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73b70da9642e4..426377fcf598f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
Arguments:
- `mem_desc`: the memory descriptor identifying the SLM region.
- `offsets`: the coordinates within the matrix to read from.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block load. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
@@ -1367,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- `mem_desc`: the memory descriptor specifying the SLM region.
- `offsets`: the coordinates within the matrix where the data will be written.
- `data`: the values to be stored in the matrix.
+ - `subgroup_block_io`: [optional] An attribute indicating that the operation can be
+ lowered to a subgroup block store. When this attribute is present,
+ the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 024ca2023c811..b1196fbe9c66a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -263,10 +263,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
/// based on the rank and the value of the first stride dimension.
bool isColMajor() {
auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
- return getRank() == 2 && dim0 && dim0.getInt() == 1;
+ return getRank() == 2 && dim0.getInt() == 1;
}
- // get the Blocking shape for a MemDescType, Which is represented
+ // Get the Blocking shape for a MemDescType, Which is represented
// as an attribute in MemDescType. By default it is the shape
// of the mdescTy
SmallVector<int64_t> getBlockShape() {
@@ -284,16 +284,18 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
// Get strides as vector of integer.
// If it contains block attribute, the strides are blocked strides.
//
- // The blocking is applied against the original matrix shape
- // so that the linear offset is not impacted by the subview.
+ // The blocking is applied to the base matrix shape derived from the
+ // memory descriptor's stride information. If the matrix described by
+ // the memory descriptor is not contiguous, it is assumed that the base
+ // matrix is contiguous and follows the same memory layout.
//
// It first computes the original matrix shape using the stride info,
// then computes the number of blocks in each dimension of original shape,
// then compute the outer block shape and stride,
// then combines the inner and outer block shape and stride
- // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+ // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
- // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+ // for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
SmallVector<int64_t> getStrideShape();
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9cf963e101816..9ee384e46ef33 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -520,7 +520,7 @@ class CreateMemDescOpPattern final
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+ auto resTy = op.getMemDesc();
// Create the result MemRefType with the same shape, element type, and
// memory space
More information about the Mlir-commits
mailing list