[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)

Chao Chen llvmlistbot at llvm.org
Wed Aug 20 10:49:44 PDT 2025


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/154403

>From fd09d122269c0f53a6340e9cafc32fff95f711eb Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 19 Aug 2025 17:28:28 +0000
Subject: [PATCH 1/8] refactor

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h    |   2 +
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  45 ++++-
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   8 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  13 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |   4 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 157 ++++++++++++++++--
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  68 ++++++++
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |   4 +-
 8 files changed, 268 insertions(+), 33 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 3592da4c46364..ce33da9632c2b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
@@ -23,6 +24,7 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
+class DistributLayoutAttrInterface;
 class LayoutAttr;
 class SliceAttr;
 } // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index a94987885c9e0..adfd8bae75a5a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,22 +175,31 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def LayoutTrait: AttrInterface<"LayoutTrait"> {
+def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface"> {
   let cppNamespace = "::mlir::xegpu";
   let description = [{
     Common trait for all XeGPU layouts.
   }];
 
   let methods = [
+    InterfaceMethod<"Check the availability of workgroup level layouts",
+                    "bool",
+                    "isWgLayout">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
+    InterfaceMethod<"Get the num of effective subgroups",
+                    "int64_t",
+                    "getNumSubgroups">,
     InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgLayoutAsInt">,
     InterfaceMethod<"Get the SgData field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
+    InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
+                    "xegpu::DistributLayoutAttrInterface",
+                    "dropSgLayoutAndData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
                     "FailureOr<SmallVector<Value>>",
@@ -206,7 +215,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
   ];
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributLayoutAttrInterface]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -346,6 +355,13 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
       return 0;
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     LayoutAttr dropSgLayoutAndData() {
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
@@ -393,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributLayoutAttrInterface]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
@@ -420,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
   }];
 
   let parameters = (ins
-    "xegpu::LayoutTrait": $parent,
+    "xegpu::DistributLayoutAttrInterface": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
@@ -450,6 +466,13 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return parent.isSgLayout();
     }
 
+    int64_t getNumSubgroups() {
+      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
+      if (sgLayout.has_value())
+        return computeProduct(*sgLayout);
+      return 0;
+    }
+
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
@@ -474,6 +497,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
       return std::nullopt;
     }
 
+    SliceAttr dropSgLayoutAndData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropSgLayoutAndData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
+    SliceAttr dropInstData() {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      parent = parent.dropInstData();
+      return SliceAttr::get(getContext(), parent, attr.getDims());
+    }
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index eb54d6887681d..deea44cd14db0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1150,7 +1150,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributLayoutAttrInterface>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
   let assemblyFormat = [{
@@ -1175,7 +1175,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 
   let builders = [
     OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
-                    "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1194,7 +1194,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<LayoutTrait>:$layout
+    OptionalAttr<DistributLayoutAttrInterface>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,7 +1213,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   }];
   let builders = [
     OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ea8cb1f45972..9f6e498854c18 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,8 +290,8 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions
+/// for computing multi-dimensional offsets when distributed by LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                        ArrayRef<int64_t> shape) {
@@ -322,7 +322,8 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
 //===----------------------------------------------------------------------===//
 LogicalResult
 SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
-                  xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+                  xegpu::DistributLayoutAttrInterface parent,
+                  DenseI64ArrayAttr dims) {
   if (!parent || !dims)
     return emitError() << "expected parent layout and dims attribute";
 
@@ -340,7 +341,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 }
 
 SliceAttr SliceAttr::flatten() const {
-  xegpu::LayoutTrait parent = getParent();
+  xegpu::DistributLayoutAttrInterface parent = getParent();
   SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
 
   while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,8 +376,8 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return parent.delinearizeSubgroupId(builder, loc, linearId);
 }
 
-/// Implements LayoutTrait::getOffsets to generate instructions for
-/// computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions
+/// for computing multi-dimensional offsets when distributed by SliceAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                       ArrayRef<int64_t> shape) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 906c71d8b8dad..05a3604ae2b43 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
                          TypedValue<MemDescType> memDesc,
                          llvm::ArrayRef<OpFoldResult> offsets,
-                         LayoutTrait layout) {
+                         DistributLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
                           TypedValue<MemDescType> memDesc,
                           llvm::ArrayRef<OpFoldResult> offsets,
-                          LayoutTrait layout) {
+                          DistributLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 8f1208e77ca5d..39077c1fb64b6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -55,17 +55,16 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
 }
 
 static std::pair<SmallVector<int64_t>, int>
-getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
+getSgShapeAndCount(ArrayRef<int64_t> shape,
+                   xegpu::DistributLayoutAttrInterface layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
-
   if (layout && layout.isWgLayout()) {
-    DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout();
-    auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    if (DenseI32ArrayAttr sgDataAttr = layout.getSgData())
-      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
-    else
-      sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape);
+    SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
+    if (auto maybeSgData = layout.getSgDataAsInt())
+      sgShape = *maybeSgData;
+    else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
+      sgShape = *maybeDerivedSgData;
     SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
     // Clamp distUnit to the original shape to handle cases where data is
     // shared among subgroups, which may cause distUnit to exceed the original
@@ -723,8 +722,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 // is lowered to:
 //   #a = #xegpu.layout<inst_data = [16, 16]>
 //   #b = #xegpu.layout<inst_data = [8, 16]>
-//   store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
-//   %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
+//   store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
+//   %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
 //   xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
 // clang-format on
 struct WgToSgConvertLayoutOp
@@ -884,6 +883,123 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+                            SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by load_matrix and store_matrix
+/// operations.
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+LogicalResult distributeMatrixOp(
+    ConversionPatternRewriter &rewriter,
+    typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor, OpType op,
+    ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+  Location loc = op.getLoc();
+  auto layout = op.getLayoutAttr();
+  if (!layout || !layout.isWgLayout())
+    return failure();
+
+  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+  // adjust the linearId if the range specifier is present
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+  if (sgIdRangeSpecified) {
+    if (layout.getNumSubgroups() != endOfRange - startOfRange)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    sgId = rewriter.create<index::SubOp>(loc, startOfRangeVal, sgId);
+  }
+
+  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  if (failed(maybeMdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+  callback(wgOffsets, *maybeMdescOffsets);
+  return success();
+}
+
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+                                     Location loc, ArrayRef<OpFoldResult> lhs,
+                                     ArrayRef<OpFoldResult> rhs) {
+  return llvm::map_to_vector(
+      llvm::zip_equal(lhs, rhs), [&](auto p) -> OpFoldResult {
+        auto l = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<0>(p));
+        auto r = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<1>(p));
+        return rewriter.create<index::AddOp>(loc, l, r).getResult();
+      });
+}
+
+struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
+  using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    VectorType valueTy = op.getRes().getType();
+    ArrayRef<int64_t> wgShape = valueTy.getShape();
+    Type elemTy = valueTy.getElementType();
+
+    // the call back function for creating new LoadMatrixOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+                        SmallVector<SmallVector<Value>> descOffsets) {
+      auto layout = op.getLayoutAttr();
+      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+      VectorType newResTy = VectorType::get(sgShape, elemTy);
+
+      SmallVector<Value> newOps;
+      for (auto offsets : descOffsets) {
+        SmallVector<OpFoldResult> sgOffsets =
+            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
+        auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+            loc, newResTy, op.getMemDesc(), sgOffsets,
+            layout.dropSgLayoutAndData());
+        newOps.push_back(newOp);
+      }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+    };
+
+    return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback);
+  }
+};
+
+struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
+  using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    VectorType valueTy = op.getData().getType();
+    ArrayRef<int64_t> wgShape = valueTy.getShape();
+
+    // the call back function for creating new StoreMatrixOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+                        SmallVector<SmallVector<Value>> descOffsets) {
+      auto layout = op.getLayoutAttr();
+      for (auto [v, descOffsets] : llvm::zip(adaptor.getData(), descOffsets)) {
+        SmallVector<OpFoldResult> sgOffsets =
+            add(rewriter, loc, baseOffsets, getAsOpFoldResult(descOffsets));
+        rewriter.create<xegpu::StoreMatrixOp>(
+            loc, v, op.getMemDesc(), sgOffsets, layout.dropSgLayoutAndData());
+      }
+      rewriter.eraseOp(op);
+    };
+    return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback);
+  }
+};
+
 } // namespace
 
 namespace mlir {
@@ -895,7 +1011,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
            WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
            WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
            WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
-           WgToSgArithConstantOp>(patterns.getContext());
+           WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
+          patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -985,7 +1102,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return xegpu::TensorDescType();
   };
 
