[Mlir-commits] [mlir] 68c4c83 - [MLIR][XeGPU] Matrix load/store subgroup distribution (#165008)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 3 12:48:31 PST 2025


Author: Artem Kroviakov
Date: 2025-11-03T21:48:27+01:00
New Revision: 68c4c83bcbf9612a02074b946fe6bb73054183ef

URL: https://github.com/llvm/llvm-project/commit/68c4c83bcbf9612a02074b946fe6bb73054183ef
DIFF: https://github.com/llvm/llvm-project/commit/68c4c83bcbf9612a02074b946fe6bb73054183ef.diff

LOG: [MLIR][XeGPU] Matrix load/store subgroup distribution (#165008)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
    mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/invalid.mlir
    mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 1481859e94a92..0c059967bb898 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -30,9 +30,11 @@ class SliceAttr;
 } // namespace xegpu
 } // namespace mlir
 
+// clang-format off
+#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+// clang-format on
 
 #define GET_ATTRDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 40352b44b6441..9c35c07a7e587 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -223,17 +223,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Derive a new layout by dropping InstData",
                     "xegpu::DistributeLayoutAttr",
                     "dropInstData">,
-    InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
-                      indices based on the effective subgroup layout.}],
+    InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
+                      indices based on the effective layout level.}],
                     "FailureOr<SmallVector<Value>>",
-                    "delinearizeSubgroupId",
+                    "delinearizeId",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
-    InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
-                      assigned to a subgroup identified by linearId. The shape parameter
-                      represents the workgroup-level problem size. Each subgroup may access
+    InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
+                      assigned to a level identified by linearId. The shape parameter
+                      represents the higher-level problem size. Each level may access
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
-                    "getOffsets",
+                    "computeDistributedCoords",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
     InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
                      to some other layout according to given permutation of (0...n-1).}],
@@ -476,17 +476,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return {};
     }
 
-    /// Delinearizes a linear subgroup ID into its multidimensional indices
-    /// based on the effective subgroup layout.
+    /// Delinearizes a linear ID into its multidimensional indices
+    /// based on the effective level of the layout.
     FailureOr<SmallVector<Value>>
-    delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+    delinearizeId(OpBuilder &builder, Location loc, Value linearId);
 
-    /// Generates instructions to compute multidimensional offsets for blocks
-    /// assigned to a subgroup identified by linearId. The shape parameter
-    /// represents the workgroup-level problem size. Each subgroup may access
+    /// Generates instructions to compute multidimensional coordinates for dist units
+    /// assigned to a level identified by linearId. The shape parameter
+    /// represents the higher-level problem size. Each `level` may access
     /// multiple blocks according to round-robin distribution rules.
     FailureOr<SmallVector<SmallVector<Value>>>
-    getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -643,14 +643,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// Delinearizes a linear subgroup ID into its multidimensional indices
     /// based on the effective subgroup layout.
     FailureOr<SmallVector<Value>>
-    delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+    delinearizeId(OpBuilder &builder, Location loc, Value linearId);
 
-    /// Generates instructions to compute multidimensional offsets for blocks
+    /// Generates instructions to compute multidimensional coordinates for blocks
     /// assigned to a subgroup identified by linearId. The shape parameter
     /// represents the workgroup-level problem size. Each subgroup may access
     /// multiple blocks according to round-robin distribution rules.
+
     FailureOr<SmallVector<SmallVector<Value>>>
-    getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);

