[Mlir-commits] [mlir] [MLIR][XeGPU] Matrix load/store subgroup distribution (PR #165008)

Artem Kroviakov llvmlistbot at llvm.org
Wed Oct 29 02:33:20 PDT 2025


https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/165008

>From 887f9781ea3b62cd990d9df7066f28ec049f603b Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 24 Oct 2025 16:08:00 +0000
Subject: [PATCH 1/6] [MLIR][XeGPU] Matrix load/store subgroup distribution

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 124 ++++++++++++++++--
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  15 +++
 2 files changed, 131 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc196c0bf7..fe059bb86eba2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+    constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+    int operandIdx{-1};
+
+    VectorType payloadTy;
+    VectorType warpResultTy;
+    if constexpr (isLoad) {
+      OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+        return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+      });
+      if (!producedByLastLoad)
+        return rewriter.notifyMatchFailure(
+            warpOp, "The last op is not xegpu::LoadMatrixOp");
+      operandIdx = producedByLastLoad->getOperandNumber();
+      payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+      warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+    } else {
+      payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+    }
+    if (!payloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the load op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp,
+          "The matrix op payload has no layouts, using defaults instead.");
+
+    SmallVector<Value> operands;
+    if constexpr (isLoad)
+      operands = {matrixOp.getMemDesc()};
+    else
+      operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+    const unsigned offsetsStartIdx = operands.size();
+    operands.append(offsetsAsValues);
+
+    SmallVector<Type> operandTypes = llvm::to_vector(
+        llvm::map_range(operands, [](Value v) { return v.getType(); }));
+    if constexpr (!isLoad)
+      operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    rewriter.setInsertionPointAfter(newWarpOp);
+    unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
+    newOperands[operandIdxToModify] = arith::AddIOp::create(
+        rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
+        newWarpOp.getLaneid());
+
+    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+    std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+              ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+    ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    if constexpr (isLoad) {
+      xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+          rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+          newOperands[0], newOffsets, newConstOffsetsAttr,
+          matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+      // Resolve the output type and replace all uses.
+      rewriter.replaceAllUsesWith(
+          newWarpOp.getResult(operandIdx),
+          resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+    } else {
+      xegpu::StoreMatrixOp::create(
+          rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+          newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
+          xegpu::DistributeLayoutAttr{});
+      rewriter.eraseOp(matrixOp);
+    }
+    return success();
+  }
+};
+
 /// Distribute a scattered load op. The logic and requirements are the same as
 /// for the scattered store distribution. The warpOp's payload vector is
 /// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
-               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
-               GpuBarrierDistribution, VectorMultiReductionDistribution,
-               LoadDistribution, StoreDistribution, VectorTransposeDistribution,
-               VectorBitcastDistribution,
-               MemrefExtractAlignedPointerAsIndexDistribution>(
-      patterns.getContext(),
-      /*pattern benefit=*/regularPatternBenefit);
+  patterns
+      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+           DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
+           VectorMultiReductionDistribution, LoadDistribution,
+           StoreDistribution, VectorTransposeDistribution,
+           VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
+           MatrixOpDistribution<xegpu::StoreMatrixOp>,
+           MemrefExtractAlignedPointerAsIndexDistribution>(
+          patterns.getContext(),
+          /*pattern benefit=*/regularPatternBenefit);
   patterns.add<VectorShapeCastDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
       // Layouts are needed for vector type only.
       if (!isa<VectorType>(operand.get().getType()))
         continue;
+      if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+        continue;
 
       auto layout = xegpu::getDistributeLayoutAttr(operand.get());
       if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..3fcc747217c9d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+
+    xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+
+    gpu.return 
+  }
+}

>From f80ee32a523ddda05eaf789358ef300efd1208d3 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sat, 25 Oct 2025 10:51:52 +0000
Subject: [PATCH 2/6] Add offset calculation

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h    |   4 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  47 ++++--
 .../mlir/Dialect/XeGPU/Transforms/Passes.td   |   2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 154 ++++++++++--------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  44 +++--
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  10 +-
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  15 +-
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |   3 +-
 8 files changed, 166 insertions(+), 113 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 1481859e94a92..0c059967bb898 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -30,9 +30,11 @@ class SliceAttr;
 } // namespace xegpu
 } // namespace mlir
 
+// clang-format off
+#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+// clang-format on
 
 #define GET_ATTRDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 19a52317956d2..1b515b11658c0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -167,6 +167,16 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
   let cppNamespace = "::mlir::xegpu";
 }
 
+def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
+def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
+def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
+def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
+      "The enumeration for the scope of fence operation.",
+      [XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::xegpu";
+}
+
 def XeGPU_FenceScopeAttr:
   EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
     let summary = [{Describes the scope of fence.
@@ -223,18 +233,18 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Derive a new layout by dropping InstData",
                     "xegpu::DistributeLayoutAttr",
                     "dropInstData">,
-    InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
-                      indices based on the effective subgroup layout.}],
+    InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
+                      indices based on the effective `level` layout.}],
                     "FailureOr<SmallVector<Value>>",
-                    "delinearizeSubgroupId",
-                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
-    InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
-                      assigned to a subgroup identified by linearId. The shape parameter
-                      represents the workgroup-level problem size. Each subgroup may access
+                    "delinearizeId",
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
+    InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
+                      assigned to a `level` identified by linearId. The shape parameter
+                      represents the higher-level problem size. Each `level` may access
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
-                    "getOffsets",
-                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+                    "computeDistributedCoords",
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
     InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
                      to some other layout according to given permutation of (0...n-1).}],
                     /*retTy=*/"bool",
@@ -476,17 +486,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return {};
     }
 
-    /// Delinearizes a linear subgroup ID into its multidimensional indices
-    /// based on the effective subgroup layout.
+    /// Delinearizes a linear ID into its multidimensional indices
+    /// based on the effective `level` layout.
     FailureOr<SmallVector<Value>>
-    delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+    delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
 
-    /// Generates instructions to compute multidimensional offsets for blocks
-    /// assigned to a subgroup identified by linearId. The shape parameter
-    /// represents the workgroup-level problem size. Each subgroup may access
+    /// Generates instructions to compute multidimensional offsets for dist units
+    /// assigned to a `level` identified by linearId. The shape parameter
+    /// represents the higher-level problem size. Each `level` may access
     /// multiple blocks according to round-robin distribution rules.
     FailureOr<SmallVector<SmallVector<Value>>>