-  auto isLegal = [&](xegpu::LayoutAttr layout) -> bool {
+  auto isLegal = [&](xegpu::DistributLayoutAttrInterface layout) -> bool {
     return !layout || !layout.isWgLayout();
   };
 
@@ -1002,9 +1119,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>(
-      [=](vector::BroadcastOp op) -> bool {
-        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+  target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
+      [=](xegpu::LoadMatrixOp op) -> bool {
+        return isLegal(op.getLayoutAttr());
+      });
+
+  target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
+      [=](xegpu::StoreMatrixOp op) -> bool {
+        return isLegal(op.getLayoutAttr());
       });
 
   target.addDynamicallyLegalOp<arith::ConstantOp>(
@@ -1015,6 +1137,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
         return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
       [=](xegpu::ConvertLayoutOp op) -> bool {
         return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index f4a49da71605f..5f851e9003a0e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -372,4 +372,72 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_load_matrix
+  // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
+  gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
+    //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c2:%.+]] = arith.constant 2 : index
+    //CHECK: [[c4:%.+]] = arith.constant 4 : index
+    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
+    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
+    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
+    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
+    //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
+    //CHECK: [[c64:%.+]] = arith.constant 64 : index
+    //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
+    //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+    //CHECK: [[off_y:%.+]] = index.add [[c0_2]], [[mod_y]]
+    //CHECK: [[c0_3:%.+]] = arith.constant 0 : index
+    //CHECK: [[off_x:%.+]] = index.add [[c0_3]], [[mod_x]]
+    //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
+    %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
+    gpu.return
+  }
+
+  //CHECK-LABEL: distribute_store_matrix
+  //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
+  gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) {
+    //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
+    //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c2:%.+]] = arith.constant 2 : index
+    //CHECK: [[c4:%.+]] = arith.constant 4 : index
+    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
+    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
+    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
+    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+    //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
+    //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
+    //CHECK: [[c64:%.+]] = arith.constant 64 : index
+    //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y]], [[c64]]
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x]], [[c128]]
+    //CHECK: [[c0_3:%.+]] = arith.constant 0 : index
+    //CHECK: [[off_y:%.+]] = index.add [[c0_3]], [[mod_y]]
+    //CHECK: [[c0_4:%.+]] = arith.constant 0 : index
+    //CHECK: [[off_x:%.+]] = index.add [[c0_4]], [[mod_x]]
+    //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
+    %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
+
+    gpu.return
+  }
+
 }
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 58962714b7864..d94d285b1105d 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -156,8 +156,8 @@ struct TestXeGPUUnrollingPatterns
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 // Test pattern for distributing vector::StepOp from workgroup to subgroup.
-// Validates LayoutTrait interfaces for offset computation abstraction between
-// LayoutAttr and SliceAttr.
+// Validates DistributLayoutAttrInterface interfaces for offset computation
+// abstraction between LayoutAttr and SliceAttr.
 class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
   using OpConversionPattern<vector::StepOp>::OpConversionPattern;
 

>From 6fc6ec7b69d6d7d1e86ec05e876f80876fe46e6c Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 19 Aug 2025 18:35:43 +0000
Subject: [PATCH 2/8] refactor createNdOp

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h    |   2 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  10 +-
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  12 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  14 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |   4 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 209 ++++++++----------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |   8 +-
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  16 +-
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |   2 +-
 9 files changed, 134 insertions(+), 143 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index ce33da9632c2b..1d152f0c9ca9a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -24,7 +24,7 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
-class DistributLayoutAttrInterface;
+class DistributeLayoutAttrInterface;
 class LayoutAttr;
 class SliceAttr;
 } // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index adfd8bae75a5a..de86141ad006a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,7 +175,7 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface"> {
+def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
   let cppNamespace = "::mlir::xegpu";
   let description = [{
     Common trait for all XeGPU layouts.
@@ -198,7 +198,7 @@ def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface">
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
     InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
-                    "xegpu::DistributLayoutAttrInterface",
+                    "xegpu::DistributeLayoutAttrInterface",
                     "dropSgLayoutAndData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
@@ -215,7 +215,7 @@ def DistributLayoutAttrInterface: AttrInterface<"DistributLayoutAttrInterface">
   ];
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributLayoutAttrInterface]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -409,7 +409,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributLayoutAttrInterfa
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributLayoutAttrInterface]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
@@ -436,7 +436,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributLayoutAttrInterface]
   }];
 
   let parameters = (ins
-    "xegpu::DistributLayoutAttrInterface": $parent,
+    "xegpu::DistributeLayoutAttrInterface": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index deea44cd14db0..3ba9eaa4a66da 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -232,6 +232,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return static_cast<unsigned>(MemorySpace::Global);
     }
 
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
+    }
+
   }];
 }
 
@@ -1150,7 +1154,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<DistributLayoutAttrInterface>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
   let assemblyFormat = [{
@@ -1175,7 +1179,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 
   let builders = [
     OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
-                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributLayoutAttrInterface": $layout)>,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1194,7 +1198,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<DistributLayoutAttrInterface>:$layout
+    OptionalAttr<DistributeLayoutAttrInterface>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,7 +1217,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   }];
   let builders = [
     OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributLayoutAttrInterface": $layout)>,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9f6e498854c18..de118b7faea4d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,8 +290,9 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions
-/// for computing multi-dimensional offsets when distributed by LayoutAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                        ArrayRef<int64_t> shape) {
@@ -322,7 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
 //===----------------------------------------------------------------------===//
 LogicalResult
 SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
-                  xegpu::DistributLayoutAttrInterface parent,
+                  xegpu::DistributeLayoutAttrInterface parent,
                   DenseI64ArrayAttr dims) {
   if (!parent || !dims)
     return emitError() << "expected parent layout and dims attribute";
@@ -341,7 +342,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 }
 
 SliceAttr SliceAttr::flatten() const {
-  xegpu::DistributLayoutAttrInterface parent = getParent();
+  xegpu::DistributeLayoutAttrInterface parent = getParent();
   SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
 
   while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -376,8 +377,9 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return parent.delinearizeSubgroupId(builder, loc, linearId);
 }
 
-/// Implements DistributLayoutAttrInterface::getOffsets to generate instructions
-/// for computing multi-dimensional offsets when distributed by SliceAttr.
+/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// instructions for computing multi-dimensional offsets when distributed by
+/// SliceAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                       ArrayRef<int64_t> shape) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 05a3604ae2b43..0e22af900daf1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
                          TypedValue<MemDescType> memDesc,
                          llvm::ArrayRef<OpFoldResult> offsets,
-                         DistributLayoutAttrInterface layout) {
+                         DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
                           TypedValue<MemDescType> memDesc,
                           llvm::ArrayRef<OpFoldResult> offsets,
-                          DistributLayoutAttrInterface layout) {
+                          DistributeLayoutAttrInterface layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 39077c1fb64b6..ca1209e776d0e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -56,7 +56,7 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
 
 static std::pair<SmallVector<int64_t>, int>
 getSgShapeAndCount(ArrayRef<int64_t> shape,
-                   xegpu::DistributLayoutAttrInterface layout) {
+                   xegpu::DistributeLayoutAttrInterface layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
   if (layout && layout.isWgLayout()) {
@@ -76,6 +76,72 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
   return std::make_pair(sgShape, count);
 }
 
+// An util helper to generate elementwise addition ops for index computing.
+// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
+// left-alignment is performed.
+static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
+                                     Location loc, ArrayRef<OpFoldResult> lhs,
+                                     ArrayRef<OpFoldResult> rhs) {
+  SmallVector<OpFoldResult> reversedResult;
+  auto l = lhs.rbegin();
+  auto r = rhs.rbegin();
+  for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
+    if (l == lhs.rend()) {
+      reversedResult.push_back(*r);
+    } else if (r == rhs.rend()) {
+      reversedResult.push_back(*l);
+    } else {
+      auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
+      auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
+      auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
+      reversedResult.push_back(add);
+    }
+  }
+  return llvm::to_vector(llvm::reverse(reversedResult));
+}
+
+// A callback funtion type used to create new load/store_matrix ops
+using CreatorFuncType =
+    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
+                            SmallVector<SmallVector<Value>> &descOffsets)>;
+
+/// Utility helper for distributing logic shared by operations with offsets
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
+              xegpu::StoreMatrixOp>::value>>
+static LogicalResult
+distributeOp(ConversionPatternRewriter &rewriter,
+             typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
+             OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+  Location loc = op.getLoc();
+  auto layout = op.getLayoutAttr();
+  if (!layout || !layout.isWgLayout())
+    return failure();
+
+  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
+
+  // adjust the linearId if the range specifier is present
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+  if (sgIdRangeSpecified) {
+    if (layout.getNumSubgroups() != endOfRange - startOfRange)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
+  }
+
+  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  if (failed(maybeMdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+  callback(wgOffsets, *maybeMdescOffsets);
+  return success();
+}
+
 /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor
 /// from a workgroup descriptor. It replaces the offsets and sizes with
 /// appropriate values for the subgroup.
@@ -136,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
-    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-    if (!layout)
-      return failure();
-    Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    // sgLayout must be present for workgroup-level distribution.
-    SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout())
-      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    else
-      return rewriter.notifyMatchFailure(
-          op, "sgLayout attribute is required in layout");
-
-    // Get the subgroup ID
-    Value linearSgId =
-        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-    int64_t startOfRange = -1, endOfRange = -1;
-    bool sgIdRangeSpecified =
-        isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
-    if (sgIdRangeSpecified) {
-      int64_t sgCount = endOfRange - startOfRange;
-      if (computeProduct(sgLayout) != sgCount)
-        return rewriter.notifyMatchFailure(
-            op, "sg_layout size must match the sg_id_range");
-      // Subtract startOfRange from the original subgroup id to get
-      // the adjusted sg id
-      Value startOfRangeVal =
-          arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
-      linearSgId =
-          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
-    }
+    Type elemTy = tdescTy.getElementType();
 
-    auto maybeTdescOffsets =
-        layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-    if (failed(maybeTdescOffsets))
-      return failure();
+    // the call back function for creating new CreateNdOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
+                        SmallVector<SmallVector<Value>> descOffsets) {
+      xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+      auto newTdescTy = xegpu::TensorDescType::get(
+          ctx, sgShape, elemTy, tdescTy.getEncoding(),
+          layout.dropSgLayoutAndData());
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    xegpu::TensorDescType newTdescTy =
-        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
-                                   layout.dropSgLayoutAndData());
+      SmallVector<Value> newOps;
+      for (auto offsets : descOffsets) {
+        SmallVector<OpFoldResult> sgOffsets =
+            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
+        auto newOp = xegpu::CreateNdDescOp::create(
+            rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
+            op.getMixedSizes(), op.getMixedStrides());
 
-    SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
-    for (auto tdescOffsets : *maybeTdescOffsets) {
-      SmallVector<OpFoldResult> sgOffsets;
-      size_t rank = tdescOffsets.size();
-      for (size_t i = 0; i < rank; i++) {
-        size_t idx = origOffsets.size() - rank + i;
-        Value add = rewriter.createOrFold<index::AddOp>(
-            loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
-        sgOffsets.push_back(add);
+        newOps.push_back(newOp);
       }
+      rewriter.replaceOpWithMultiple(op, {newOps});
+    };
 
-      auto newOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
-          op.getMixedSizes(), op.getMixedStrides());
-      newCreateNdOps.push_back(newOp);
-    }
-    rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
-    return success();
+    return distributeOp(rewriter, adaptor, op, wgShape, callback);
   }
 };
 
