[Mlir-commits] [mlir] 9c4e571 - [mlir][xegpu] Add definitions of MemDescType and related ops. (#153273)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 15 16:02:16 PDT 2025
Author: Chao Chen
Date: 2025-08-15T18:02:13-05:00
New Revision: 9c4e571ae83d86aa81c556d62400c61b3f53c805
URL: https://github.com/llvm/llvm-project/commit/9c4e571ae83d86aa81c556d62400c61b3f53c805
DIFF: https://github.com/llvm/llvm-project/commit/9c4e571ae83d86aa81c556d62400c61b3f53c805.diff
LOG: [mlir][xegpu] Add definitions of MemDescType and related ops. (#153273)
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
mlir/test/Dialect/XeGPU/invalid.mlir
mlir/test/Dialect/XeGPU/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1f420c13ebae0..a94987885c9e0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -527,4 +527,34 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
let genVerifyDecl = 1;
}
+def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
+ let summary = [{Specifies memory layouts with named attributes.}];
+
+ let description = [{
+ This attribute stores a collection of named attributes that describe
+ memory layout properties such as stride, block, etc.
+ }];
+
+ let parameters = (ins "DictionaryAttr": $attrs);
+ let hasCustomAssemblyFormat = 1;
+
+ let extraClassDeclaration = [{
+ /// Get a specific attribute by name
+ Attribute getAttr(StringRef name) const {
+ return getAttrs().get(name);
+ }
+
+ /// Check if a specific attribute exists
+ bool hasAttr(StringRef name) const {
+ return getAttrs().contains(name);
+ }
+
+ ArrayAttr getStrides() {
+ return getAttrs().getAs<ArrayAttr>("stride");
+ }
+
+ }];
+
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 480b43e740736..abc291c81a76c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1097,4 +1097,152 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
let hasCanonicalizer = 1;
}
+def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
+class StaticShared1DMemRefOf<list<Type> allowedTypes> :
+ ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
+ "statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
+ "mlir::MemRefType">;
+
+class SizeInBits<string name> :
+ StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
+ "*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
+class AllMemSizesMatch<list<string> names> :
+ AllMatchSameOperatorTrait<names, SizeInBits<"_self">.result,
+ "size in bits">;
+
+def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
+ AllMemSizesMatch<["source", "mem_desc"]>]> {
+ let summary = "Create a memory descriptor.";
+ let description = [{
+ Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu
+ specific memory layout. The resulting memory descriptor has to have the same size
+ as the underlying shared local memory.
+
+ Arguments:
+ - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
+ Results:
+ - `mem_desc` : the memory descriptor.
+ }];
+ let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
+ let results = (outs XeGPU_MemDesc:$mem_desc);
+ let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
+}
+
+def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
+ AllElementTypesMatch<["mem_desc", "res"]>,
+ AllRanksMatch<["mem_desc", "res"]>]> {
+ let arguments = (ins XeGPU_MemDesc:$mem_desc,
+ Variadic<Index>: $offsets,
+ DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<LayoutTrait>:$layout
+ );
+ let results = (outs XeGPU_ValueType:$res);
+ let assemblyFormat = [{
+ $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `` `:` type(operands) `->` type(results)
+ }];
+
+ let description = [{
+ This operation loads a 2D block of data from shared local memory (SLM) as specified
+ by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
+ subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+
+ Arguments:
+ - `mem_desc`: the memory descriptor identifying the SLM region.
+ - `offsets`: the coordinates within the matrix to read from.
+ - `layout`: [optional] An attribute for guiding distributions among
+ subgroups and/or work-items. It currently can accept either
+ LayoutAttr or SliceAttr.
+ Results:
+ - `res`: the matrix elements loaded from SLM.
+ }];
+
+ let builders = [
+ OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+ ];
+ let extraClassDeclaration = [{
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
+ AllElementTypesMatch<["mem_desc", "data"]>,
+ AllRanksMatch<["mem_desc", "data"]>]> {
+ let arguments = (ins
+ XeGPU_ValueType:$data,
+ XeGPU_MemDesc:$mem_desc,
+ Variadic<Index>: $offsets,
+ DenseI64ArrayAttr: $const_offsets,
+ OptionalAttr<LayoutTrait>:$layout
+ );
+ let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
+ prop-dict attr-dict `` `:` type(operands)}];
+ let description = [{
+ This operation stores a 2D `data` fragment into the shared local memory region
+ specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
+ subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+
+ Arguments:
+ - `mem_desc`: the memory descriptor specifying the SLM region.
+ - `offsets`: the coordinates within the matrix where the data will be written.
+ - `data`: the values to be stored in the matrix.
+ - `layout`: [optional] An attribute for guiding distributions among
+ subgroups and/or work-items. It currently can accept either
+ LayoutAttr or SliceAttr.
+ }];
+ let builders = [
+ OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+ ];
+ let extraClassDeclaration = [{
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
+ [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
+ let description = [{
+ Creates a subview of a memory descriptor. The resulting memory descriptor can have
+ a lower rank than the source; in this case, the result dimensions correspond to the
+ higher-order dimensions of the source memory descriptor.
+
+ Arguments:
+ - `src` : a memory descriptor.
+ - `offsets` : the coordinates within the matrix the subview will be created from.
+
+ Results:
+ - `res` : a memory descriptor with smaller size.
+
+ }];
+ let arguments = (ins XeGPU_MemDesc:$src,
+ Variadic<Index>:$offsets,
+ DenseI64ArrayAttr:$const_offsets);
+ let results = (outs XeGPU_MemDesc:$res);
+ let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
+ attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
+ let builders = [
+ OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
+ ];
+
+ let extraClassDeclaration = [{
+ mlir::Value getViewSource() { return getSrc(); }
+
+ SmallVector<OpFoldResult> getMixedOffsets() {
+ return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+ }
+ }];
+
+ let hasVerifier = 1;
+}
+
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index b268cabb5d266..f8b371db498e8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -201,4 +201,53 @@ def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
}];
}
+def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "mlir::Type"> {
+ let summary = "MemDesc describing the data in SLM";
+ let description = [{
+ MemDesc represents a block of data stored in shared local memory.
+ By default, unless a layout attribute is provided, the data is stored
+ contiguously in row-major order within the region.
+
+ Examples:
+ ```mlir
+ // A multi-dimensional array stored in column-major order.
+ !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128]>>
+
+ // A multi-dimensional array stored in a blocked layout. Elements within the same block
+ // are stored contiguously in memory. Blocks are stored in row-major order.
+ !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<block = [8, 8]>>
+
+ // A multi-dimensional array stored in column-major order with blocked layout.
+ !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [8, 8]>>
+ ```
+ }];
+ let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
+ "mlir::Type": $elementType,
+ OptionalParameter<"MemLayoutAttr">: $mem_layout);
+
+ let extraClassDeclaration = [{
+ bool hasRank() const { return true; }
+
+ MemDescType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, Type elementType) const {
+ return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
+ }
+
+ ArrayAttr getStrides() {
+ auto layout = getMemLayout();
+ if (layout && layout.hasAttr("stride")) {
+ return layout.getStrides();
+ }
+
+ // derive and return default strides
+ SmallVector<int64_t> defaultStrides;
+ llvm::append_range(defaultStrides, getShape().drop_front());
+ llvm::append_values(defaultStrides, 1);
+ Builder builder(getContext());
+ return builder.getI64ArrayAttr(defaultStrides);
+ }
+ }];
+
+ let hasCustomAssemblyFormat = true;
+}
+
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 7c6a4f37db9af..7869a28dfed57 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -17,6 +17,8 @@ add_mlir_dialect_library(MLIRXeGPUDialect
MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRIR
MLIRViewLikeInterface
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d997296a22c20..1b26542ff65a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -427,7 +427,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
-mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+mlir::Type TensorDescType::parse(AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
@@ -477,7 +477,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
layout.value_or(mlir::Attribute()));
}
-void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+void TensorDescType::print(AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
@@ -522,10 +522,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, layout);
}
-LogicalResult TensorDescType::verify(
- llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
- llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
- mlir::Attribute encoding, mlir::Attribute layout) {
+LogicalResult
+TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
if (rank == 0)
@@ -591,6 +591,119 @@ LogicalResult TensorDescType::verify(
return success();
}
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+mlir::Type MemDescType::parse(AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<MemLayoutAttr> layout;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ MemLayoutAttr attr;
+ ParseResult res = parser.parseAttribute(attr);
+ if (mlir::failed(res))
+ return {};
+ layout = attr;
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ MLIRContext *ctxt = parser.getContext();
+ return MemDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+ elementType, layout.value_or(MemLayoutAttr()));
+}
+
+void MemDescType::print(AsmPrinter &printer) const {
+ printer << "<";
+
+ printer.printDimensionList(getShape());
+ printer << 'x';
+ printer << getElementType();
+
+ if (auto layout = getMemLayout())
+ printer << ", " << layout;
+
+ printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+
+Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
+
+ auto context = parser.getContext();
+ llvm::SMLoc loc = parser.getCurrentLocation();
+
+ llvm::SmallDenseSet<StringRef> seenKeys;
+ SmallVector<NamedAttribute> attributes;
+
+ auto parseElt = [&]() -> ParseResult {
+ StringRef nameId;
+ if (failed(parser.parseKeyword(&nameId)))
+ return parser.emitError(loc, "expected valid attribute name");
+
+ if (!seenKeys.insert(nameId).second)
+ return parser.emitError(loc, "duplicate key '")
+ << nameId << " in mem layout attribute";
+
+ if (failed(parser.parseEqual()))
+ return failure();
+
+ Attribute attr;
+ if (failed(parser.parseAttribute(attr)))
+ return failure();
+ attributes.emplace_back(nameId, attr);
+ return success();
+ };
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ if (failed(parser.parseCommaSeparatedList(parseElt)))
+ return {};
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return parser.getChecked<MemLayoutAttr>(
+ loc, context, DictionaryAttr::get(context, attributes));
+}
+
+void MemLayoutAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
+ for (size_t i = 0; i < attrs.size(); i++) {
+ printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
+ if (i < attrs.size() - 1)
+ printer << ", ";
+ }
+ printer << ">";
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7b7ce19e6937b..eee0fdc7160de 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -21,6 +23,17 @@
namespace mlir {
namespace xegpu {
+bool isSharedMemory(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 3;
+ if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
+ return memrefSpace.getValue() == MemorySpace::SLM;
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
@@ -919,6 +932,101 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<FoldConvertLayoutOp>(context);
}
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadMatrixOp
+//===----------------------------------------------------------------------===//
+void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ LayoutTrait layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult LoadMatrixOp::verify() {
+ VectorType resTy = getRes().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> valueShape = resTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed mem_desc shape.");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreMatrixOp
+//===----------------------------------------------------------------------===//
+void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ LayoutTrait layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult StoreMatrixOp::verify() {
+ VectorType dataTy = getData().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("data shape must not exceed mem_desc shape.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescSubviewOp
+//===----------------------------------------------------------------------===//
+
+void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
+ Type resTy, Value src,
+ llvm::ArrayRef<OpFoldResult> offsets) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
+}
+
+LogicalResult MemDescSubviewOp::verify() {
+ MemDescType srcTy = getSrc().getType();
+ MemDescType resTy = getRes().getType();
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> resShape = resTy.getShape();
+
+ if (srcTy.getRank() < resTy.getRank())
+ return emitOpError("result rank must not exceed source rank.");
+
+ if (llvm::any_of(
+ llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed source shape.");
+
+ if (srcTy.getStrides() != resTy.getStrides())
+ return emitOpError("result must inherit the source strides.");
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 270d71aaa7273..46ff03745a220 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -524,8 +524,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
// is lowered to:
// #a = #xegpu.layout<inst_data = [16, 16]>
// #b = #xegpu.layout<inst_data = [8, 16]>
-// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
// clang-format on
struct WgToSgConvertLayoutOp
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 5b251517d2ef0..93a5a055b08c6 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -762,3 +762,89 @@ func.func @slice_attr_repeat_dim() {
return
}
+// -----
+func.func @create_mem_desc_non_slm() {
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
+ // expected-error at +1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @create_mem_desc_mismatch_sizes() {
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ // expected-error at +1 {{failed to verify that all of {source, mem_desc} have same size in bits}}
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x32xf16>
+ return
+}
+
+// -----
+func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{failed to verify that all of {mem_desc, res} have same element type}}
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf32>
+ return
+}
+
+// -----
+func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{result shape must not exceed mem_desc shape}}
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
+ return
+}
+
+// -----
+func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
+ // expected-error at +1 {{mem_desc must be 2D}}
+ %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
+ // expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<32x32xf16>) {
+ // expected-error at +1 {{data shape must not exceed mem_desc shape}}
+ xegpu.store_matrix %arg1, %arg0[8, 8] : vector<32x32xf16>, !xegpu.mem_desc<16x64xf16>
+ return
+}
+
+// -----
+func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
+ // expected-error at +1 {{mem_desc must be 2D.}}
+ xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{result shape must not exceed source shape}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
+ // expected-error at +1 {{result must inherit the source strides}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{failed to verify that all of {src, res} have same element type}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
+ return
+}
+
+// -----
+func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // expected-error at +1 {{result rank must not exceed source rank}}
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
+ return
+}
+
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 67c00f5a9cc2f..35342eca1354c 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -751,4 +751,72 @@ gpu.func @fence() {
gpu.return
}
+// CHECK-LABEL: gpu.func @create_mem_desc({{.*}}) {
+gpu.func @create_mem_desc() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16>
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride({{.*}}) {
+gpu.func @create_mem_desc_with_stride() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
+ gpu.return
+}
+
+
+// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
+ //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+ //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+ gpu.return
+}
+
}
More information about the Mlir-commits
mailing list