diff  --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index b7af5413669c9..eb05628d4772b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
     The pass distributes subgroup level (SIMD) XeGPU ops to work items.
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
-                           "vector::VectorDialect"];
+                           "vector::VectorDialect", "index::IndexDialect"];
 }
 
 def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {

diff  --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2ed1f6ed..de552ce22ef62 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
     VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
     if (!valOrResVecTy)
       valOrResVecTy = VectorType::get(1, data.getType());
+    if (valOrResVecTy.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
 
     int64_t elemBitWidth =
         valOrResVecTy.getElementType().getIntOrFloatBitWidth();

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 83406c8c75dcf..397107b786c9e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,55 +37,61 @@ void XeGPUDialect::initialize() {
       >();
 }
 
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
 static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
-                         SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
-                         ArrayRef<int64_t> sizePerSg,
-                         ArrayRef<int64_t> sizePerWg) {
-
-  SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+               SmallVector<Value> delinearizedId,
+               ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+               ArrayRef<int64_t> srcShape) {
+  SmallVector<SmallVector<Value>> coordinates;
+
+  // A distribution unit must be less than or equal to `srcShape`
+  SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+      llvm::zip_equal(srcShape,
+                      computeElementwiseMul(subShapesLayout, subShape)),
+      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
 
-  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
-  SmallVector<Value> localOffsets = llvm::map_to_vector(
-      llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+  // Get the offset of `subShape` within a distribution unit.
+  SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+      llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
         return builder.createOrFold<index::MulOp>(
             loc, std::get<0>(t),
             builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
       });
 
-  // distUnit[i] is the minimum value between sizePerWg[i] and
-  // sgLayout[i] * sizePerSg[i]
-  SmallVector<int64_t> distUnit = llvm::map_to_vector(
-      llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
-      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+  // For each dist unit
   for (SmallVector<int64_t> unitOffs :
-       StaticTileOffsetRange(sizePerWg, distUnit)) {
+       StaticTileOffsetRange(srcShape, distUnitShape)) {
+    // Get dist unit offset within `srcShape`.
     SmallVector<Value> base =
         llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
           return arith::ConstantIndexOp::create(builder, loc, d);
         });
-
-    SmallVector<Value> adds = llvm::map_to_vector(
-        llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
-          return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
-                                                     std::get<1>(t));
-        });
-
+    // Calculate `subShape` offset within `srcShape`.
+    SmallVector<Value> adds =
+        llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+                            [&](const auto &t) -> Value {
+                              return builder.createOrFold<arith::AddIOp>(
+                                  loc, std::get<0>(t), std::get<1>(t));
+                            });
+    // Do not go beyond `srcShape` bounds.
     SmallVector<Value> mods = llvm::map_to_vector(
-        llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+        llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
           return builder.createOrFold<index::RemUOp>(
               loc, std::get<0>(t),
               arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
         });
 
-    offsets.push_back(mods);
+    coordinates.push_back(mods);
   }
-  return offsets;
+  return coordinates;
 }
 
 // Checks if the given shape can be evenly distributed based on the layout
@@ -272,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 }
 
 FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
-                                  Value linearId) {
-  // delinearizeSubgroupId is only available for
-  // workgroup-level layout attribute
-  if (!isForWorkgroup())
-    return failure();
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
 
   // TODO: handle order attribute
   auto hasDefaultOrder = [&]() {
@@ -287,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   };
   if (!hasDefaultOrder())
     return mlir::emitError(loc, "order attribute is currently not supported.");
-
-  auto dims =
-      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
-        return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-      });
+  SmallVector<int64_t> layout;
+  if (isForWorkgroup()) {
+    layout = getEffectiveSgLayoutAsInt();
+  } else if (isForSubgroup()) {
+    layout = getEffectiveLaneLayoutAsInt();
+  } else {
+    return failure();
+  }
+  auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
+    return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+  });
 
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
 /// instructions for computing multi-dimensional offsets when distributed by
 /// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
-                       ArrayRef<int64_t> shape) {
-  if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+                                     Value linearId, ArrayRef<int64_t> shape) {
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  if (isForWorkgroup()) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    layout = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  } else {
     return failure();
-
-  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
-  if (sgShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, sgLayout))
-      sgShape = derivedShape.value();
+  }
+  if (subShape.empty()) {
+    if (auto derivedShape = computeShapeRatio(shape, layout))
+      subShape = derivedShape.value();
     else
       return failure();
   }
 
   // delinearize Ids
-  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  auto maybeIds = delinearizeId(builder, loc, linearId);
   if (failed(maybeIds))
     return failure();
-  SmallVector<Value> sgIds = *maybeIds;
+  SmallVector<Value> ids = *maybeIds;
 
-  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
-                                  shape);
+  return genCoordinates(builder, loc, ids, layout, subShape, shape);
 }
 
 //===----------------------------------------------------------------------===//
@@ -375,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {
 }
 
 FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
-                                 Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
   SliceAttr attr = flatten();
   auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-  return parent.delinearizeSubgroupId(builder, loc, linearId);