@@ -883,59 +913,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
-// A callback funtion type used to create new load/store_matrix ops
-using CreatorFuncType =
-    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
-                            SmallVector<SmallVector<Value>> &descOffsets)>;
-
-/// Utility helper for distributing logic shared by load_matrix and store_matrix
-/// operations.
-template <typename OpType,
-          typename = std::enable_if_t<llvm::is_one_of<
-              OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
-LogicalResult distributeMatrixOp(
-    ConversionPatternRewriter &rewriter,
-    typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor, OpType op,
-    ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
-  Location loc = op.getLoc();
-  auto layout = op.getLayoutAttr();
-  if (!layout || !layout.isWgLayout())
-    return failure();
-
-  Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
-
-  // adjust the linearId if the range specifier is present
-  int64_t startOfRange = -1, endOfRange = -1;
-  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
-  if (sgIdRangeSpecified) {
-    if (layout.getNumSubgroups() != endOfRange - startOfRange)
-      return rewriter.notifyMatchFailure(
-          op, "sg_layout size must match the sg_id_range");
-    Value startOfRangeVal =
-        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
-    sgId = rewriter.create<index::SubOp>(loc, startOfRangeVal, sgId);
-  }
-
-  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
-  if (failed(maybeMdescOffsets))
-    return failure();
-
-  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
-  callback(wgOffsets, *maybeMdescOffsets);
-  return success();
-}
-
-static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
-                                     Location loc, ArrayRef<OpFoldResult> lhs,
-                                     ArrayRef<OpFoldResult> rhs) {
-  return llvm::map_to_vector(
-      llvm::zip_equal(lhs, rhs), [&](auto p) -> OpFoldResult {
-        auto l = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<0>(p));
-        auto r = getValueOrCreateConstantIndexOp(rewriter, loc, std::get<1>(p));
-        return rewriter.create<index::AddOp>(loc, l, r).getResult();
-      });
-}
-
 struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
   using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
   LogicalResult
@@ -968,7 +945,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
       rewriter.replaceOpWithMultiple(op, {newOps});
     };
 
-    return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback);
+    return distributeOp(rewriter, adaptor, op, wgShape, callback);
   }
 };
 
@@ -996,7 +973,7 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
       }
       rewriter.eraseOp(op);
     };
-    return distributeMatrixOp(rewriter, adaptor, op, wgShape, callback);
+    return distributeOp(rewriter, adaptor, op, wgShape, callback);
   }
 };
 
@@ -1102,7 +1079,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return xegpu::TensorDescType();
   };
 
-  auto isLegal = [&](xegpu::DistributLayoutAttrInterface layout) -> bool {
+  auto isLegal = [&](xegpu::DistributeLayoutAttrInterface layout) -> bool {
     return !layout || !layout.isWgLayout();
   };
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index e5cc65e6bd3d7..7dcdcca070ac9 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -30,9 +30,13 @@ gpu.module @test_round_robin_assignment {
     //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
     //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
     //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
+    //CHECK: [[modY:%.+]] = index.remu [[ADDY]], [[C128]]
     //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
-    //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
+    //CHECK: [[modX:%.+]] = index.remu [[ADDX]], [[C64_2]]
+    //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
+    //CHECK: [[offX:%.+]] = index.add [[modX]], [[C0_3]]
+    //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
+    //CHECK: [[offY:%.+]] = index.add [[modY]], [[C0_4]]
     //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 5f851e9003a0e..bdda77a69f22e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -17,9 +17,13 @@ gpu.module @test_1_1_assignment {
     //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
     //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
     //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
     //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
+    //CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_3]]
+    //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_4]]
     //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -396,9 +400,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
     //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
     //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_y:%.+]] = index.add [[c0_2]], [[mod_y]]
+    //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_2]]
     //CHECK: [[c0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_x:%.+]] = index.add [[c0_3]], [[mod_x]]
+    //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_3]]
     //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
     %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
     %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
@@ -429,9 +433,9 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
     //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x]], [[c128]]
     //CHECK: [[c0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_y:%.+]] = index.add [[c0_3]], [[mod_y]]
+    //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_3]]
     //CHECK: [[c0_4:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_x:%.+]] = index.add [[c0_4]], [[mod_x]]
+    //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_4]]
     //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
     %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index d94d285b1105d..8d2fb85655c72 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -156,7 +156,7 @@ struct TestXeGPUUnrollingPatterns
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 // Test pattern for distributing vector::StepOp from workgroup to subgroup.
-// Validates DistributLayoutAttrInterface interfaces for offset computation
+// Validates DistributeLayoutAttrInterface interfaces for offset computation
 // abstraction between LayoutAttr and SliceAttr.
 class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
   using OpConversionPattern<vector::StepOp>::OpConversionPattern;

>From 9af1f7f417c5c4ee0e1a47830d8f061286cfe9e0 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 19 Aug 2025 20:20:07 +0000
Subject: [PATCH 3/8] fix typo

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index ca1209e776d0e..d31f0f1a75c51 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -752,8 +752,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
 // is lowered to:
 //   #a = #xegpu.layout<inst_data = [16, 16]>
 //   #b = #xegpu.layout<inst_data = [8, 16]>
-//   store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-//   %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+//   store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+//   %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
 //   xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
 // clang-format on
 struct WgToSgConvertLayoutOp

>From 93acad2807a0fd9ea9e3d5f594afc681dfb50840 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 20 Aug 2025 16:49:05 +0000
Subject: [PATCH 4/8] cleanup

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  66 +++
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 384 ++++++------------
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir |   8 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       |  58 +++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   |  76 +---
 5 files changed, 249 insertions(+), 343 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 3ba9eaa4a66da..3182552288ca6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -236,6 +236,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
     }
 
+    ArrayRef<int64_t> getDistributeShape() {
+      return getTensorDescShape();
+    }
+
   }];
 }
 
