[Mlir-commits] [mlir] 9c4e571 - [mlir][xegpu] Add definitions of MemDescType and related ops. (#153273)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 15 16:02:16 PDT 2025


Author: Chao Chen
Date: 2025-08-15T18:02:13-05:00
New Revision: 9c4e571ae83d86aa81c556d62400c61b3f53c805

URL: https://github.com/llvm/llvm-project/commit/9c4e571ae83d86aa81c556d62400c61b3f53c805
DIFF: https://github.com/llvm/llvm-project/commit/9c4e571ae83d86aa81c556d62400c61b3f53c805.diff

LOG: [mlir][xegpu] Add definitions of MemDescType and related ops.  (#153273)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
    mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/invalid.mlir
    mlir/test/Dialect/XeGPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1f420c13ebae0..a94987885c9e0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -527,4 +527,34 @@ def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
   let genVerifyDecl = 1;
 }
 
+def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
+  let summary = [{Specifies memory layouts with named attributes.}];
+
+  let description = [{
+    This attribute stores a collection of named attributes that describe
+    memory layout properties such as stride, block, etc.
+  }];
+
+  let parameters = (ins "DictionaryAttr": $attrs);
+  let hasCustomAssemblyFormat = 1;
+
+  let extraClassDeclaration = [{
+    /// Get a specific attribute by name
+    Attribute getAttr(StringRef name) const {
+      return getAttrs().get(name);
+    }
+
+    /// Check if a specific attribute exists
+    bool hasAttr(StringRef name) const {
+      return getAttrs().contains(name);
+    }
+
+    ArrayAttr getStrides() {
+      return getAttrs().getAs<ArrayAttr>("stride");
+    }
+
+  }];
+
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 480b43e740736..abc291c81a76c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1097,4 +1097,152 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
     let hasCanonicalizer = 1;
 }
 