-    getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -643,14 +653,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// Delinearizes a linear subgroup ID into its multidimensional indices
     /// based on the effective subgroup layout.
     FailureOr<SmallVector<Value>>
-    delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+    delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
 
     /// Generates instructions to compute multidimensional offsets for blocks
     /// assigned to a subgroup identified by linearId. The shape parameter
     /// represents the workgroup-level problem size. Each subgroup may access
     /// multiple blocks according to round-robin distribution rules.
+
     FailureOr<SmallVector<SmallVector<Value>>>
-    getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..5f803233041ab 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
     The pass distributes subgroup level (SIMD) XeGPU ops to work items.
   }];
   let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
-                           "vector::VectorDialect"];
+                           "vector::VectorDialect", "index::IndexDialect"];
 }
 
 def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e909548fe0b..cbe459bfcbb48 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -38,47 +38,47 @@ void XeGPUDialect::initialize() {
       >();
 }
 
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
 static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
-                         SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
-                         ArrayRef<int64_t> sizePerSg,
-                         ArrayRef<int64_t> sizePerWg) {
-
+genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
+           ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+           ArrayRef<int64_t> srcShape) {
   SmallVector<SmallVector<Value>> offsets;
 
-  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
-  SmallVector<Value> localOffsets = llvm::map_to_vector(
-      llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+  // A distribution unit must be less than or equal to `srcShape`
+  SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+      llvm::zip_equal(srcShape,
+                      computeElementwiseMul(subShapesLayout, subShape)),
+      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+  // Get the offset of `subShape` within a distribution unit.
+  SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+      llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
         return builder.createOrFold<index::MulOp>(
             loc, std::get<0>(t),
             builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
       });
 
-  // distUnit[i] is the minimum value between sizePerWg[i] and
-  // sgLayout[i] * sizePerSg[i]
-  SmallVector<int64_t> distUnit = llvm::map_to_vector(
-      llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
-      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+  // For each dist unit
   for (SmallVector<int64_t> unitOffs :
-       StaticTileOffsetRange(sizePerWg, distUnit)) {
+       StaticTileOffsetRange(srcShape, distUnitShape)) {
+    // Get dist unit offset within `srcShape`.
     SmallVector<Value> base =
         llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
           return arith::ConstantIndexOp::create(builder, loc, d);
         });
-
-    SmallVector<Value> adds = llvm::map_to_vector(
-        llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
-          return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
-                                                     std::get<1>(t));
-        });
-
+    // Calculate `subShape` offset within `srcShape`.
+    SmallVector<Value> adds =
+        llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+                            [&](const auto &t) -> Value {
+                              return builder.createOrFold<arith::AddIOp>(
+                                  loc, std::get<0>(t), std::get<1>(t));
+                            });
+    // Do not go beyond `srcShape` bounds.
     SmallVector<Value> mods = llvm::map_to_vector(
-        llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+        llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
           return builder.createOrFold<index::RemUOp>(
               loc, std::get<0>(t),
               arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
@@ -268,12 +268,8 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 }
 
 FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
-                                  Value linearId) {
-  // delinearizeSubgroupId is only available for
-  // workgroup-level layout attribute
-  if (!isForWorkgroup())
-    return failure();
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
+                          xegpu::DistributionLevel idLevel) {
 
   // TODO: handle order attribute
   auto hasDefaultOrder = [&]() {
@@ -283,41 +279,53 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   };
   if (!hasDefaultOrder())
     return mlir::emitError(loc, "order attribute is currently not supported.");
-
-  auto dims =
-      llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
-        return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
-      });
+  SmallVector<int64_t> layout;
+  if (idLevel == xegpu::DistributionLevel::SG) {
+    layout = getEffectiveSgLayoutAsInt();
+  } else if (idLevel == xegpu::DistributionLevel::WI) {
+    layout = getEffectiveLaneLayoutAsInt();
+  } else {
+    return failure();
+  }
+  auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
+    return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+  });
 
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
 /// instructions for computing multi-dimensional offsets when distributed by
 /// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
-                       ArrayRef<int64_t> shape) {
-  if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+                                     Value linearId, ArrayRef<int64_t> shape,
+                                     xegpu::DistributionLevel targetLevel) {
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  if (targetLevel == DistributionLevel::SG) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (targetLevel == DistributionLevel::WI) {
+    layout = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  } else {
     return failure();
-
-  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
-  if (sgShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, sgLayout))
-      sgShape = derivedShape.value();
+  }
+  if (subShape.empty()) {
+    if (auto derivedShape = computeShapeRatio(shape, layout))
+      subShape = derivedShape.value();
     else
       return failure();
   }
 
   // delinearize Ids
-  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
   if (failed(maybeIds))
     return failure();
-  SmallVector<Value> sgIds = *maybeIds;
+  SmallVector<Value> ids = *maybeIds;
 
-  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
-                                  shape);
+  return genOffsets(builder, loc, ids, layout, subShape, shape);
 }
 
 //===----------------------------------------------------------------------===//
@@ -371,34 +379,45 @@ SliceAttr SliceAttr::flatten() const {
 }
 
 FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
-                                 Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
+                         xegpu::DistributionLevel level) {
   SliceAttr attr = flatten();
   auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-  return parent.delinearizeSubgroupId(builder, loc, linearId);
+  return parent.delinearizeId(builder, loc, linearId, level);
 }
 
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
-                      ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+                                    Value linearId, ArrayRef<int64_t> shape,
+                                    xegpu::DistributionLevel targetLevel) {
   assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
   if (!isForWorkgroup())
     return failure();
 
-  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
-  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
-  if (sgShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, sgLayout))
-      sgShape = derivedShape.value();
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  if (targetLevel == DistributionLevel::SG) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (targetLevel == DistributionLevel::WI) {
+    layout = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  } else {
+    return failure();
+  }
+
+  if (subShape.empty()) {
+    if (auto derivedShape = computeShapeRatio(shape, layout))
+      subShape = derivedShape.value();
     else
       return failure();
   }
 
   // delinearize Ids