@@ -266,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
     xegpu::TensorDescType getTensorDescType() {
       return getTensorDesc().getType();
     }
+
+   SmallVector<OpFoldResult> getMixedOffsets() {
+      auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+      auto dynamics = getOffsets();
+      if (statics.size() == 0 && dynamics.size() == 0)
+        return {};
+      return getMixedValues(statics, dynamics, getContext());
+    }
+
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
+    }
+
+    ArrayRef<int64_t> getDistributeShape() {
+      return getTensorDescType().getShape();
+    }
+
   }];
 
   let assemblyFormat = [{
@@ -347,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     xegpu::TensorDescType getTensorDescType() {
       return getTensorDesc().getType();
     }
+
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+      auto dynamics = getOffsets();
+      if (statics.size() == 0 && dynamics.size() == 0)
+        return {};
+      return getMixedValues(statics, dynamics, getContext());
+    }
+
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
+    }
+
+    ArrayRef<int64_t> getDistributeShape() {
+      return getTensorDescType().getShape();
+    }
+
+
   }];
 
   let assemblyFormat = [{
@@ -421,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     xegpu::TensorDescType getTensorDescType() {
       return getTensorDesc().getType();
     }
+
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
+      auto dynamics = getOffsets();
+      if (statics.size() == 0 && dynamics.size() == 0)
+        return {};
+      return getMixedValues(statics, dynamics, getContext());
+    }
+
+    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
+    }
+
+    ArrayRef<int64_t> getDistributeShape() {
+      return getTensorDescType().getShape();
+    }
+
   }];
 
    let assemblyFormat = [{
@@ -644,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
     xegpu::TensorDescType getTensorDescType() {
       return dyn_cast<xegpu::TensorDescType>(getSourceType());
     }
+
   }];
 
   let assemblyFormat = [{
@@ -1185,6 +1242,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
     SmallVector<OpFoldResult> getMixedOffsets() {
       return getMixedValues(getConstOffsets(), getOffsets(), getContext());
     }
+
+    ArrayRef<int64_t> getDistributeShape() {
+      return getRes().getType().getShape();
+    }
   }];
 
   let hasVerifier = 1;
@@ -1223,6 +1284,11 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     SmallVector<OpFoldResult> getMixedOffsets() {
       return getMixedValues(getConstOffsets(), getOffsets(), getContext());
     }
+
+    ArrayRef<int64_t> getDistributeShape() {
+      return getData().getType().getShape();
+    }
+
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index d31f0f1a75c51..76bda64ffac1e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -79,43 +79,42 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
 // An util helper to generate elementwise addition ops for index computing.
 // lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
 // left-alignment is performed.
-static SmallVector<OpFoldResult> add(ConversionPatternRewriter &rewriter,
-                                     Location loc, ArrayRef<OpFoldResult> lhs,
-                                     ArrayRef<OpFoldResult> rhs) {
-  SmallVector<OpFoldResult> reversedResult;
-  auto l = lhs.rbegin();
-  auto r = rhs.rbegin();
-  for (; l != lhs.rend() || r != rhs.rend(); ++l, ++r) {
-    if (l == lhs.rend()) {
-      reversedResult.push_back(*r);
-    } else if (r == rhs.rend()) {
-      reversedResult.push_back(*l);
-    } else {
-      auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, *l);
-      auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, *r);
-      auto add = rewriter.createOrFold<index::AddOp>(loc, lval, rval);
-      reversedResult.push_back(add);
-    }
+static SmallVector<OpFoldResult>
+genIndexAdds(ConversionPatternRewriter &rewriter, Location loc,
+             ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {
+  // ensure a is longer than b
+  ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs;
+  ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs;
+  SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size()));
+  a = a.slice(a.size() - b.size());
+  for (auto [l, r] : llvm::zip(a, b)) {
+    auto lval = getValueOrCreateConstantIndexOp(rewriter, loc, l);
+    auto rval = getValueOrCreateConstantIndexOp(rewriter, loc, r);
+    results.push_back(rewriter.createOrFold<index::AddOp>(loc, lval, rval));
   }
-  return llvm::to_vector(llvm::reverse(reversedResult));
+  return results;
 }
 
-// A callback funtion type used to create new load/store_matrix ops
-using CreatorFuncType =
-    llvm::function_ref<void(ArrayRef<OpFoldResult> baseOffsets,
-                            SmallVector<SmallVector<Value>> &descOffsets)>;
-
-/// Utility helper for distributing logic shared by operations with offsets
-template <typename OpType,
-          typename = std::enable_if_t<llvm::is_one_of<
-              OpType, xegpu::CreateNdDescOp, xegpu::LoadMatrixOp,
-              xegpu::StoreMatrixOp>::value>>
+/// Utility helper for deriving a list of offsets for each sub-TensorDescs
+/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the
+/// associated distribute layout attribute, the shape, subgroup id and the
+/// original offsets of the op
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
+        xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
 static LogicalResult
-distributeOp(ConversionPatternRewriter &rewriter,
-             typename OpConversionPattern<OpType>::OneToNOpAdaptor adaptor,
-             OpType op, ArrayRef<int64_t> wgShape, CreatorFuncType callback) {
+genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
+               SmallVector<SmallVector<OpFoldResult>> &offsetsList) {
   Location loc = op.getLoc();
-  auto layout = op.getLayoutAttr();
+  SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
+  // not applicable to ops without offsets operands.
+  if (origOffsets.empty())
+    return failure();
+
+  // not applicable to ops without workgroup layout attributes
+  xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
   if (!layout || !layout.isWgLayout())
     return failure();
 
@@ -133,12 +132,23 @@ distributeOp(ConversionPatternRewriter &rewriter,
     sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal);
   }
 
-  auto maybeMdescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
-  if (failed(maybeMdescOffsets))
+  // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
+  // descriptors to be accessed, based on the layout information.
+  ArrayRef<int64_t> wgShape = op.getDistributeShape();
+  auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+  if (failed(maybeDescOffsets))
     return failure();
 
-  SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
-  callback(wgOffsets, *maybeMdescOffsets);
+  // Compute the final global offsets for each accessed sub-tensor
+  // or sub-memory descriptor.
+  // SmallVector<SmallVector<OpFoldResult>> offsetsList;
+  for (const auto &sgOffsets : *maybeDescOffsets) {
+    SmallVector<OpFoldResult> newOffsets =
+        genIndexAdds(rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets);
+    offsetsList.push_back(std::move(newOffsets));
+  }
+
+  // callback(offsetsList);
   return success();
 }
 
@@ -193,44 +203,31 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   LogicalResult
   matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    // Ensure that the op has explicit offsets specified (either dynamic or
-    // constant).
-    if (op.getMixedOffsets().empty())
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
-    Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
     Type elemTy = tdescTy.getElementType();
+    xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    auto newTdescTy =
+        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+                                   layout.dropSgLayoutAndData());
 
-    // the call back function for creating new CreateNdOps,
-    // the baseOffsets is the origial offsets of the op, and
-    // descOffsets is the relative offsets to the mem_desc accessed
-    // by each subgroup op.
-    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
-                        SmallVector<SmallVector<Value>> descOffsets) {
-      xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
-      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-      auto newTdescTy = xegpu::TensorDescType::get(
-          ctx, sgShape, elemTy, tdescTy.getEncoding(),
-          layout.dropSgLayoutAndData());
-
-      SmallVector<Value> newOps;
-      for (auto offsets : descOffsets) {
-        SmallVector<OpFoldResult> sgOffsets =
-            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
-        auto newOp = xegpu::CreateNdDescOp::create(
-            rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
-            op.getMixedSizes(), op.getMixedStrides());
+    SmallVector<Value> newOps;
+    for (auto offsets : offsetsList) {
+      auto newOp = xegpu::CreateNdDescOp::create(
+          rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
+          op.getMixedSizes(), op.getMixedStrides());
 
-        newOps.push_back(newOp);
-      }
-      rewriter.replaceOpWithMultiple(op, {newOps});
-    };
+      newOps.push_back(newOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newOps});
 
-    return distributeOp(rewriter, adaptor, op, wgShape, callback);
+    return success();
   }
 };
 