+def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
+class StaticShared1DMemRefOf<list<Type> allowedTypes> :
+  ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
+     "statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
+     "mlir::MemRefType">;
+
+class SizeInBits<string name> :
+  StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
+          "*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
+class AllMemSizesMatch<list<string> names> :
+    AllMatchSameOperatorTrait<names, SizeInBits<"_self">.result,
+                              "size in bits">;
+
+def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
+      AllMemSizesMatch<["source", "mem_desc"]>]>  {
+  let summary = "Create a memory descriptor.";
+  let description = [{
+    Creates a memory descriptor from a shared local memory (SLM) buffer, and xegpu
+    specific memory layout. The resulting memory descriptor has to have the same size
+    as the underlying shared local memory.
+
+    Arguments:
+     - `source` : a 1D statically shaped memref with element type i8, representing the raw SLM buffer.
+    Results:
+     - `mem_desc` : the memory descriptor.
+  }];
+  let arguments = (ins StaticShared1DMemRefOf<[I8]>:$source);
+  let results = (outs XeGPU_MemDesc:$mem_desc);
+  let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
+}
+
+def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
+                              AllElementTypesMatch<["mem_desc", "res"]>,
+                              AllRanksMatch<["mem_desc", "res"]>]>  {
+  let arguments = (ins XeGPU_MemDesc:$mem_desc,
+    Variadic<Index>: $offsets,
+    DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<LayoutTrait>:$layout
+  );
+  let results = (outs XeGPU_ValueType:$res);
+  let assemblyFormat = [{
+    $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
+    prop-dict attr-dict `` `:` type(operands) `->` type(results)
+  }];
+
+  let description = [{
+    This operation loads a 2D block of data from shared local memory (SLM) as specified
+    by the provided 2D `mem_desc`. Only 2D memory descriptors are supported; use the
+    subview operation to obtain a compatible 2D `mem_desc` from a higher-rank descriptor if needed.
+
+    Arguments:
+     - `mem_desc`: the memory descriptor identifying the SLM region.
+     - `offsets`: the coordinates within the matrix to read from.
+     - `layout`: [optional] An attribute for guiding distributions among
+                 subgroups and/or work-items. It currently can accept either
+                 LayoutAttr or SliceAttr.
+    Results:
+     - `res`: the matrix elements loaded from SLM.
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+  ];
+  let extraClassDeclaration = [{
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
+                              AllElementTypesMatch<["mem_desc", "data"]>,
+                              AllRanksMatch<["mem_desc", "data"]>]> {
+  let arguments = (ins
+    XeGPU_ValueType:$data,
+    XeGPU_MemDesc:$mem_desc,
+    Variadic<Index>: $offsets,
+    DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<LayoutTrait>:$layout
+  );
+  let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
+                          prop-dict attr-dict `` `:` type(operands)}];
+  let description = [{
+    This operation stores a 2D `data` fragment into the shared local memory region
+    specified by a 2D `mem_desc`. Only 2D memory descriptors are supported; use the
+    subview operation to obtain a 2D `mem_desc` from a higher-rank descriptor if needed.
+
+    Arguments:
+     - `mem_desc`: the memory descriptor specifying the SLM region.
+     - `offsets`: the coordinates within the matrix where the data will be written.
+     - `data`: the values to be stored in the matrix.
+     - `layout`: [optional] An attribute for guiding distributions among
+                 subgroups and/or work-items. It currently can accept either
+                 LayoutAttr or SliceAttr.
+  }];
+  let builders = [
+    OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+  ];
+  let extraClassDeclaration = [{
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
+          [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
+  let description = [{
+    Creates a subview of a memory descriptor. The resulting memory descriptor can have
+    a lower rank than the source; in this case, the result dimensions correspond to the
+    higher-order dimensions of the source memory descriptor.
+
+    Arguments:
+     - `src` : a memory descriptor.
+     - `offsets` : the coordinates within the matrix the subview will be created from.
+
+    Results:
+    - `res` : a memory descriptor with smaller size.
+
+  }];
+  let arguments = (ins XeGPU_MemDesc:$src,
+                       Variadic<Index>:$offsets,
+                       DenseI64ArrayAttr:$const_offsets);
+  let results = (outs XeGPU_MemDesc:$res);
+  let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
+                         attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
+  let builders = [
+    OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
+  ];
+
+  let extraClassDeclaration = [{
+    mlir::Value getViewSource() { return getSrc(); }
+
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      return getMixedValues(getConstOffsets(), getOffsets(), getContext());
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index b268cabb5d266..f8b371db498e8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -201,4 +201,53 @@ def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   }];
 }
 
+def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "mlir::Type"> {
+  let summary = "MemDesc describing the data in SLM";
+  let description = [{
+    MemDesc represents a block of data stored in shared local memory.
+    By default, unless a layout attribute is provided, the data is stored
+    contiguously in row-major order within the region.
+
+    Examples:
+    ```mlir
+    // A multi-dimensional array stored in column-major order.
+    !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128]>>
+
+    // A multi-dimensional array stored in a blocked layout. Elements within the same block
+    // are stored contiguously in memory. Blocks are stored in row-major order.
+    !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<block = [8, 8]>>
+
+    // A multi-dimensional array stored in column-major order with blocked layout.
+    !xegpu.mem_desc<128x128xf16, #xegpu.mem_layout<stride = [1, 128], block = [8, 8]>>
+    ```
+  }];
+  let parameters = (ins ArrayRefParameter<"int64_t">: $shape,
+                        "mlir::Type": $elementType,
+                        OptionalParameter<"MemLayoutAttr">: $mem_layout);
+
+  let extraClassDeclaration = [{
+    bool hasRank() const { return true; }
+
+    MemDescType cloneWith(std::optional<llvm::ArrayRef<int64_t>> shape, Type elementType) const {
+      return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
+    }
+
+    ArrayAttr getStrides() {
+      auto layout = getMemLayout();
+      if (layout && layout.hasAttr("stride")) {
+        return layout.getStrides();
+      }
+
+      // derive and return default strides
+      SmallVector<int64_t> defaultStrides;
+      llvm::append_range(defaultStrides, getShape().drop_front());
+      llvm::append_values(defaultStrides, 1);
+      Builder builder(getContext());
+      return builder.getI64ArrayAttr(defaultStrides);
+    }
+  }];
+
+  let hasCustomAssemblyFormat = true;
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUTYPES_TD

diff  --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 7c6a4f37db9af..7869a28dfed57 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -17,6 +17,8 @@ add_mlir_dialect_library(MLIRXeGPUDialect
   MLIRAffineUtils
   MLIRArithUtils
   MLIRDialectUtils
+  MLIRGPUDialect
+  MLIRXeVMDialect
   MLIRIR
   MLIRViewLikeInterface
   MLIRVectorDialect

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d997296a22c20..1b26542ff65a3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -427,7 +427,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//
 
-mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+mlir::Type TensorDescType::parse(AsmParser &parser) {
   llvm::SmallVector<int64_t> shape;
   mlir::Type elementType;
   mlir::FailureOr<mlir::Attribute> encoding;
@@ -477,7 +477,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
       layout.value_or(mlir::Attribute()));
 }
 
-void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+void TensorDescType::print(AsmPrinter &printer) const {
   printer << "<";
 
   auto shape = getShape();
@@ -522,10 +522,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
   return Base::get(context, shape, elementType, attr, layout);
 }
 
-LogicalResult TensorDescType::verify(
-    llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
-    llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
-    mlir::Attribute encoding, mlir::Attribute layout) {
+LogicalResult
+TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+                       llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+                       mlir::Attribute encoding, mlir::Attribute layout) {
   size_t rank = shape.size();
 
   if (rank == 0)
@@ -591,6 +591,119 @@ LogicalResult TensorDescType::verify(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+mlir::Type MemDescType::parse(AsmParser &parser) {
+  llvm::SmallVector<int64_t> shape;
+  mlir::Type elementType;
+  mlir::FailureOr<MemLayoutAttr> layout;
+
+  // Parse literal '<'
+  if (parser.parseLess())
+    return {};
+
+  auto shapeLoc = parser.getCurrentLocation();
+  if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
+    parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+    return {};
+  }
+
+  auto elemTypeLoc = parser.getCurrentLocation();
+  if (mlir::failed(parser.parseType(elementType))) {
+    parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+    return {};
+  }
+
+  // parse optional attributes
+  if (mlir::succeeded(parser.parseOptionalComma())) {
+    MemLayoutAttr attr;
+    ParseResult res = parser.parseAttribute(attr);
+    if (mlir::failed(res))
+      return {};
+    layout = attr;
+  }
+
+  // Parse literal '>'
+  if (parser.parseGreater())
+    return {};
+
+  MLIRContext *ctxt = parser.getContext();
+  return MemDescType::getChecked(
+      [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+      elementType, layout.value_or(MemLayoutAttr()));
+}
+
+void MemDescType::print(AsmPrinter &printer) const {
+  printer << "<";
+
+  printer.printDimensionList(getShape());
+  printer << 'x';
+  printer << getElementType();
+
+  if (auto layout = getMemLayout())
+    printer << ", " << layout;
+
+  printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+
+Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
+
+  auto context = parser.getContext();
+  llvm::SMLoc loc = parser.getCurrentLocation();
+
+  llvm::SmallDenseSet<StringRef> seenKeys;
+  SmallVector<NamedAttribute> attributes;
+
+  auto parseElt = [&]() -> ParseResult {
+    StringRef nameId;
+    if (failed(parser.parseKeyword(&nameId)))
+      return parser.emitError(loc, "expected valid attribute name");
+
+    if (!seenKeys.insert(nameId).second)
+      return parser.emitError(loc, "duplicate key '")
+             << nameId << " in mem layout attribute";
+
+    if (failed(parser.parseEqual()))
+      return failure();
+
+    Attribute attr;
+    if (failed(parser.parseAttribute(attr)))
+      return failure();
+    attributes.emplace_back(nameId, attr);
+    return success();
+  };
+
+  // Parse literal '<'
+  if (parser.parseLess())
+    return {};
+
+  if (failed(parser.parseCommaSeparatedList(parseElt)))
+    return {};
+
+  // Parse literal '>'
+  if (parser.parseGreater())
+    return {};
+
+  return parser.getChecked<MemLayoutAttr>(
+      loc, context, DictionaryAttr::get(context, attributes));
+}
+
+void MemLayoutAttr::print(AsmPrinter &printer) const {
+  printer << "<";
+  ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
+  for (size_t i = 0; i < attrs.size(); i++) {
+    printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
+    if (i < attrs.size() - 1)
+      printer << ", ";
+  }
+  printer << ">";
+}
+
 } // namespace xegpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7b7ce19e6937b..eee0fdc7160de 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -21,6 +23,17 @@
 namespace mlir {
 namespace xegpu {
 
+bool isSharedMemory(const MemRefType &memrefTy) {
+  Attribute attr = memrefTy.getMemorySpace();
+  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+    return intAttr.getInt() == 3;
+  if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
+    return memrefSpace.getValue() == MemorySpace::SLM;
+  if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+    return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+  return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
 template <typename T>
 static std::string makeString(T array, bool breakline = false) {
   std::string buf;
@@ -919,6 +932,101 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
   patterns.add<FoldConvertLayoutOp>(context);
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadMatrixOp
+//===----------------------------------------------------------------------===//
+void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
+                         TypedValue<MemDescType> memDesc,
+                         llvm::ArrayRef<OpFoldResult> offsets,
+                         LayoutTrait layout) {
+  llvm::SmallVector<Value> dynamicOffsets;
+  llvm::SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+  build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
+        layout);
+}
+
+LogicalResult LoadMatrixOp::verify() {
+  VectorType resTy = getRes().getType();
+  MemDescType mdescTy = getMemDesc().getType();
+
+  if (mdescTy.getRank() != 2)
+    return emitOpError("mem_desc must be 2D.");
+
+  ArrayRef<int64_t> valueShape = resTy.getShape();
+  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+  if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+                   [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+    return emitOpError("result shape must not exceed mem_desc shape.");
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreMatrixOp
+//===----------------------------------------------------------------------===//
+void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
+                          TypedValue<MemDescType> memDesc,
+                          llvm::ArrayRef<OpFoldResult> offsets,
+                          LayoutTrait layout) {
+  llvm::SmallVector<Value> dynamicOffsets;
+  llvm::SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+  build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
+        layout);
+}
+
+LogicalResult StoreMatrixOp::verify() {
+  VectorType dataTy = getData().getType();
+  MemDescType mdescTy = getMemDesc().getType();
+
+  if (mdescTy.getRank() != 2)
+    return emitOpError("mem_desc must be 2D.");
+
+  ArrayRef<int64_t> dataShape = dataTy.getShape();
+  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+  if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+                   [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+    return emitOpError("data shape must not exceed mem_desc shape.");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescSubviewOp
+//===----------------------------------------------------------------------===//
+
+void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
+                             Type resTy, Value src,
+                             llvm::ArrayRef<OpFoldResult> offsets) {
+  llvm::SmallVector<Value> dynamicOffsets;
+  llvm::SmallVector<int64_t> staticOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+  build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
+}
+
+LogicalResult MemDescSubviewOp::verify() {
+  MemDescType srcTy = getSrc().getType();
+  MemDescType resTy = getRes().getType();
+  ArrayRef<int64_t> srcShape = srcTy.getShape();
+  ArrayRef<int64_t> resShape = resTy.getShape();
+
+  if (srcTy.getRank() < resTy.getRank())
+    return emitOpError("result rank must not exceed source rank.");
+
+  if (llvm::any_of(
+          llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
+          [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+    return emitOpError("result shape must not exceed source shape.");
+
+  if (srcTy.getStrides() != resTy.getStrides())
+    return emitOpError("result must inherit the source strides.");
+
+  return success();
+}
+
 } // namespace xegpu
 } // namespace mlir
 

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 270d71aaa7273..46ff03745a220 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -524,8 +524,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 // is lowered to:
 //   #a = #xegpu.layout<inst_data = [16, 16]>
 //   #b = #xegpu.layout<inst_data = [8, 16]>
-//   store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-//   %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+//   store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+//   %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
 //   xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
 // clang-format on
 struct WgToSgConvertLayoutOp

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 5b251517d2ef0..93a5a055b08c6 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -762,3 +762,89 @@ func.func @slice_attr_repeat_dim() {
   return
 }
 
+// -----
+func.func @create_mem_desc_non_slm() {
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
+  // expected-error at +1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
+  %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
+  return
+}
+
+// -----
+func.func @create_mem_desc_mismatch_sizes() {
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  // expected-error at +1 {{failed to verify that all of {source, mem_desc} have same size in bits}}
+  %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x32xf16>
+  return
+}
+
+// -----
+func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{failed to verify that all of {mem_desc, res} have same element type}}
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf32>
+  return
+}
+
+// -----
+func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{result shape must not exceed mem_desc shape}}
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
+  return
+}
+
+// -----
+func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
+  // expected-error at +1 {{mem_desc must be 2D}}
+  %data = xegpu.load_matrix %arg0[16]: !xegpu.mem_desc<64xf16> -> vector<16xf16>
+  return
+}
+
+// -----
+func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
+  // expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
+  xegpu.store_matrix %arg1, %arg0[8, 8] : vector<16x16xf32>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+
+// -----
+func.func @store_mem_desc_invalid_data_size(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<32x32xf16>) {
+  // expected-error at +1 {{data shape must not exceed mem_desc shape}}
+  xegpu.store_matrix %arg1, %arg0[8, 8] : vector<32x32xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+
+// -----
+func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: vector<32xf16>) {
+  // expected-error at +1 {{mem_desc must be 2D.}}
+  xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
+  return
+}
+
+// -----
+func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{result shape must not exceed source shape}}
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
+  return
+}
+
+// -----
+func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
+  // expected-error at +1 {{result must inherit the source strides}}
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
+  return
+}
+
+// -----
+func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{failed to verify that all of {src, res} have same element type}}
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
+  return
+}
+
+// -----
+func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{result rank must not exceed source rank}}
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
+  return
+}
+

diff  --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 67c00f5a9cc2f..35342eca1354c 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -751,4 +751,72 @@ gpu.func @fence() {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @create_mem_desc({{.*}}) {
+gpu.func @create_mem_desc() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16>
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_mem_desc_with_stride({{.*}}) {
+gpu.func @create_mem_desc_with_stride() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+  //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  gpu.return
+}
+
 }


        


More information about the Mlir-commits mailing list