-  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
   if (failed(maybeIds))
     return failure();
 
@@ -408,8 +427,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
   SmallVector<Value> sgIds =
       XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
 
-  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
-                                  shape);
+  return genOffsets(builder, loc, sgIds, layout, subShape, shape);
 }
 
 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index fe059bb86eba2..b02290d4b251b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -919,7 +920,7 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
     constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
     int operandIdx{-1};
 
-    VectorType payloadTy;
+    VectorType sgPayloadTy;
     VectorType warpResultTy;
     if constexpr (isLoad) {
       OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
@@ -929,12 +930,12 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
         return rewriter.notifyMatchFailure(
             warpOp, "The last op is not xegpu::LoadMatrixOp");
       operandIdx = producedByLastLoad->getOperandNumber();
-      payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+      sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
       warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
     } else {
-      payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+      sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
     }
-    if (!payloadTy)
+    if (!sgPayloadTy)
       return rewriter.notifyMatchFailure(
           matrixOp, "the matrix op payload must be a vector type");
 
@@ -952,7 +953,7 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
           matrixOp, "the matrix operation lacks layout attribute");
 
     FailureOr<VectorType> distPayloadByWarpOpOrFailure =
-        getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
     if (failed(distPayloadByWarpOpOrFailure))
       return rewriter.notifyMatchFailure(
           matrixOp,
@@ -977,23 +978,36 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
     SmallVector<Value> newOperands = llvm::map_to_vector(
         newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
 
-    rewriter.setInsertionPointAfter(newWarpOp);
-    unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
-    newOperands[operandIdxToModify] = arith::AddIOp::create(
-        rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
-        newWarpOp.getLaneid());
-
     SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
     std::fill(newConstOffsets.begin(), newConstOffsets.end(),
               ShapedType::kDynamic);
     DenseI64ArrayAttr newConstOffsetsAttr =
         rewriter.getDenseI64ArrayAttr(newConstOffsets);
-    ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+    ValueRange currentOffsets =
+        ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    rewriter.setInsertionPointAfter(newWarpOp);
+    SmallVector<Value> newOffsets = currentOffsets;
+    if (!matrixOp.getSubgroupBlockIoAttr()) {
+      auto maybeDescOffsets = layout.computeDistributedCoords(
+          rewriter, loc, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+          xegpu::DistributionLevel::WI);
+      if (failed(maybeDescOffsets))
+        return failure();
+      assert(maybeDescOffsets.value().size() == 1 &&
+             "Expected same number of offset sets as number of accessed "
+             "sub-tensors or sub-memory descriptors.");
+      SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+          rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
+          offsets);
+      newOffsets = llvm::to_vector(llvm::map_range(
+          ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+    }
 
     if constexpr (isLoad) {
       xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
           rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
-          newOperands[0], newOffsets, newConstOffsetsAttr,
+          newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
           matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
       // Resolve the output type and replace all uses.
       rewriter.replaceAllUsesWith(
@@ -1002,8 +1016,8 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
     } else {
       xegpu::StoreMatrixOp::create(
           rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
-          newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
-          xegpu::DistributeLayoutAttr{});
+          ValueRange(newOffsets), newConstOffsetsAttr,
+          matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
       rewriter.eraseOp(matrixOp);
     }
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..93e23cea9c7dd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
   // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
   // descriptors to be accessed, based on the layout information.
   ArrayRef<int64_t> wgShape = op.getDataShape();
-  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  auto maybeDescOffsets = layout.computeDistributedCoords(
+      rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
   if (failed(maybeDescOffsets))
     return failure();
 
@@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       // Get subgroup id
       Value sgId =
           gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-      auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+      auto sgOffsets = layout.computeDistributedCoords(
+          rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
       if (failed(sgOffsets))
         return failure();
 
@@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+    auto sgOffsets = layout.computeDistributedCoords(
+        rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
     if (failed(sgOffsets))
       return failure();
 
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3fcc747217c9d..b69c661f8cfd5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -268,15 +268,20 @@ gpu.module @xevm_module{
 
 // -----
 // CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
-// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
-// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
+// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%0]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%0]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
 gpu.module @xevm_module{
   gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
     %c0 = arith.constant 0 : index
     %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
-
     xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
-
-    gpu.return 
+    gpu.return
   }
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 76d461108b296..4408e827a97fc 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+    auto maybeOffsets = sliceAttr.computeDistributedCoords(
+        rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
     if (failed(maybeOffsets))
       return failure();
 

>From b4f5a4d325a3069ad658362c701fe6c8ad9b8a81 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 27 Oct 2025 16:56:11 +0000
Subject: [PATCH 3/6] Relax `subgroup_block_io` dimensionality restriction

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    |  2 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  5 +--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 25 +----------
 .../Dialect/XeGPU/subgroup-distribute.mlir    | 43 ++++++++++++++++++-
 4 files changed, 45 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66dbe9e45..53b8c4f0bbd59 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
     VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
     if (!valOrResVecTy)
       valOrResVecTy = VectorType::get(1, data.getType());
+    if (valOrResVecTy.getShape().size() != 1)
+      return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
 
     int64_t elemBitWidth =
         valOrResVecTy.getElementType().getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..68f49d648e738 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -181,7 +181,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
   if (!dataTy) {
     if (subgroup_block_io)
       return emitError() << "subgroup_block_io "
-                            "are only allowed when result is a 1D VectorType.";
+                            "are only allowed when result is a VectorType.";
     else
       return success();
   }
@@ -193,9 +193,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
 
   if (dataShape.size() == 2) {
-    if (subgroup_block_io)
-      return emitError() << "subgroup_block_io "
-                            "are only allowed when result is a 1D VectorType.";
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
       return emitError() << "data shape must not exceed mem_desc shape.";
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce0ec0d0..0b0ef27e39233 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
   // expected-error at +1 {{Mask should match value except the chunk size dim}}
-  xegpu.store %val, %src[%offsets], %mask 
+  xegpu.store %val, %src[%offsets], %mask
         : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
   return
 }
@@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
   return
 }
 
-// -----
-func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
-  return
-}
-
-
 // -----
 func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
   // expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -898,18 +890,3 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
   xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
   return
 }
-
-// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
-  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
-  return
-}
-
-// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
-  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
-  return
-}
-
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index b69c661f8cfd5..fe129428dc189 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -271,8 +271,8 @@ gpu.module @xevm_module{
 // CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
 // CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
 // CHECK: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%0]
-// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%0]
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
 // CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
 // CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
 // CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
@@ -285,3 +285,42 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
+// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
+// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
+// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+    xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<2x1xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index
+gpu.module @xevm_module{
+  gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+      !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
+    xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+      vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+    gpu.return
+  }
+}

>From 3c4a5aa8e0a7bf66da85f22552e86615bdc8d1d9 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 28 Oct 2025 16:11:34 +0000
Subject: [PATCH 4/6] Address feedback

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |   8 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  21 +-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 203 ++++++++++++------
 .../Transforms/XeGPUWgToSgDistribute.cpp      |   6 +-
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  14 +-
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |   2 +-
 6 files changed, 163 insertions(+), 91 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1b515b11658c0..794a84c839548 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -171,7 +171,7 @@ def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
 def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
 def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
 def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
-      "The enumeration for the scope of fence operation.",
+      "Specify target level for offsets distribution utility.",
       [XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::xegpu";
@@ -243,7 +243,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       represents the higher-level problem size. Each `level` may access
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
-                    "computeDistributedCoords",
+                    "computeDistributedOffsets",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
     InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
                      to some other layout according to given permutation of (0...n-1).}],
@@ -496,7 +496,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     /// represents the higher-level problem size. Each `level` may access
     /// multiple blocks according to round-robin distribution rules.
     FailureOr<SmallVector<SmallVector<Value>>>
-    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -661,7 +661,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// multiple blocks according to round-robin distribution rules.
 
     FailureOr<SmallVector<SmallVector<Value>>>
-    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index cbe459bfcbb48..e335efefb608f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -41,6 +41,11 @@ void XeGPUDialect::initialize() {
 // A `srcShape` consists of N distribution units, each being `subShapesLayout` x
 // `subShape`. A `delinearizedId` is used to identify a particular `subShape`
 // within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
 static SmallVector<SmallVector<Value>>
 genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
            ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
@@ -294,13 +299,13 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+/// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
 /// instructions for computing multi-dimensional offsets when distributed by
 /// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
-                                     Value linearId, ArrayRef<int64_t> shape,
-                                     xegpu::DistributionLevel targetLevel) {
+LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
+                                      Value linearId, ArrayRef<int64_t> shape,
+                                      xegpu::DistributionLevel targetLevel) {
   SmallVector<int64_t> layout;
   SmallVector<int64_t> subShape;
   if (targetLevel == DistributionLevel::SG) {
@@ -386,13 +391,13 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
   return parent.delinearizeId(builder, loc, linearId, level);
 }
 
-// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
 // instructions for computing multi-dimensional offsets when distributed by
 // LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
-                                    Value linearId, ArrayRef<int64_t> shape,
-                                    xegpu::DistributionLevel targetLevel) {
+SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
+                                     Value linearId, ArrayRef<int64_t> shape,
+                                     xegpu::DistributionLevel targetLevel) {
   assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
   if (!isForWorkgroup())
     return failure();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b02290d4b251b..c576172683f68 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -907,34 +907,48 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
-template <class MatrixOp>
-struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
+    PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+    Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+  SmallVector<Value> newOffsets;
+  ;
+  auto maybeDescOffsets = layout.computeDistributedOffsets(
+      rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
+  if (failed(maybeDescOffsets))
+    return {};
+  assert(maybeDescOffsets.value().size() == 1 &&
+         "Expected one set of distributed offsets");
+  SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+      rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
+      getAsOpFoldResult(origOffsets));
+  newOffsets = llvm::to_vector(llvm::map_range(
+      ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+  return newOffsets;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
     gpu::YieldOp yield = warpOp.getTerminator();
     Operation *lastNode = yield->getPrevNode();
-    auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+    auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
     if (!matrixOp)
       return failure();
-    constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
-    int operandIdx{-1};
-
-    VectorType sgPayloadTy;
-    VectorType warpResultTy;
-    if constexpr (isLoad) {
-      OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
-        return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
-      });
-      if (!producedByLastLoad)
-        return rewriter.notifyMatchFailure(
-            warpOp, "The last op is not xegpu::LoadMatrixOp");
-      operandIdx = producedByLastLoad->getOperandNumber();
-      sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
-      warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
-    } else {
-      sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
-    }
+
+    OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+      return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+    });
+    if (!producedByLastLoad)
+      return rewriter.notifyMatchFailure(
+          warpOp, "The last op is not xegpu::LoadMatrixOp");
+    const int operandIdx = producedByLastLoad->getOperandNumber();
+
+    VectorType sgPayloadTy =
+        dyn_cast<VectorType>(matrixOp.getResult().getType());
+    VectorType warpResultTy =
+        cast<VectorType>(warpOp.getResult(operandIdx).getType());
     if (!sgPayloadTy)
       return rewriter.notifyMatchFailure(
           matrixOp, "the matrix op payload must be a vector type");
@@ -956,21 +970,14 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
         getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
     if (failed(distPayloadByWarpOpOrFailure))
       return rewriter.notifyMatchFailure(
-          matrixOp,
-          "The matrix op payload has no layouts, using defaults instead.");
-
-    SmallVector<Value> operands;
-    if constexpr (isLoad)
-      operands = {matrixOp.getMemDesc()};
-    else
-      operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+          matrixOp, "The matrix op payload has no layout.");
+
+    SmallVector<Value> operands = {matrixOp.getMemDesc()};
     const unsigned offsetsStartIdx = operands.size();
     operands.append(offsetsAsValues);
 
     SmallVector<Type> operandTypes = llvm::to_vector(
         llvm::map_range(operands, [](Value v) { return v.getType(); }));
-    if constexpr (!isLoad)
-      operandTypes[0] = *distPayloadByWarpOpOrFailure;
 
     SmallVector<size_t> newRetIndices;
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -986,40 +993,97 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
     ValueRange currentOffsets =
         ValueRange(newOperands).drop_front(offsetsStartIdx);
 
-    rewriter.setInsertionPointAfter(newWarpOp);
     SmallVector<Value> newOffsets = currentOffsets;
+    rewriter.setInsertionPointAfter(newWarpOp);
+
     if (!matrixOp.getSubgroupBlockIoAttr()) {
-      auto maybeDescOffsets = layout.computeDistributedCoords(
-          rewriter, loc, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
-          xegpu::DistributionLevel::WI);
-      if (failed(maybeDescOffsets))
-        return failure();
-      assert(maybeDescOffsets.value().size() == 1 &&
-             "Expected same number of offset sets as number of accessed "
-             "sub-tensors or sub-memory descriptors.");
-      SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
-          rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
-          offsets);
-      newOffsets = llvm::to_vector(llvm::map_range(
-          ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+      newOffsets = computeDistributedOffsetsForMatrixOp(
+          rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+          currentOffsets);
     }
+    xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+        rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+        newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
+        matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    // Resolve the output type and replace all uses.
+    rewriter.replaceAllUsesWith(
+        newWarpOp.getResult(operandIdx),
+        resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+    return success();
+  }
+};
 
-    if constexpr (isLoad) {
-      xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
-          rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
-          newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
-          matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
-      // Resolve the output type and replace all uses.
-      rewriter.replaceAllUsesWith(
-          newWarpOp.getResult(operandIdx),
-          resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
-    } else {
-      xegpu::StoreMatrixOp::create(
-          rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
-          ValueRange(newOffsets), newConstOffsetsAttr,
-          matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
-      rewriter.eraseOp(matrixOp);
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    gpu::YieldOp yield = warpOp.getTerminator();
+    Operation *lastNode = yield->getPrevNode();
+    auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+    if (!matrixOp)
+      return failure();
+
+    VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+    if (!sgPayloadTy)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix op payload must be a vector type");
+
+    auto loc = matrixOp.getLoc();
+    auto offsets = matrixOp.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(matrixOp,
+                                         "the store op must have offsets");
+    SmallVector<Value> offsetsAsValues =
+        vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+    auto layout = matrixOp.getLayoutAttr();
+    if (!layout)
+      return rewriter.notifyMatchFailure(
+          matrixOp, "the matrix operation lacks layout attribute");
+
+    FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+    if (failed(distPayloadByWarpOpOrFailure))
+      return rewriter.notifyMatchFailure(
+          matrixOp, "The matrix op payload has no layout.");
+
+    SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+    const unsigned offsetsStartIdx = operands.size();
+    operands.append(offsetsAsValues);
+
+    SmallVector<Type> operandTypes = llvm::to_vector(
+        llvm::map_range(operands, [](Value v) { return v.getType(); }));
+    operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+    SmallVector<size_t> newRetIndices;
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+    std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+              ShapedType::kDynamic);
+    DenseI64ArrayAttr newConstOffsetsAttr =
+        rewriter.getDenseI64ArrayAttr(newConstOffsets);
+    ValueRange currentOffsets =
+        ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+    SmallVector<Value> newOffsets = currentOffsets;
+    rewriter.setInsertionPointAfter(newWarpOp);
+
+    if (!matrixOp.getSubgroupBlockIoAttr()) {
+      newOffsets = computeDistributedOffsetsForMatrixOp(
+          rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+          currentOffsets);
     }
+
+    xegpu::StoreMatrixOp::create(
+        rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+        ValueRange(newOffsets), newConstOffsetsAttr,
+        matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+    rewriter.eraseOp(matrixOp);
     return success();
   }
 };
@@ -1551,16 +1615,15 @@ struct XeGPUSubgroupDistributePass final
 
 void xegpu::populateXeGPUSubgroupDistributePatterns(
     RewritePatternSet &patterns) {
-  patterns
-      .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
-           DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
-           VectorMultiReductionDistribution, LoadDistribution,
-           StoreDistribution, VectorTransposeDistribution,
-           VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
-           MatrixOpDistribution<xegpu::StoreMatrixOp>,
-           MemrefExtractAlignedPointerAsIndexDistribution>(
-          patterns.getContext(),
-          /*pattern benefit=*/regularPatternBenefit);
+  patterns.add<CreateNdDescDistribution, StoreNdDistribution,
+               LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
+               GpuBarrierDistribution, VectorMultiReductionDistribution,
+               LoadDistribution, StoreDistribution, VectorTransposeDistribution,
+               VectorBitcastDistribution, LoadMatrixDistribution,
+               StoreMatrixDistribution,
+               MemrefExtractAlignedPointerAsIndexDistribution>(
+      patterns.getContext(),
+      /*pattern benefit=*/regularPatternBenefit);
   patterns.add<VectorShapeCastDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/highPatternBenefit);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 93e23cea9c7dd..35072f0529072 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
   // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
   // descriptors to be accessed, based on the layout information.
   ArrayRef<int64_t> wgShape = op.getDataShape();
-  auto maybeDescOffsets = layout.computeDistributedCoords(
+  auto maybeDescOffsets = layout.computeDistributedOffsets(
       rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
   if (failed(maybeDescOffsets))
     return failure();
@@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       // Get subgroup id
       Value sgId =
           gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-      auto sgOffsets = layout.computeDistributedCoords(
+      auto sgOffsets = layout.computeDistributedOffsets(
           rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
       if (failed(sgOffsets))
         return failure();
@@ -1053,7 +1053,7 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto sgOffsets = layout.computeDistributedCoords(
+    auto sgOffsets = layout.computeDistributedOffsets(
         rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
     if (failed(sgOffsets))
       return failure();
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index fe129428dc189..da4151024edb5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -291,19 +291,22 @@ gpu.module @xevm_module{
 // CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
 // CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
 // CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index
 // CHECK: %[[LANE_ID:.*]] = gpu.lane_id
 // CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
 // CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
 // CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
 // CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
-// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]]
 // CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
 // CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
 gpu.module @xevm_module{
   gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
     %c0 = arith.constant 0 : index
-    %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
-    xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+    %c1 = arith.constant 1 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+    xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
     gpu.return
   }
 }
@@ -317,9 +320,10 @@ gpu.module @xevm_module{
 gpu.module @xevm_module{
   gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
     %c0 = arith.constant 0 : index
-    %1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+    %c1 = arith.constant 1 : index
+    %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
       !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
-    xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+    xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
       vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
     gpu.return
   }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 4408e827a97fc..61ebdce5d7995 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,7 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto maybeOffsets = sliceAttr.computeDistributedCoords(
+    auto maybeOffsets = sliceAttr.computeDistributedOffsets(
         rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
     if (failed(maybeOffsets))
       return failure();

>From 5965b543738799f504a171624384b4c388fc0deb Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 28 Oct 2025 17:30:33 +0000
Subject: [PATCH 5/6] Remove DistributionLevel enum

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 32 +++++++------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 30 ++++++++---------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  5 ++-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 12 +++----
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  4 +--
 5 files changed, 34 insertions(+), 49 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 794a84c839548..699a7c7e0cf98 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -167,16 +167,6 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
   let cppNamespace = "::mlir::xegpu";
 }
 
-def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
-def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
-def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
-def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
-      "Specify target level for offsets distribution utility.",
-      [XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::xegpu";
-}
-
 def XeGPU_FenceScopeAttr:
   EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
     let summary = [{Describes the scope of fence.
@@ -234,17 +224,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "xegpu::DistributeLayoutAttr",
                     "dropInstData">,
     InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
-                      indices based on the effective `level` layout.}],
+                      indices based on the effective layout level.}],
                     "FailureOr<SmallVector<Value>>",
                     "delinearizeId",
-                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
     InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
-                      assigned to a `level` identified by linearId. The shape parameter
-                      represents the higher-level problem size. Each `level` may access
+                      assigned to a level identified by linearId. The shape parameter
+                      represents the higher-level problem size. Each level may access
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "computeDistributedOffsets",
-                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
     InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
                      to some other layout according to given permutation of (0...n-1).}],
                     /*retTy=*/"bool",
@@ -487,16 +477,16 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     }
 
     /// Delinearizes a linear ID into its multidimensional indices
-    /// based on the effective `level` layout.
+    /// based on the effective level of the layout.
     FailureOr<SmallVector<Value>>
-    delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
+    delinearizeId(OpBuilder &builder, Location loc, Value linearId);
 
     /// Generates instructions to compute multidimensional offsets for dist units
-    /// assigned to a `level` identified by linearId. The shape parameter
+    /// assigned to a level identified by linearId. The shape parameter
     /// represents the higher-level problem size. Each `level` may access
     /// multiple blocks according to round-robin distribution rules.
     FailureOr<SmallVector<SmallVector<Value>>>
-    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -653,7 +643,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// Delinearizes a linear subgroup ID into its multidimensional indices
     /// based on the effective subgroup layout.
     FailureOr<SmallVector<Value>>
-    delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
+    delinearizeId(OpBuilder &builder, Location loc, Value linearId);
 
     /// Generates instructions to compute multidimensional offsets for blocks
     /// assigned to a subgroup identified by linearId. The shape parameter
@@ -661,7 +651,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// multiple blocks according to round-robin distribution rules.
 
     FailureOr<SmallVector<SmallVector<Value>>>
-    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index e335efefb608f..d162d36bef504 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -273,8 +273,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
 }
 
 FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
-                          xegpu::DistributionLevel idLevel) {
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
 
   // TODO: handle order attribute
   auto hasDefaultOrder = [&]() {
@@ -285,9 +284,9 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
   if (!hasDefaultOrder())
     return mlir::emitError(loc, "order attribute is currently not supported.");
   SmallVector<int64_t> layout;
-  if (idLevel == xegpu::DistributionLevel::SG) {
+  if (isForWorkgroup()) {
     layout = getEffectiveSgLayoutAsInt();
-  } else if (idLevel == xegpu::DistributionLevel::WI) {
+  } else if (isForSubgroup()) {
     layout = getEffectiveLaneLayoutAsInt();
   } else {
     return failure();
@@ -304,14 +303,13 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
 /// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
-                                      Value linearId, ArrayRef<int64_t> shape,
-                                      xegpu::DistributionLevel targetLevel) {
+                                      Value linearId, ArrayRef<int64_t> shape) {
   SmallVector<int64_t> layout;
   SmallVector<int64_t> subShape;
-  if (targetLevel == DistributionLevel::SG) {
+  if (isForWorkgroup()) {
     layout = getEffectiveSgLayoutAsInt();
     subShape = getEffectiveSgDataAsInt();
-  } else if (targetLevel == DistributionLevel::WI) {
+  } else if (isForSubgroup()) {
     layout = getEffectiveLaneLayoutAsInt();
     subShape = getEffectiveLaneDataAsInt();
   } else {
@@ -325,7 +323,7 @@ LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
   }
 
   // delinearize Ids
-  auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
+  auto maybeIds = delinearizeId(builder, loc, linearId);
   if (failed(maybeIds))
     return failure();
   SmallVector<Value> ids = *maybeIds;
@@ -384,11 +382,10 @@ SliceAttr SliceAttr::flatten() const {
 }
 
 FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
-                         xegpu::DistributionLevel level) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
   SliceAttr attr = flatten();
   auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-  return parent.delinearizeId(builder, loc, linearId, level);
+  return parent.delinearizeId(builder, loc, linearId);
 }
 
 // Implements DistributeLayoutAttr::computeDistributedOffsets to generate
@@ -396,18 +393,17 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
 // LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
-                                     Value linearId, ArrayRef<int64_t> shape,
-                                     xegpu::DistributionLevel targetLevel) {
+                                     Value linearId, ArrayRef<int64_t> shape) {
   assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
   if (!isForWorkgroup())
     return failure();
 
   SmallVector<int64_t> layout;
   SmallVector<int64_t> subShape;
-  if (targetLevel == DistributionLevel::SG) {
+  if (isForWorkgroup()) {
     layout = getEffectiveSgLayoutAsInt();
     subShape = getEffectiveSgDataAsInt();
-  } else if (targetLevel == DistributionLevel::WI) {
+  } else if (isForSubgroup()) {
     layout = getEffectiveLaneLayoutAsInt();
     subShape = getEffectiveLaneDataAsInt();
   } else {
@@ -422,7 +418,7 @@ SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
   }
 
   // delinearize Ids
-  auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
+  auto maybeIds = delinearizeId(builder, loc, linearId);
   if (failed(maybeIds))
     return failure();
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c576172683f68..b9ec19f15b65c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -911,9 +911,8 @@ static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
     PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
     Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
   SmallVector<Value> newOffsets;
-  ;
-  auto maybeDescOffsets = layout.computeDistributedOffsets(
-      rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
+  auto maybeDescOffsets =
+      layout.computeDistributedOffsets(rewriter, loc, laneId, payloadShape);
   if (failed(maybeDescOffsets))
     return {};
   assert(maybeDescOffsets.value().size() == 1 &&
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 35072f0529072..5f8627bc75d4d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,8 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
   // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
   // descriptors to be accessed, based on the layout information.
   ArrayRef<int64_t> wgShape = op.getDataShape();
-  auto maybeDescOffsets = layout.computeDistributedOffsets(
-      rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+  auto maybeDescOffsets =
+      layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
   if (failed(maybeDescOffsets))
     return failure();
 
@@ -831,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       // Get subgroup id
       Value sgId =
           gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-      auto sgOffsets = layout.computeDistributedOffsets(
-          rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+      auto sgOffsets =
+          layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
       if (failed(sgOffsets))
         return failure();
 
@@ -1053,8 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto sgOffsets = layout.computeDistributedOffsets(
-        rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+    auto sgOffsets =
+        layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
     if (failed(sgOffsets))
       return failure();
 
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 61ebdce5d7995..ba5591a996eec 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,8 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
 
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-    auto maybeOffsets = sliceAttr.computeDistributedOffsets(
-        rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+    auto maybeOffsets =
+        sliceAttr.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
     if (failed(maybeOffsets))
       return failure();
 

>From 246761e4f3eac6a68e3307dc7c1503c185ca4aa0 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Wed, 29 Oct 2025 09:33:02 +0000
Subject: [PATCH 6/6] Improve verification

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 12 +++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 29 ++++++++--------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 21 +++++++++---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 34 +++++++++----------
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  6 ++--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 16 +++++++++
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  4 +--
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  2 +-
 8 files changed, 77 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 699a7c7e0cf98..7a9784703f0ec 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -228,12 +228,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "FailureOr<SmallVector<Value>>",
                     "delinearizeId",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
-    InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
+    InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
                       assigned to a level identified by linearId. The shape parameter
                       represents the higher-level problem size. Each level may access
                       multiple blocks according to round-robin distribution rules.}],
                     "FailureOr<SmallVector<SmallVector<Value>>>",
-                    "computeDistributedOffsets",
+                    "computeDistributedCoords",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
     InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
                      to some other layout according to given permutation of (0...n-1).}],
@@ -481,12 +481,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<Value>>
     delinearizeId(OpBuilder &builder, Location loc, Value linearId);
 
-    /// Generates instructions to compute multidimensional offsets for dist units
+    /// Generates instructions to compute multidimensional coordinates for dist units
     /// assigned to a level identified by linearId. The shape parameter
     /// represents the higher-level problem size. Each `level` may access
     /// multiple blocks according to round-robin distribution rules.
     FailureOr<SmallVector<SmallVector<Value>>>
-    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -645,13 +645,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<Value>>
     delinearizeId(OpBuilder &builder, Location loc, Value linearId);
 
-    /// Generates instructions to compute multidimensional offsets for blocks
+    /// Generates instructions to compute multidimensional coordinates for blocks
     /// assigned to a subgroup identified by linearId. The shape parameter
     /// represents the workgroup-level problem size. Each subgroup may access
     /// multiple blocks according to round-robin distribution rules.
 
     FailureOr<SmallVector<SmallVector<Value>>>
-    computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
+    computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index d162d36bef504..5a9b15e73002d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -47,10 +47,11 @@ void XeGPUDialect::initialize() {
 // `delinearizedId` is used to identify a 16x32 of a subgroup in each
 // distribution unit.
 static SmallVector<SmallVector<Value>>
-genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
-           ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
-           ArrayRef<int64_t> srcShape) {
-  SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+               SmallVector<Value> delinearizedId,
+               ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+               ArrayRef<int64_t> srcShape) {
+  SmallVector<SmallVector<Value>> coordinates;
 
   // A distribution unit must be less than or equal to `srcShape`
   SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
@@ -89,9 +90,9 @@ genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
               arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
         });
 
-    offsets.push_back(mods);
+    coordinates.push_back(mods);
   }
-  return offsets;
+  return coordinates;
 }
 
 // Checks if the given shape can be evenly distributed based on the layout
@@ -298,12 +299,12 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
 /// instructions for computing multi-dimensional offsets when distributed by
 /// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
-                                      Value linearId, ArrayRef<int64_t> shape) {
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+                                     Value linearId, ArrayRef<int64_t> shape) {
   SmallVector<int64_t> layout;
   SmallVector<int64_t> subShape;
   if (isForWorkgroup()) {
@@ -328,7 +329,7 @@ LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
     return failure();
   SmallVector<Value> ids = *maybeIds;
 
-  return genOffsets(builder, loc, ids, layout, subShape, shape);
+  return genCoordinates(builder, loc, ids, layout, subShape, shape);
 }
 
 //===----------------------------------------------------------------------===//
@@ -388,12 +389,12 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
   return parent.delinearizeId(builder, loc, linearId);
 }
 
-// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
 // instructions for computing multi-dimensional offsets when distributed by
 // LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
-                                     Value linearId, ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+                                    Value linearId, ArrayRef<int64_t> shape) {
   assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
   if (!isForWorkgroup())
     return failure();
@@ -428,7 +429,7 @@ SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
   SmallVector<Value> sgIds =
       XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
 
-  return genOffsets(builder, loc, sgIds, layout, subShape, shape);
+  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
 }
 
 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 68f49d648e738..5f1de23265b2b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,7 +175,7 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
 
 LogicalResult
 IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
-                      UnitAttr subgroup_block_io,
+                      UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
                       function_ref<InFlightDiagnostic()> emitError) {
 
   if (!dataTy) {
@@ -191,7 +191,20 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
 
   ArrayRef<int64_t> dataShape = dataTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
-
+  if (subgroup_block_io && layout) {
+    auto laneData = layout.getEffectiveLaneDataAsInt();
+    if (!laneData.empty()) {
+      bool isLaneDataLinear =
+          std::all_of(laneData.begin(), std::prev(laneData.end()),
+                      [](int x) { return x == 1; });
+      if (!isLaneDataLinear)
+        return emitError()
+               << "With subgroup_block_io, lane data must be linear.";
+      if (isLaneDataLinear && laneData.back() != 1)
+        return emitError()
+               << "With subgroup_block_io, lane data must be coalesced.";
+    }
+  }
   if (dataShape.size() == 2) {
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -1102,7 +1115,7 @@ LogicalResult LoadMatrixOp::verify() {
   MemDescType mdescTy = getMemDesc().getType();
 
   return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
-                               [&]() { return emitError(); });
+                               getLayoutAttr(), [&]() { return emitError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1126,7 +1139,7 @@ LogicalResult StoreMatrixOp::verify() {
   UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
   MemDescType mdescTy = getMemDesc().getType();
   return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
-                               [&]() { return emitError(); });
+                               getLayoutAttr(), [&]() { return emitError(); });
 }
 
 namespace mlir {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b9ec19f15b65c..29ccc0a48786b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -907,22 +907,22 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
-static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
+static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
     PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
     Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
-  SmallVector<Value> newOffsets;
-  auto maybeDescOffsets =
-      layout.computeDistributedOffsets(rewriter, loc, laneId, payloadShape);
-  if (failed(maybeDescOffsets))
+  SmallVector<Value> newCoods;
+  auto maybeCoords =
+      layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+  if (failed(maybeCoords))
     return {};
-  assert(maybeDescOffsets.value().size() == 1 &&
+  assert(maybeCoords.value().size() == 1 &&
          "Expected one set of distributed offsets");
   SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
-      rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
+      rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
       getAsOpFoldResult(origOffsets));
-  newOffsets = llvm::to_vector(llvm::map_range(
+  newCoods = llvm::to_vector(llvm::map_range(
       ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
-  return newOffsets;
+  return newCoods;
 }
 
 /// Pattern for distributing xegpu::LoadMatrixOp.
@@ -969,7 +969,7 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
         getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
     if (failed(distPayloadByWarpOpOrFailure))
       return rewriter.notifyMatchFailure(
-          matrixOp, "The matrix op payload has no layout.");
+          matrixOp, "Failed to distribute matrix op payload based on layout.");
 
     SmallVector<Value> operands = {matrixOp.getMemDesc()};
     const unsigned offsetsStartIdx = operands.size();
@@ -992,17 +992,17 @@ struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
     ValueRange currentOffsets =
         ValueRange(newOperands).drop_front(offsetsStartIdx);
 
-    SmallVector<Value> newOffsets = currentOffsets;
+    SmallVector<Value> newCoords = currentOffsets;
     rewriter.setInsertionPointAfter(newWarpOp);
 
     if (!matrixOp.getSubgroupBlockIoAttr()) {
-      newOffsets = computeDistributedOffsetsForMatrixOp(
+      newCoords = computeDistributedCoordinatesForMatrixOp(
           rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
           currentOffsets);
     }
     xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
         rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
-        newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
+        newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
         matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
     // Resolve the output type and replace all uses.
     rewriter.replaceAllUsesWith(
@@ -1045,7 +1045,7 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
         getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
     if (failed(distPayloadByWarpOpOrFailure))
       return rewriter.notifyMatchFailure(
-          matrixOp, "The matrix op payload has no layout.");
+          matrixOp, "Failed to distribute matrix op payload based on layout.");
 
     SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
     const unsigned offsetsStartIdx = operands.size();
@@ -1069,18 +1069,18 @@ struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
     ValueRange currentOffsets =
         ValueRange(newOperands).drop_front(offsetsStartIdx);
 
-    SmallVector<Value> newOffsets = currentOffsets;
+    SmallVector<Value> newCoords = currentOffsets;
     rewriter.setInsertionPointAfter(newWarpOp);
 
     if (!matrixOp.getSubgroupBlockIoAttr()) {
-      newOffsets = computeDistributedOffsetsForMatrixOp(
+      newCoords = computeDistributedCoordinatesForMatrixOp(
           rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
           currentOffsets);
     }
 
     xegpu::StoreMatrixOp::create(
         rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
-        ValueRange(newOffsets), newConstOffsetsAttr,
+        ValueRange(newCoords), newConstOffsetsAttr,
         matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
     rewriter.eraseOp(matrixOp);
     return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 5f8627bc75d4d..79eea55c8b78a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -115,7 +115,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
   // descriptors to be accessed, based on the layout information.
   ArrayRef<int64_t> wgShape = op.getDataShape();
   auto maybeDescOffsets =
-      layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
+      layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
   if (failed(maybeDescOffsets))
     return failure();
 
@@ -832,7 +832,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
       Value sgId =
           gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
       auto sgOffsets =
-          layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
+          layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
       if (failed(sgOffsets))
         return failure();
 
@@ -1054,7 +1054,7 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
     auto sgOffsets =
-        layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
+        layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
     if (failed(sgOffsets))
       return failure();
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 0b0ef27e39233..f5a271811f1e3 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -890,3 +890,19 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
   xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
   return
 }
+
+// -----
+func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, %arg1: vector<2x16xf32>) {
+  // expected-error at +1 {{With subgroup_block_io, lane data must be linear}}
+  xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+        vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>
+  return
+}
+
+// -----
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>, %arg1: vector<16x2xf32>) {
+  // expected-error at +1 {{With subgroup_block_io, lane data must be coalesced}}
+  xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} :
+        vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32]>>
+  return
+}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index da4151024edb5..30308de3f22e3 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -321,9 +321,9 @@ gpu.module @xevm_module{
   gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
     %c0 = arith.constant 0 : index
     %c1 = arith.constant 1 : index
-    %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+    %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
       !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
-    xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+    xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
       vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
     gpu.return
   }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index ba5591a996eec..93d51441f5b81 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -201,7 +201,7 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
     Value sgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
     auto maybeOffsets =
-        sliceAttr.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
+        sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape);
     if (failed(maybeOffsets))
       return failure();
 



More information about the Mlir-commits mailing list