@@ -283,12 +280,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> {
   LogicalResult
   matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    SmallVector<Value> newLoadOps;
-
-    int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
+    if (!op.getMixedOffsets().empty())
       return failure();
 
+    SmallVector<Value> newLoadOps;
     for (auto src : adaptor.getTensorDesc()) {
       xegpu::TensorDescType tdescTy =
           dyn_cast<xegpu::TensorDescType>(src.getType());
@@ -311,9 +306,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   LogicalResult
   matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-    if ((offsetSize != 0) || op.getConstOffsetsAttr())
+    if (!op.getMixedOffsets().empty())
       return failure();
 
     for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
@@ -325,100 +318,6 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
-// Utility function to compute global offsets for subgroup operations.
-// Returns a vector of new offsets for each subgroup, given the original op's
-// offsets and subgroup relative offsets.
-static SmallVector<SmallVector<OpFoldResult>>
-computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
-               ArrayRef<OpFoldResult> origOffsets,
-               ConversionPatternRewriter &rewriter) {
-  SmallVector<SmallVector<OpFoldResult>> finalOffsets;
-  Location loc = op->getLoc();
-  for (const auto &sgOffsets : sgOffsetsList) {
-    SmallVector<OpFoldResult> newOffsets;
-    size_t rank = sgOffsets.size();
-    for (size_t i = 0; i < rank; i++) {
-      size_t idx = origOffsets.size() - rank + i;
-      Value add = rewriter.createOrFold<index::AddOp>(
-          loc, sgOffsets[i],
-          getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
-      newOffsets.push_back(add);
-    }
-    finalOffsets.push_back(std::move(newOffsets));
-  }
-  return finalOffsets;
-}
-
-// Utility function to get sgShape, sgOffsetList for a given
-// op.
-template <typename OpTy, typename AdaptorTy>
-LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
-                           ConversionPatternRewriter &rewriter,
-                           SmallVector<int64_t> &sgShape,
-                           SmallVector<SmallVector<Value>> &sgOffsetList) {
-  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
-  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
-    return failure();
-
-  Location loc = op.getLoc();
-  Value tdesc = op.getTensorDesc();
-  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
-  if (!tdescTy)
-    return failure();
-  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-  if (!layout)
-    return failure();
-
-  SmallVector<int64_t> sgLayout;
-  auto sgLayoutAttr = layout.getSgLayout();
-  if (!sgLayoutAttr)
-    return rewriter.notifyMatchFailure(
-        op, "sgLayout attribute is required in layout");
-  sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-
-  ArrayRef<int64_t> wgShape = tdescTy.getShape();
-  int count;
-  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
-
-  // Get the subgroup ID
-  Value linearSgId =
-      gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-  int64_t startOfRange = -1, endOfRange = -1;
-  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
-  if (sgIdRangeSpecified) {
-    int64_t sgCount = endOfRange - startOfRange;
-    if (computeProduct(sgLayout) != sgCount)
-      return rewriter.notifyMatchFailure(
-          op, "sg_layout size must match the sg_id_range");
-    Value startOfRangeVal =
-        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
-    linearSgId =
-        rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
-  }
-
-  auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-  if (failed(sgOffsets))
-    return failure();
-
-  sgOffsetList = *sgOffsets;
-  return success();
-}
-
-template <typename OpTy>
-SmallVector<OpFoldResult> getOffsets(OpTy op,
-                                     ConversionPatternRewriter &rewriter) {
-  SmallVector<OpFoldResult> origOffsets;
-  if (auto constOffsets = op.getConstOffsetsAttr()) {
-    for (auto attr : constOffsets.asArrayRef())
-      origOffsets.push_back(rewriter.getIndexAttr(attr));
-  }
-  for (auto v : op.getOffsets())
-    origOffsets.push_back(v);
-  return origOffsets;
-}
-
 // This pattern transforms the LoadNdOp with explicit offsets to load
 // subgroup data.
 struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
@@ -427,33 +326,24 @@ struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
   matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    SmallVector<int64_t> sgShape;
-    SmallVector<SmallVector<Value>> sgOffsetList;
-
-    // Do the distribution from workgroup to subgroup and get subgroup offsets
-    if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
-    // Get the original workgroup offsets
-    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
-
-    // Calculate the final offsets for each subgroup
-    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
-
-    SmallVector<Value> newLoadOps;
-    for (auto [offsets, tdesc] :
-         llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
-      VectorType newResTy = VectorType::get(
-          sgShape,
-          dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
-      auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
-          op.getLoc(), newResTy, tdesc, offsets,
-          /*packed=*/nullptr,
-          /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
-          op.getL3HintAttr());
-      newLoadOps.push_back(newLoadOp);
+    SmallVector<Value> newOps;
+    for (auto [tdesc, offsets] :
+         llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
+      auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+      VectorType newResTy =
+          VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
+      auto newOp = xegpu::LoadNdOp::create(
+          rewriter, op.getLoc(), newResTy, tdesc, offsets,
+          /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      newOps.push_back(newOp);
     }
-    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    rewriter.replaceOpWithMultiple(op, {newOps});
+
     return success();
   }
 };
@@ -466,27 +356,18 @@ struct WgToSgStoreNdOpWithOffset
   LogicalResult
   matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    SmallVector<int64_t> sgShape;
-    SmallVector<SmallVector<Value>> sgOffsetList;
-
-    // Do the distribution from workgroup to subgroup and get subgroup offsets
-    if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
-    // Get the original workgroup offsets
-    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
-
-    // Calculate the final offsets for each subgroup
-    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
-
-    for (auto [offsets, tdesc, value] :
-         llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
-      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
+    for (auto [v, tdesc, offsets] :
+         llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
+      rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets,
                                         op.getL1HintAttr(), op.getL2HintAttr(),
                                         op.getL3HintAttr());
     }
     rewriter.eraseOp(op);
+
     return success();
   }
 };
@@ -499,27 +380,18 @@ struct WgToSgPrefetchNdOpWithOffset
   LogicalResult
   matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-
-    SmallVector<int64_t> sgShape;
-    SmallVector<SmallVector<Value>> sgOffsetList;
-
-    // Do the distribution from workgroup to subgroup and get subgroup offsets
-    if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
-    // Get the original workgroup offsets
-    SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
-
-    // Calculate the final offsets for each subgroup
-    auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
-
-    for (auto [offsets, tdesc] :
-         llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
+    for (auto [tdesc, offsets] :
+         llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
       rewriter.create<xegpu::PrefetchNdOp>(
           op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
           op.getL3HintAttr());
     }
     rewriter.eraseOp(op);
+
     return success();
   }
 };
@@ -918,34 +790,28 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
   LogicalResult
   matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Location loc = op.getLoc();
+
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    if (failed(genOffsetsList(rewriter, op, offsetsList)))
+      return failure();
+
+    ArrayRef<int64_t> wgShape = op.getDistributeShape();
     VectorType valueTy = op.getRes().getType();
-    ArrayRef<int64_t> wgShape = valueTy.getShape();
     Type elemTy = valueTy.getElementType();
 
-    // the call back function for creating new LoadMatrixOps,
-    // the baseOffsets is the origial offsets of the op, and
-    // descOffsets is the relative offsets to the mem_desc accessed
-    // by each subgroup op.
-    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
-                        SmallVector<SmallVector<Value>> descOffsets) {
-      auto layout = op.getLayoutAttr();
-      SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-      VectorType newResTy = VectorType::get(sgShape, elemTy);
-
-      SmallVector<Value> newOps;
-      for (auto offsets : descOffsets) {
-        SmallVector<OpFoldResult> sgOffsets =
-            add(rewriter, loc, baseOffsets, getAsOpFoldResult(offsets));
-        auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
-            loc, newResTy, op.getMemDesc(), sgOffsets,
-            layout.dropSgLayoutAndData());
-        newOps.push_back(newOp);
-      }
-      rewriter.replaceOpWithMultiple(op, {newOps});
-    };
+    xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResTy = VectorType::get(sgShape, elemTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : offsetsList) {
+      auto newOp = rewriter.create<xegpu::LoadMatrixOp>(
+          op.getLoc(), newResTy, op.getMemDesc(), offsets,
+          layout.dropSgLayoutAndData());
+      newOps.push_back(newOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newOps});
 
-    return distributeOp(rewriter, adaptor, op, wgShape, callback);
+    return success();
   }
 };
 
@@ -954,26 +820,18 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
   LogicalResult
   matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    VectorType valueTy = op.getData().getType();