+  return parent.delinearizeId(builder, loc, linearId);
 }
 
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
-                      ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+                                    Value linearId, ArrayRef<int64_t> shape) {
   assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
   if (!isForWorkgroup())
     return failure();
 
-  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
-  if (sgShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, sgLayout))
-      sgShape = derivedShape.value();
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  if (isForWorkgroup()) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    layout = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  } else {
+    return failure();
+  }
+
+  if (subShape.empty()) {
+    if (auto derivedShape = computeShapeRatio(shape, layout))
+      subShape = derivedShape.value();
     else
       return failure();
   }
 
   // delinearize Ids
-  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  auto maybeIds = delinearizeId(builder, loc, linearId);
   if (failed(maybeIds))
     return failure();
 
@@ -412,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
   SmallVector<Value> sgIds =
       XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
 
-  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
-                                  shape);
+  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
 }
 
 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..7b6c4b6c2c813 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
 
 LogicalResult
 IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
-                      UnitAttr subgroup_block_io,
+                      UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
                       function_ref<InFlightDiagnostic()> emitError) {
 
   if (!dataTy) {
     if (subgroup_block_io)
       return emitError() << "subgroup_block_io "
-                            "are only allowed when result is a 1D VectorType.";
+                            "are only allowed when result is a VectorType.";
     else
       return success();
   }
@@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
   ArrayRef<int64_t> dataShape = dataTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
 
+  SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+  ArrayAttr strideAttr = mdescTy.getStrideAttr();
+  SmallVector<int64_t> strides;
+  for (Attribute attr : strideAttr.getValue()) {
+    strides.push_back(cast<IntegerAttr>(attr).getInt());
+  }
+  if (subgroup_block_io && layout) {
+    auto laneData = layout.getEffectiveLaneDataAsInt();
+    auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+    if (!laneData.empty()) {
+      bool isLaneDataContiguous =
+          std::all_of(laneData.begin(), std::prev(laneData.end()),
+                      [](int x) { return x == 1; });
+      if (!isLaneDataContiguous)
+        return emitError() << "With subgroup_block_io, accessed data must be "
+                              "contiguous and coalesced.";
+      for (size_t i = 0; i < laneData.size(); ++i) {
+        if (laneLayout[i] != blockShape[i])
+          return emitError() << "With subgroup_block_io, the block shape must "
+                                "match the lane layout.";
+        if (laneLayout[i] != 1 && strides[i] != 1)
+          return emitError() << "With subgroup_block_io, the distributed "
+                                "dimensions must be contiguous.";
+      }
+    }
+  }
   if (dataShape.size() == 2) {
-    if (subgroup_block_io)
-      return emitError() << "subgroup_block_io "
-                            "are only allowed when result is a 1D VectorType.";
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
       return emitError() << "data shape must not exceed mem_desc shape.";
   } else {
-    SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
     // if the subgroup_block_io attribute is set,  mdescTy must have block
     // attribute
     if (subgroup_block_io && !blockShape.size())
@@ -1105,7 +1127,7 @@ LogicalResult LoadMatrixOp::verify() {
   MemDescType mdescTy = getMemDesc().getType();
 
   return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
-                               [&]() { return emitError(); });
+                               getLayoutAttr(), [&]() { return emitError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1129,7 +1151,7 @@ LogicalResult StoreMatrixOp::verify() {
   UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
   MemDescType mdescTy = getMemDesc().getType();
   return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
-                               [&]() { return emitError(); });
+                               getLayoutAttr(), [&]() { return emitError(); });
 }
 
 namespace mlir {

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 5a3b27ec6108e..bbd7733e89c29 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -912,6 +913,186 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
+    PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+    Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+  SmallVector<Value> newCoods;
+  auto maybeCoords =
+      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+  if (failed(maybeCoords))
+    return {};
+  assert(maybeCoords.value().size() == 1 &&
+         "Expected one set of distributed offsets");
+  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+      getAsOpFoldResult(origOffsets));
+  newCoods = llvm::to_vector(llvm::map_range(
+      ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+  return newCoods;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+
+    OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+      return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+    });
+    if (!producedByLastLoad)
+      return rewriter.notifyMatchFailure(
+          warpOp, "The last op is not xegpu::LoadMatrixOp");
+    const int operandIdx = producedByLastLoad->getOperandNumber();
+
+    VectorType sgPayloadTy =
+        dyn_cast<VectorType>(matrixOp.getResult().getType());
+    VectorType warpResultTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the load op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> operands = {matrixOp.getMemDesc()};
+    const unsigned offsetsStartIdx = operands.size();
+    operands.append(offsetsAsValues);
+
+    SmallVector<Type> operandTypes = llvm::to_vector(
+        llvm::map_range(operands, [](Value v) { return v.getType(); }));
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+    std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+              ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+    ValueRange currentOffsets =
+        ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    SmallVector<Value> newCoords = currentOffsets;
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    if (!matrixOp.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordinatesForMatrixOp(
+          rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+          currentOffsets);
+    }
+    xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+        rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+        newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
+        matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    // Resolve the output type and replace all uses.
+    rewriter.replaceAllUsesWith(
+        newWarpOp.getResult(operandIdx),
+        resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+    return success();
+  }
+};
+
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the store op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+    SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+    const unsigned offsetsStartIdx = operands.size();
+    operands.append(offsetsAsValues);
+
+    SmallVector<Type> operandTypes = llvm::to_vector(
+        llvm::map_range(operands, [](Value v) { return v.getType(); }));
+    operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+    std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+              ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+    ValueRange currentOffsets =
+        ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    SmallVector<Value> newCoords = currentOffsets;
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    if (!matrixOp.getSubgroupBlockIoAttr()) {
+      newCoords = computeDistributedCoordinatesForMatrixOp(
+          rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+          currentOffsets);
+    }
+
+    xegpu::StoreMatrixOp::create(
+        rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+        ValueRange(newCoords), newConstOffsetsAttr,
+        matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    rewriter.eraseOp(matrixOp);
+    return success();
+  }
+};
+
 /// Distribute a scattered load op. The logic and requirements are the same as
 /// for the scattered store distribution. The warpOp's payload vector is
 /// expected to be distributed by the load's result consumer.