-    ArrayRef<int64_t> wgShape = valueTy.getShape();
-
-    // the call back function for creating new StoreMatrixOps,
-    // the baseOffsets is the origial offsets of the op, and
-    // descOffsets is the relative offsets to the mem_desc accessed
-    // by each subgroup op.
-    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
-                        SmallVector<SmallVector<Value>> descOffsets) {
-      auto layout = op.getLayoutAttr();
-      for (auto [v, descOffsets] : llvm::zip(adaptor.getData(), descOffsets)) {
-        SmallVector<OpFoldResult> sgOffsets =
-            add(rewriter, loc, baseOffsets, getAsOpFoldResult(descOffsets));
-        rewriter.create<xegpu::StoreMatrixOp>(
-            loc, v, op.getMemDesc(), sgOffsets, layout.dropSgLayoutAndData());
-      }
-      rewriter.eraseOp(op);
-    };
-    return distributeOp(rewriter, adaptor, op, wgShape, callback);
+
+    SmallVector<SmallVector<OpFoldResult>> offsetsList;
+    if (failed(genOffsetsList(rewriter, op, offsetsList)))
+      return failure();
+
+    xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+    for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
+      rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
+                                            offsets,
+                                            layout.dropSgLayoutAndData());
+    rewriter.eraseOp(op);
+    return success();
   }
 };
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 7dcdcca070ac9..e5cc65e6bd3d7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -30,13 +30,9 @@ gpu.module @test_round_robin_assignment {
     //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
     //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
     //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[modY:%.+]] = index.remu [[ADDY]], [[C128]]
+    //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
     //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
-    //CHECK: [[modX:%.+]] = index.remu [[ADDX]], [[C64_2]]
-    //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[offX:%.+]] = index.add [[modX]], [[C0_3]]
-    //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
-    //CHECK: [[offY:%.+]] = index.add [[modY]], [[C0_4]]
+    //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
     //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 07a0b86223c33..32157a7911f62 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -263,4 +263,62 @@ gpu.module @test_distribution {
   } {sg_id_range = #xegpu.range<[3, 19]>}
   gpu.return
   }
+
+  // CHECK-LABEL: distribute_load_matrix
+  // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
+  gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
+    //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c2:%.+]] = arith.constant 2 : index
+    //CHECK: [[c4:%.+]] = arith.constant 4 : index
+    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
+    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
+    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
+    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
+    //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
+    //CHECK: [[c64:%.+]] = arith.constant 64 : index
+    //CHECK: [[off_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[off_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
+    //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
+    %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
+    gpu.return
+  }
+
+  //CHECK-LABEL: distribute_store_matrix
+  //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
+  gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) {
+    //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
+    //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[c2:%.+]] = arith.constant 2 : index
+    //CHECK: [[c4:%.+]] = arith.constant 4 : index
+    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
+    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
+    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
+    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
+    //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
+    //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
+    //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
+    //CHECK: [[c64:%.+]] = arith.constant 64 : index
+    //CHECK: [[off_y:%.+]] = index.remu [[l_off_y]], [[c64]]
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[off_x:%.+]] = index.remu [[l_off_x]], [[c128]]
+    //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
+    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
+    %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
+    xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index bdda77a69f22e..f4a49da71605f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -17,13 +17,9 @@ gpu.module @test_1_1_assignment {
     //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
     //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
     //CHECK: [[C256:%.+]] = arith.constant 256 : index
-    //CHECK: [[MODY:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
     //CHECK: [[C128:%.+]] = arith.constant 128 : index
-    //CHECK: [[MODX:%.+]] = index.remu [[UX]], [[C128]]
-    //CHECK: [[C0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[X:%.+]] = index.add [[MODX]], [[C0_3]]
-    //CHECK: [[C0_4:%.+]] = arith.constant 0 : index
-    //CHECK: [[Y:%.+]] = index.add [[MODY]], [[C0_4]]
+    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
     //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
       -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
@@ -376,72 +372,4 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
     gpu.return
   }
-
-  // CHECK-LABEL: distribute_load_matrix
-  // CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
-  gpu.func @distribute_load_matrix(%arg0: memref<32768xi8, 3>) {
-    //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
-    //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[c2:%.+]] = arith.constant 2 : index
-    //CHECK: [[c4:%.+]] = arith.constant 4 : index
-    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
-    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
-    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_y:%.+]] = index.mul [[id_y]], [[c32]]
-    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_x:%.+]] = index.mul [[id_x]], [[c32_1]]
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[c0_1:%.+]] = arith.constant 0 : index
-    //CHECK: [[l_off_y_0:%.+]] = arith.addi [[l_off_y]], [[c0]] : index
-    //CHECK: [[l_off_x_0:%.+]] = arith.addi [[l_off_x]], [[c0_1]] : index
-    //CHECK: [[c64:%.+]] = arith.constant 64 : index
-    //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y_0]], [[c64]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x_0]], [[c128]]
-    //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_2]]
-    //CHECK: [[c0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_3]]
-    //CHECK: xegpu.load_matrix [[mdesc]][[[off_y]], [[off_x]]] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32>, index, index -> vector<32x32xf32>
-    %0 = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
-    %1 = xegpu.load_matrix %0[0, 0] <{layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32], lane_layout = [2, 8], lane_data = [1, 1]>}>: !xegpu.mem_desc<64x128xf32> -> vector<64x128xf32>
-    gpu.return
-  }
-
-  //CHECK-LABEL: distribute_store_matrix
-  //CHECK-SAME: [[arg0:%.+]]: memref<32768xi8, 3>
-  gpu.func @distribute_store_matrix(%arg0 : memref<32768xi8, 3>) {
-    //CHECK: [[cst:%.+]] = arith.constant dense<1.000000e+00> : vector<32x32xf32>
-    //CHECK: [[mdesc:%.+]] = xegpu.create_mem_desc [[arg0]] : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
-    //CHECK: [[sgid:%.+]] = gpu.subgroup_id : index
-    //CHECK: [[c2:%.+]] = arith.constant 2 : index
-    //CHECK: [[c4:%.+]] = arith.constant 4 : index
-    //CHECK: [[c4_0:%.+]] = arith.constant 4 : index
-    //CHECK: [[id_y:%.+]] = affine.apply #map()[[[sgid]]]
-    //CHECK: [[id_x:%.+]] = affine.apply #map1()[[[sgid]]]
-    //CHECK: [[c32:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_y_0:%.+]] = index.mul [[id_y]], [[c32]]
-    //CHECK: [[c32_1:%.+]] = arith.constant 32 : index
-    //CHECK: [[l_off_x_0:%.+]] = index.mul [[id_x]], [[c32_1]]
-    //CHECK: [[c0:%.+]] = arith.constant 0 : index
-    //CHECK: [[c0_2:%.+]] = arith.constant 0 : index
-    //CHECK: [[l_off_y:%.+]] = arith.addi [[l_off_y_0]], [[c0]] : index
-    //CHECK: [[l_off_x:%.+]] = arith.addi [[l_off_x_0]], [[c0_2]] : index
-    //CHECK: [[c64:%.+]] = arith.constant 64 : index
-    //CHECK: [[mod_y:%.+]] = index.remu [[l_off_y]], [[c64]]
-    //CHECK: [[c128:%.+]] = arith.constant 128 : index
-    //CHECK: [[mod_x:%.+]] = index.remu [[l_off_x]], [[c128]]
-    //CHECK: [[c0_3:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_x:%.+]] = index.add [[mod_x]], [[c0_3]]
-    //CHECK: [[c0_4:%.+]] = arith.constant 0 : index
-    //CHECK: [[off_y:%.+]] = index.add [[mod_y]], [[c0_4]]
-    //CHECK: xegpu.store_matrix [[cst]], [[mdesc]][[[off_y]], [[off_x]]] : vector<32x32xf32>, !xegpu.mem_desc<64x128xf32>, index, index
-    %cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} dense<1.0> : vector<64x128xf32>
-    %mdesc = xegpu.create_mem_desc %arg0 : memref<32768xi8, 3> -> !xegpu.mem_desc<64x128xf32>
-    xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
-
-    gpu.return
-  }
-
 }

>From ce07282d88568f25f8c6fb29c7327f1e26624e1d Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 20 Aug 2025 16:56:29 +0000
Subject: [PATCH 5/8] rename isWgLayout to isForWorkgroup

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 16 ++++++++--------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp       |  6 +++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp           |  4 ++--
 .../Dialect/XeGPU/Transforms/XeGPUBlocking.cpp   | 10 +++++-----
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp   | 11 ++++++-----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp      |  2 +-
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp    |  4 ++--
 7 files changed, 27 insertions(+), 26 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index de86141ad006a..fe1f127bcd6b6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -184,7 +184,7 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"
   let methods = [
     InterfaceMethod<"Check the availability of workgroup level layouts",
                     "bool",
-                    "isWgLayout">,
+                    "isForWorkgroup">,
     InterfaceMethod<"Get the rank of attribute",
                     "int64_t",
                     "getRank">,
@@ -337,12 +337,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf
   ];
 
   let extraClassDeclaration = [{
-    bool isWgLayout() {
+    bool isForWorkgroup() {
       return getSgLayout() != nullptr;
     }
 
-    bool isSgLayout() {
-      return !isWgLayout();
+    bool isForSubgroup() {
+      return !isForWorkgroup();
     }
 
     int64_t getRank() {
@@ -454,16 +454,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface
       return parent.getOrder();
     }
 
-    bool isWgLayout() const {
+    bool isForWorkgroup() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      return parent.isWgLayout();
+      return parent.isForWorkgroup();
     }
 
-    bool isSgLayout() const {
+    bool isForSubgroup() const {
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
-      return parent.isSgLayout();
+      return parent.isForSubgroup();
     }
 
     int64_t getNumSubgroups() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index de118b7faea4d..9e6702dda2de3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -271,7 +271,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
                                   Value linearId) {
   // delinearizeSubgroupId is only available for
   // workgroup-level layout attribute
-  if (!isWgLayout())
+  if (!isForWorkgroup())
     return failure();
 
   // TODO: handle order attribute
@@ -296,7 +296,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
 FailureOr<SmallVector<SmallVector<Value>>>
 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                        ArrayRef<int64_t> shape) {
-  if (!isWgLayout())
+  if (!isForWorkgroup())
     return failure();
 
   SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
@@ -384,7 +384,7 @@ FailureOr<SmallVector<SmallVector<Value>>>
 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
                       ArrayRef<int64_t> shape) {
   assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
-  if (!isWgLayout())
+  if (!isForWorkgroup())
     return failure();
 
   SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0e22af900daf1..ff538ebed4bad 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -938,8 +938,8 @@ LogicalResult ConvertLayoutOp::verify() {
 
   // both input and target layouts should be WgLayout or SgLayout at the same
   // time.
-  if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) &&
-      (!srcLayout.isSgLayout() || !resLayout.isSgLayout()))
+  if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) &&
+      (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup()))
     return emitOpError("expected input layout and target layout be WgLayout or "
                        "SgLayout at the same time.");
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index d82c541f31359..b3144e4c1e55d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -141,7 +141,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const {
     value = (Value)operandOrResult;
 
   xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult);
-  if (layout && layout.isSgLayout()) {
+  if (layout && layout.isForSubgroup()) {
     if (auto inst_data = layout.getInstData())
       return llvm::to_vector_of<int64_t>(inst_data.asArrayRef());
 
@@ -205,12 +205,12 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const {
   bool hasWgLayoutOperands =
       llvm::any_of(op->getOpOperands(), [](OpOperand &opr) {
         xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr);
-        return layout && layout.isWgLayout();
+        return layout && layout.isForWorkgroup();
       });
   bool hasWgLayoutResults =
       llvm::any_of(op->getOpResults(), [](OpResult result) {
         xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result);
-        return layout && layout.isWgLayout();
+        return layout && layout.isForWorkgroup();
       });
   if (hasWgLayoutOperands || hasWgLayoutResults) {
     LDBG() << "skip unrolling for op with workgroup level layout: " << *op;
@@ -272,7 +272,7 @@ void XeGPUBlockingPass::runOnOperation() {
 
         auto layout =
             llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding());
-        if (layout && layout.isWgLayout())
+        if (layout && layout.isForWorkgroup())
           return failure();
 
         int count;
@@ -289,7 +289,7 @@ void XeGPUBlockingPass::runOnOperation() {
         ArrayRef<int64_t> shape = type.getShape();
 
         xegpu::LayoutAttr layout = type.getLayoutAttr();
-        if (layout && layout.isWgLayout())
+        if (layout && layout.isForWorkgroup())
           return failure();
 
         int count;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 76bda64ffac1e..55957d9b264fc 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -59,7 +59,7 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
                    xegpu::DistributeLayoutAttrInterface layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
-  if (layout && layout.isWgLayout()) {
+  if (layout && layout.isForWorkgroup()) {
     SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt().value();
     if (auto maybeSgData = layout.getSgDataAsInt())
       sgShape = *maybeSgData;
@@ -115,7 +115,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
 
   // not applicable to ops without workgroup layout attributes
   xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
-  if (!layout || !layout.isWgLayout())
+  if (!layout || !layout.isForWorkgroup())
     return failure();
 
   Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr);
@@ -249,7 +249,7 @@ struct WgToSgCreateNdOpNoOffset
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
     auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-    if (!layout || !layout.isWgLayout())
+    if (!layout || !layout.isForWorkgroup())
       return failure();
 
     Type elemTy = tdescTy.getElementType();
@@ -637,7 +637,8 @@ struct WgToSgConvertLayoutOp
     xegpu::LayoutAttr input = op.getInputLayout();
     xegpu::LayoutAttr target = op.getTargetLayout();
 
-    if (!input || !target || !input.isWgLayout() || !target.isWgLayout())
+    if (!input || !target || !input.isForWorkgroup() ||
+        !target.isForWorkgroup())
       return rewriter.notifyMatchFailure(
           op, "Input and target layouts must have subgroup layout");
 
@@ -938,7 +939,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   };
 
   auto isLegal = [&](xegpu::DistributeLayoutAttrInterface layout) -> bool {
-    return !layout || !layout.isWgLayout();
+    return !layout || !layout.isForWorkgroup();
   };
 
   target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 19eedbac0f76b..535e2b10353c9 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -40,7 +40,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) {
   auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
   // It only works for subgroup level layout, which only has lane_layout
   // and lane_data, and is to distribute a SIMD code into SIMT code.
-  if (!layout || !layout.isSgLayout())
+  if (!layout || !layout.isForSubgroup())
     return failure();
 
   SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef());
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 8d2fb85655c72..86bb3af326da2 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -82,7 +82,7 @@ struct TestXeGPUUnrollingPatterns
 
             if (auto layout = tdescTy.getLayoutAttr()) {
               auto inst_data = layout.getInstData();
-              if (inst_data && layout.isSgLayout())
+              if (inst_data && layout.isForSubgroup())
                 return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
                                             inst_data.asArrayRef().end());
             }
@@ -239,7 +239,7 @@ struct TestXeGPULayoutInterface
 
     ConversionTarget target(*ctx);
     auto isLegal = [&](xegpu::SliceAttr layout) -> bool {
-      return !layout || !layout.isWgLayout();
+      return !layout || !layout.isForWorkgroup();
     };
 
     target.addDynamicallyLegalOp<vector::StepOp>(

>From 9ae490cf05850458f7f81635ba4be21dc8d26ac1 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 20 Aug 2025 17:37:58 +0000
Subject: [PATCH 6/8] cleanup getNumSubgroups

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 21 ++++++-------------
 1 file changed, 6 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index fe1f127bcd6b6..0fe4e22f50376 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -190,7 +190,12 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"
                     "getRank">,
     InterfaceMethod<"Get the num of effective subgroups",
                     "int64_t",
-                    "getNumSubgroups">,
+                    "getNumSubgroups", (ins), [{
+                        std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
+                        if (sgLayout.has_value())
+                          return computeProduct(*sgLayout);
+                        return 0;
+                    }], [{}]>,
     InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
                     "std::optional<SmallVector<int64_t>>",
                     "getSgLayoutAsInt">,
@@ -355,13 +360,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf
       return 0;
     }
 
-    int64_t getNumSubgroups() {
-      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
-      if (sgLayout.has_value())
-        return computeProduct(*sgLayout);
-      return 0;
-    }
-
     LayoutAttr dropSgLayoutAndData() {
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
@@ -466,13 +464,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface
       return parent.isForSubgroup();
     }
 
-    int64_t getNumSubgroups() {
-      std::optional<SmallVector<int64_t>> sgLayout = getSgLayoutAsInt();
-      if (sgLayout.has_value())
-        return computeProduct(*sgLayout);
-      return 0;
-    }
-
     /// Returns the SgLayout of the attribute, computed by applying
     /// the slice dimensions to the underlying LayoutAttr.
     std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {

>From 36e3e3d632e36c9fcbc98791389020e088d7285f Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 20 Aug 2025 17:45:39 +0000
Subject: [PATCH 7/8] update comments

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp       | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 55957d9b264fc..c5b497dcc695e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -76,9 +76,15 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
   return std::make_pair(sgShape, count);
 }
 
-// An util helper to generate elementwise addition ops for index computing.
-// lhs and rhs are vectors of Values. If the rank of lhs and rhs doesn't match.
-// left-alignment is performed.
+/// Generates element-wise addition ops of two arrays with automatic alignment.
+/// When the input arrays have different sizes, the shorter array is right-aligned
+/// with the longer array, and the unmatched leading elements from the longer array
+/// are preserved unchanged. This is commonly used for offset computation where
+/// higher-dimensional offsets need to be added to lower-dimensional adjustments.
+///
+/// Example:
+///   lhs = [10, 20, 30], rhs = [5, 7]
+///   Result: [10, 25, 37] (20+5, 30+7, with 10 preserved)
 static SmallVector<OpFoldResult>
 genIndexAdds(ConversionPatternRewriter &rewriter, Location loc,
              ArrayRef<OpFoldResult> lhs, ArrayRef<OpFoldResult> rhs) {

>From 6d0458f2f5f2d5424a4cad7567309c6ababf787e Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 20 Aug 2025 17:49:27 +0000
Subject: [PATCH 8/8] cleanup

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h    |  2 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 10 ++++----
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 24 +++++++++----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  9 ++++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  4 ++--
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 12 +++++-----
 .../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp |  2 +-
 7 files changed, 31 insertions(+), 32 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 1d152f0c9ca9a..1481859e94a92 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -24,7 +24,7 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
-class DistributeLayoutAttrInterface;
+class DistributeLayoutAttr;
 class LayoutAttr;
 class SliceAttr;
 } // namespace xegpu
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 0fe4e22f50376..b4d696444cc44 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,7 +175,7 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"> {
+def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
   let cppNamespace = "::mlir::xegpu";
   let description = [{
     Common trait for all XeGPU layouts.
@@ -203,7 +203,7 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"
                     "std::optional<SmallVector<int64_t>>",
                     "getSgDataAsInt">,
     InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
-                    "xegpu::DistributeLayoutAttrInterface",
+                    "xegpu::DistributeLayoutAttr",
                     "dropSgLayoutAndData">,
     InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
                       indices based on the effective subgroup layout.}],
@@ -220,7 +220,7 @@ def DistributeLayoutAttrInterface: AttrInterface<"DistributeLayoutAttrInterface"
   ];
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterface]> {
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -407,7 +407,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttrInterf
 }
 
 
-def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface]> {
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
   let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
 
   let description = [{
@@ -434,7 +434,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttrInterface
   }];
 
   let parameters = (ins
-    "xegpu::DistributeLayoutAttrInterface": $parent,
+    "xegpu::DistributeLayoutAttr": $parent,
     "DenseI64ArrayAttr": $dims
   );
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 3182552288ca6..f3eaf400e1e4c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -232,8 +232,8 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       return static_cast<unsigned>(MemorySpace::Global);
     }
 