@@ -1443,7 +1624,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
                LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
                GpuBarrierDistribution, VectorMultiReductionDistribution,
                LoadDistribution, StoreDistribution, VectorTransposeDistribution,
-               VectorBitcastDistribution,
+               VectorBitcastDistribution, LoadMatrixDistribution,
+               StoreMatrixDistribution,
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
@@ -1468,6 +1650,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       // Layouts are needed for vector type only.
       if (!isa<VectorType>(operand.get().getType()))
         continue;
+      if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+        continue;
 
       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..79eea55c8b78a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
   // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
   // descriptors to be accessed, based on the layout information.
   ArrayRef<int64_t> wgShape = op.getDataShape();
-  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  auto maybeDescOffsets =
+      layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
   if (failed(maybeDescOffsets))
     return failure();
 
@@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       // Get subgroup id
       Value sgId =
           gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+      auto sgOffsets =
+          layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
       if (failed(sgOffsets))
         return failure();
 
@@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    auto sgOffsets =
+        layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
     if (failed(sgOffsets))
       return failure();
 

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce0ec0d0..92f353717ac59 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
   // expected-error at +1 {{Mask should match value except the chunk size dim}}
-  xegpu.store %val, %src[%offsets], %mask 
+  xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
   return
 }
@@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
   return
 }
 
-// -----
-func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
-  return
-}
-
-
 // -----
 func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
   // expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -900,16 +892,25 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
 }
 
 // -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
-  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>, %arg1: vector<2x16xf32>) {
+  // expected-error at +1 {{With subgroup_block_io, accessed data must be contiguous and coalesced}}
+  xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+        vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>
   return
 }
 
 // -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
-  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>, %arg1: vector<16x2xf32>) {
+  // expected-error at +1 {{With subgroup_block_io, the distributed dimensions must be contiguous}}
+  xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} :
+        vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>
   return
 }
 
+// -----
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>, %arg1: vector<16x2xf32>) {
+  // expected-error at +1 {{With subgroup_block_io, the block shape must match the lane layout}}
+  xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+        vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>
+  return
+}

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..8946d14e80b72 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,66 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
+// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+    xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
+// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
+// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
+// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
+// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+    xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+      !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
+    xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+      vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
+    gpu.return
+  }
+}

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 76d461108b296..93d51441f5b81 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+    auto maybeOffsets =
+        sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape);
     if (failed(maybeOffsets))
       return failure();
 


        


More information about the Mlir-commits mailing list