-    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
-      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
+    xegpu::DistributeLayoutAttr getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
     }
 
     ArrayRef<int64_t> getDistributeShape() {
@@ -279,8 +279,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
       return getMixedValues(statics, dynamics, getContext());
     }
 
-    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
-      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
+    xegpu::DistributeLayoutAttr getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
     }
 
     ArrayRef<int64_t> getDistributeShape() {
@@ -377,8 +377,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
       return getMixedValues(statics, dynamics, getContext());
     }
 
-    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
-      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
+    xegpu::DistributeLayoutAttr getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
     }
 
     ArrayRef<int64_t> getDistributeShape() {
@@ -469,8 +469,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
       return getMixedValues(statics, dynamics, getContext());
     }
 
-    xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
-      return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
+    xegpu::DistributeLayoutAttr getLayoutAttr() {
+      return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
     }
 
     ArrayRef<int64_t> getDistributeShape() {
@@ -1211,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<DistributeLayoutAttrInterface>:$layout
+    OptionalAttr<DistributeLayoutAttr>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
   let assemblyFormat = [{
@@ -1236,7 +1236,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 
   let builders = [
     OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
-                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
+                    "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
@@ -1259,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<DistributeLayoutAttrInterface>:$layout
+    OptionalAttr<DistributeLayoutAttr>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
                           prop-dict attr-dict `` `:` type(operands)}];
@@ -1278,7 +1278,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   }];
   let builders = [
     OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttrInterface": $layout)>,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
   ];
   let extraClassDeclaration = [{
     SmallVector<OpFoldResult> getMixedOffsets() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 9e6702dda2de3..a2d708be0e937 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -290,7 +290,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return affine::delinearizeIndex(builder, loc, linearId, dims);
 }
 
-/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// Implements DistributeLayoutAttr::getOffsets to generate
 /// instructions for computing multi-dimensional offsets when distributed by
 /// LayoutAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
@@ -323,8 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
 //===----------------------------------------------------------------------===//
 LogicalResult
 SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
-                  xegpu::DistributeLayoutAttrInterface parent,
-                  DenseI64ArrayAttr dims) {
+                  xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
   if (!parent || !dims)
     return emitError() << "expected parent layout and dims attribute";
 
@@ -342,7 +341,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
 }
 
 SliceAttr SliceAttr::flatten() const {
-  xegpu::DistributeLayoutAttrInterface parent = getParent();
+  xegpu::DistributeLayoutAttr parent = getParent();
   SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
 
   while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -377,7 +376,7 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
   return parent.delinearizeSubgroupId(builder, loc, linearId);
 }
 
-/// Implements DistributeLayoutAttrInterface::getOffsets to generate
+/// Implements DistributeLayoutAttr::getOffsets to generate
 /// instructions for computing multi-dimensional offsets when distributed by
 /// SliceAttr.
 FailureOr<SmallVector<SmallVector<Value>>>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ff538ebed4bad..c8d180b973f05 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -984,7 +984,7 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
                          TypedValue<MemDescType> memDesc,
                          llvm::ArrayRef<OpFoldResult> offsets,
-                         DistributeLayoutAttrInterface layout) {
+                         DistributeLayoutAttr layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
@@ -1014,7 +1014,7 @@ LogicalResult LoadMatrixOp::verify() {
 void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
                           TypedValue<MemDescType> memDesc,
                           llvm::ArrayRef<OpFoldResult> offsets,
-                          DistributeLayoutAttrInterface layout) {
+                          DistributeLayoutAttr layout) {
   llvm::SmallVector<Value> dynamicOffsets;
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c5b497dcc695e..09aa1e61c20e6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -56,7 +56,7 @@ static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange,
 
 static std::pair<SmallVector<int64_t>, int>
 getSgShapeAndCount(ArrayRef<int64_t> shape,
-                   xegpu::DistributeLayoutAttrInterface layout) {
+                   xegpu::DistributeLayoutAttr layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
   if (layout && layout.isForWorkgroup()) {
@@ -120,7 +120,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
     return failure();
 
   // not applicable to ops without workgroup layout attributes
-  xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+  xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
   if (!layout || !layout.isForWorkgroup())
     return failure();
 
@@ -217,7 +217,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     xegpu::TensorDescType tdescTy = op.getType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
     Type elemTy = tdescTy.getElementType();
-    xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     auto newTdescTy =
         xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
@@ -806,7 +806,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
     VectorType valueTy = op.getRes().getType();
     Type elemTy = valueTy.getElementType();
 
-    xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResTy = VectorType::get(sgShape, elemTy);
     SmallVector<Value> newOps;
@@ -832,7 +832,7 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
     if (failed(genOffsetsList(rewriter, op, offsetsList)))
       return failure();
 
-    xegpu::DistributeLayoutAttrInterface layout = op.getLayoutAttr();
+    xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
     for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
       rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(),
                                             offsets,
@@ -944,7 +944,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return xegpu::TensorDescType();
   };
 
-  auto isLegal = [&](xegpu::DistributeLayoutAttrInterface layout) -> bool {
+  auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool {
     return !layout || !layout.isForWorkgroup();
   };
 
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 86bb3af326da2..200323c7a4e51 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -156,7 +156,7 @@ struct TestXeGPUUnrollingPatterns
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 // Test pattern for distributing vector::StepOp from workgroup to subgroup.
-// Validates DistributeLayoutAttrInterface interfaces for offset computation
+// Validates DistributeLayoutAttr interfaces for offset computation
 // abstraction between LayoutAttr and SliceAttr.
 class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
   using OpConversionPattern<vector::StepOp>::OpConversionPattern;



More information about the Mlir-commits mailing list