[Mlir-commits] [mlir] c962234 - [mlir][xegpu] Add definition of SliceAttr (#150146)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 8 09:27:20 PDT 2025


Author: Chao Chen
Date: 2025-08-08T11:27:17-05:00
New Revision: c96223434c64d32c3f397d20a8ed1d9749aae441

URL: https://github.com/llvm/llvm-project/commit/c96223434c64d32c3f397d20a8ed1d9749aae441
DIFF: https://github.com/llvm/llvm-project/commit/c96223434c64d32c3f397d20a8ed1d9749aae441.diff

LOG: [mlir][xegpu] Add definition of SliceAttr (#150146)

---------

Co-authored-by: Charitha Saumya <136391709+charithaintc at users.noreply.github.com>

Added: 
    mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
    mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/invalid.mlir
    mlir/test/Dialect/XeGPU/layout.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
    mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
index 3f8cac4dc07c3..728f1aa859061 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
 mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
 add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
 add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
+mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
+add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 8e2784f40ad39..3592da4c46364 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -15,6 +15,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
@@ -22,17 +23,19 @@
 namespace mlir {
 namespace xegpu {
 class TensorDescType;
+class LayoutAttr;
+class SliceAttr;
 } // namespace xegpu
 } // namespace mlir
 
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+
 #define GET_ATTRDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
 #define GET_TYPEDEF_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
-
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-
 #define GET_OP_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
 

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 64eb21cbc3c4c..1f420c13ebae0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,7 +175,38 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
+def LayoutTrait: AttrInterface<"LayoutTrait"> {
+  let cppNamespace = "::mlir::xegpu";
+  let description = [{
+    Common trait for all XeGPU layouts.
+  }];
+
+  let methods = [
+    InterfaceMethod<"Get the rank of attribute",
+                    "int64_t",
+                    "getRank">,
+    InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
+                    "std::optional<SmallVector<int64_t>>",
+                    "getSgLayoutAsInt">,
+    InterfaceMethod<"Get the SgData field of the attribute as integer array",
+                    "std::optional<SmallVector<int64_t>>",
+                    "getSgDataAsInt">,
+    InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
+                      indices based on the effective subgroup layout.}],
+                    "FailureOr<SmallVector<Value>>",
+                    "delinearizeSubgroupId",
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
+    InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
+                      assigned to a subgroup identified by linearId. The shape parameter
+                      represents the workgroup-level problem size. Each subgroup may access
+                      multiple blocks according to round-robin distribution rules.}],
+                    "FailureOr<SmallVector<SmallVector<Value>>>",
+                    "getOffsets",
+                    (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
+  ];
+}
+
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
   let summary = [{
     Describes the data distribution to subgroups and work-items for a tensor
     specified by the tensor descriptor.
@@ -330,12 +361,143 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
       return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
                              getLaneLayout(), getLaneData(), getOrder());
     }
+
+    std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
+      if (DenseI32ArrayAttr layout = getSgLayout())
+        return llvm::to_vector_of<int64_t>(layout.asArrayRef());
+      return std::nullopt;
+    }
+
+    std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
+      if (DenseI32ArrayAttr data = getSgData())
+        return llvm::to_vector_of<int64_t>(data.asArrayRef());
+      return std::nullopt;
+    }
+
+    /// Delinearizes a linear subgroup ID into its multidimensional indices
+    /// based on the effective subgroup layout.
+    FailureOr<SmallVector<Value>>
+    delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+
+    /// Generates instructions to compute multidimensional offsets for blocks
+    /// assigned to a subgroup identified by linearId. The shape parameter
+    /// represents the workgroup-level problem size. Each subgroup may access
+    /// multiple blocks according to round-robin distribution rules.
+    FailureOr<SmallVector<SmallVector<Value>>>
+    getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+
   }];
 
   let assemblyFormat = "`<` struct(params) `>`";
   let genVerifyDecl = 1;
 }
 
+
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+  let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
+
+  let description = [{
+    Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
+    However, whereas LayoutAttr requires the data to have the same rank as the attribute,
+    SliceAttr permits the data to have a lower rank. In this case, compute units in the
+    specified dimensions (given by `$dims`) share the data, provided that the remaining
+    ranks match the data rank. SliceAttr is commonly used by operations such as
+    vector.multi_reduction and vector.broadcast.
+
+    Example:
+    ```
+    #l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
+    #r = #xegpu.slice<#l, dim = [0]>
+
+    %exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
+    %red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
+    %bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
+    ```
+    In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to
+    a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a
+    single reduction result of type vector<32xf32>.
+
+  }];
+
+  let parameters = (ins
+    "xegpu::LayoutTrait": $parent,
+    "DenseI64ArrayAttr": $dims
+  );
+
+  let extraClassDeclaration = [{
+
+    int64_t getRank() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      return parent.getRank() - attr.getDims().size();
+    }
+
+    DenseI32ArrayAttr getOrder() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      return parent.getOrder();
+    }
+
+    bool isWgLayout() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      return parent.isWgLayout();
+    }
+
+    bool isSgLayout() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      return parent.isSgLayout();
+    }
+
+    /// Returns the SgLayout of the attribute, computed by applying
+    /// the slice dimensions to the underlying LayoutAttr.
+    std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      if (auto layout = parent.getSgLayoutAsInt()) {
+        ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+        return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
+      }
+      return std::nullopt;
+    }
+
+    /// Returns the SgData of the attribute, computed by applying
+    /// the slice dimensions to the underlying LayoutAttr.
+    std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
+      SliceAttr attr = flatten();
+      auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+      if (auto data = parent.getSgDataAsInt()) {
+        ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+        return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
+      }
+      return std::nullopt;
+    }
+
+    /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
+    /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
+    /// it will coalese two slice operations and return a simplified SliceAttr
+    /// #xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0, 1]>
+    SliceAttr flatten() const;
+
+    /// Delinearizes a linear subgroup ID into its multidimensional indices
+    /// based on the effective subgroup layout.
+    FailureOr<SmallVector<Value>>
+    delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+
+    /// Generates instructions to compute multidimensional offsets for blocks
+    /// assigned to a subgroup identified by linearId. The shape parameter
+    /// represents the workgroup-level problem size. Each subgroup may access
+    /// multiple blocks according to round-robin distribution rules.
+    FailureOr<SmallVector<SmallVector<Value>>>
+    getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+
+  }];
+
+  let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
+  let genVerifyDecl = 1;
+}
+
 def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
   let summary = [{Specifies a half-open range}];
   let description = [{

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 549018b61d6fb..76d58e5ea2424 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
       /// Checks if the given shape can be evenly distributed based on the layout
       /// and data factors provided by the LayoutAttr.
       static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+
+      /// drops/slices the shape in the specified dims, and return the rest. e.g.,
+      /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
+      template<typename T, typename U>
+      static llvm::SmallVector<T> slice(llvm::ArrayRef<T> shape, llvm::ArrayRef<U> dims) {
+        llvm::SmallVector<T> result;
+        for (auto [i, v]: llvm::enumerate(shape)) {
+          if (!llvm::is_contained(dims, i))
+            result.push_back(v);
+        }
+        return result;
+      }
     }];
 }
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97ccfdf6d..7c6a4f37db9af 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,11 +7,14 @@ add_mlir_dialect_library(MLIRXeGPUDialect
 
   DEPENDS
   MLIRXeGPUIncGen
+  MLIRXeGPUAttrInterfaceIncGen
   MLIRXeGPUAttrsIncGen
   MLIRXeGPUEnumsIncGen
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRIndexDialect
+  MLIRAffineUtils
   MLIRArithUtils
   MLIRDialectUtils
   MLIRIR

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3c0ca114a62d4..d997296a22c20 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,12 +6,16 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
 
 using std::optional;
 
@@ -33,6 +37,57 @@ void XeGPUDialect::initialize() {
       >();
 }
 
+/// Generates instructions to compute offsets for a subgroup identified by
+/// its multidimensional indices (sgId), using the specified subgroup layout
+/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
+/// dimensions (sizePerWg).
+static SmallVector<SmallVector<Value>>
+genOffsetsComputingInsts(OpBuilder &builder, Location loc,
+                         SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
+                         ArrayRef<int64_t> sizePerSg,
+                         ArrayRef<int64_t> sizePerWg) {
+
+  SmallVector<SmallVector<Value>> offsets;
+
+  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
+  SmallVector<Value> localOffsets = llvm::map_to_vector(
+      llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+        return builder.createOrFold<index::MulOp>(
+            loc, std::get<0>(t),
+            builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+      });
+
+  // distUnit[i] is the minimum value between sizePerWg[i] and
+  // sgLayout[i] * sizePerSg[i]
+  SmallVector<int64_t> distUnit = llvm::map_to_vector(
+      llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
+      [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+  for (SmallVector<int64_t> unitOffs :
+       StaticTileOffsetRange(sizePerWg, distUnit)) {
+    SmallVector<Value> base =
+        llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
+          return builder.create<arith::ConstantIndexOp>(loc, d);
+        });
+
+    SmallVector<Value> adds = llvm::map_to_vector(
+        llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
+          return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
+                                                     std::get<1>(t));
+        });
+
+    SmallVector<Value> mods = llvm::map_to_vector(
+        llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+          return builder.createOrFold<index::RemUOp>(
+              loc, std::get<0>(t),
+              builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+        });
+
+    offsets.push_back(mods);
+  }
+  return offsets;
+}
+
 // Checks if the given shape can be evenly distributed based on the layout
 // and data factors provided by the LayoutAttr.
 bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
@@ -211,6 +266,148 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   return success();
 }
 
+FailureOr<SmallVector<Value>>
+LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+                                  Value linearId) {
+  // delinearizeSubgroupId is only available for
+  // workgroup-level layout attribute
+  if (!isWgLayout())
+    return failure();
+
+  // TODO: handle order attribute
+  auto hasDefaultOrder = [&]() {
+    DenseI32ArrayAttr order = getOrder();
+    return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
+                         llvm::reverse(order.asArrayRef())));
+  };
+  if (!hasDefaultOrder())
+    return mlir::emitError(loc, "order attribute is currently not supported.");
+
+  auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value {
+    return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+  });
+
+  return affine::delinearizeIndex(builder, loc, linearId, dims);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by LayoutAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+                       ArrayRef<int64_t> shape) {
+  if (!isWgLayout())
+    return failure();
+
+  SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+  SmallVector<int64_t> sgShape;
+  if (auto maybeSgShape = getSgDataAsInt())
+    sgShape = maybeSgShape.value();
+  else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+    sgShape = derivedShape.value();
+  else
+    return failure();
+
+  // delinearize Ids
+  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  if (failed(maybeIds))
+    return failure();
+  SmallVector<Value> sgIds = *maybeIds;
+
+  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+                                  shape);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+                  xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+  if (!parent || !dims)
+    return emitError() << "expected parent layout and dims attribute";
+
+  int64_t rank = parent.getRank();
+
+  // check every element in dims is unique and smaller than rank
+  llvm::SmallDenseSet<int64_t> seen;
+  for (int64_t dim : dims.asArrayRef()) {
+    if (dim < 0 || dim >= rank)
+      return emitError() << "invalid dim (" << dim << ") in slice attribute.";
+    if (!seen.insert(dim).second)
+      return emitError() << "repeated dim (" << dim << ") in slice attribute.";
+  }
+  return success();
+}
+
+SliceAttr SliceAttr::flatten() const {
+  xegpu::LayoutTrait parent = getParent();
+  SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
+
+  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
+    parent = sliceAttr.getParent();
+    slicedDims.push_back(sliceAttr.getDims());
+  }
+
+  auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
+  SmallVector<int64_t> indices =
+      llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
+
+  // get remaining dims (flattend) by applying slice ops with all slicedDims
+  SmallVector<int64_t> remainingDims(indices);
+  for (auto dim : llvm::reverse(slicedDims))
+    remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
+                                        dim.asArrayRef());
+
+  // get flattend sliced dims by applying slice ops with the remaining dims
+  SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
+      llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
+
+  return xegpu::SliceAttr::get(
+      getContext(), layoutAttr,
+      DenseI64ArrayAttr::get(getContext(), flattendDims));
+}
+
+FailureOr<SmallVector<Value>>
+SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+                                 Value linearId) {
+  SliceAttr attr = flatten();
+  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+  return parent.delinearizeSubgroupId(builder, loc, linearId);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by SliceAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+                      ArrayRef<int64_t> shape) {
+  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+  if (!isWgLayout())
+    return failure();
+
+  SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+  SmallVector<int64_t> sgShape;
+  if (auto maybeSgShape = getSgDataAsInt())
+    sgShape = maybeSgShape.value();
+  else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+    sgShape = derivedShape.value();
+  else
+    return failure();
+
+  // delinearize Ids
+  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  if (failed(maybeIds))
+    return failure();
+
+  // The effective sgIds for offsets computing correspond
+  // to the dims that are not sliced.
+  ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
+  SmallVector<Value> sgIds =
+      XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
+
+  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+                                  shape);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..fc11fa810a35c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -931,6 +931,9 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
 } // namespace xegpu
 } // namespace mlir
 
+namespace mlir {
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
+} // namespace mlir
 #include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
 #define GET_OP_CLASSES
 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..4a5525c8abb30 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -125,39 +125,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
 struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
 
-  // Calculate offset for each subgroup
-  static SmallVector<OpFoldResult>
-  calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
-                         const SmallVector<OpFoldResult> &originalOffsets,
-                         const SmallVector<Value> &localOffset,
-                         const SmallVector<int64_t> &distUnitBaseAddr,
-                         const SmallVector<int64_t> &distUnitShape) {
-    assert(localOffset.size() == distUnitBaseAddr.size() &&
-           "localOffset and distUnitBaseAddr must have the same rank");
-
-    SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
-                                            originalOffsets.end());
-    size_t rank = localOffset.size();
-    for (size_t i = 0; i < rank; ++i) {
-      size_t dimIdx = originalOffsets.size() - rank + i;
-      Value constOffset =
-          arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]);
-      Value offset =
-          rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
-      Value modValue =
-          arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]);
-      Value offsetMod =
-          rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
-      Value origOffset = getValueOrCreateConstantIndexOp(
-          rewriter, loc, originalOffsets[dimIdx]);
-      Value globalOffset =
-          rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
-      globalOffsets[dimIdx] = globalOffset;
-    }
-
-    return globalOffsets;
-  }
-
   LogicalResult
   matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -177,74 +144,56 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
       return rewriter.notifyMatchFailure(
           op, "sgLayout attribute is required in layout");
 
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-
-    // TODO : Handle order attribute
     // Get the subgroup ID
-    auto linearSgId =
+    Value linearSgId =
         gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
 
-    // Create constants for layout dimensions
-    SmallVector<Value> sgLayoutDim(sgLayout.size());
-    SmallVector<Value> sgDataDim(sgShape.size());
-
-    for (size_t i = 0; i < sgLayout.size(); i++) {
-      sgLayoutDim[i] =
-          arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
-      sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
-    }
-
     int64_t startOfRange = -1, endOfRange = -1;
     bool sgIdRangeSpecified =
         isSgIdRangeSpecified(op, startOfRange, endOfRange);
 
-    Value adjustedSgId = linearSgId;
     if (sgIdRangeSpecified) {
       int64_t sgCount = endOfRange - startOfRange;
       if (computeProduct(sgLayout) != sgCount)
         return rewriter.notifyMatchFailure(
             op, "sg_layout size must match the sg_id_range");
-      // Subtract startOfRange from the original subgroup id to get the adjusted
-      // sg id
+      // Subtract startOfRange from the original subgroup id to get
+      // the adjusted sg id
       Value startOfRangeVal =
-          arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
-      adjustedSgId =
+          rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+      linearSgId =
           rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
     }
 
-    auto deLinearizeSgId =
-        affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
-    if (failed(deLinearizeSgId))
+    auto maybeTdescOffsets =
+        layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+    if (failed(maybeTdescOffsets))
       return failure();
-    SmallVector<Value> sgIds = *deLinearizeSgId;
-
-    // Calculate distribution unit shape and local offsets for subgroup
-    SmallVector<int64_t> distUnitShape(sgLayout.size());
-    SmallVector<Value> localOffset(sgLayout.size());
-    for (size_t i = 0; i < sgLayout.size(); i++) {
-      distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
-      localOffset[i] =
-          rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
-    }
-
-    SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
 
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     xegpu::TensorDescType newTdescTy =
         xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
                                    layout.dropSgLayoutAndData());
+
     SmallVector<Value> newCreateNdOps;
-    for (SmallVector<int64_t> distUnitBaseAddr :
-         StaticTileOffsetRange(wgShape, distUnitShape)) {
-      SmallVector<OpFoldResult> globalOffsets =
-          calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
-                                 distUnitBaseAddr, distUnitShape);
-
-      auto newCreateNdOp = xegpu::CreateNdDescOp::create(
-          rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
+    SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+
+    for (auto tdescOffsets : *maybeTdescOffsets) {
+      SmallVector<OpFoldResult> sgOffsets;
+      size_t rank = tdescOffsets.size();
+      for (size_t i = 0; i < rank; i++) {
+        size_t idx = wgOffsets.size() - rank + i;
+        Value add = rewriter.createOrFold<index::AddOp>(
+            loc, tdescOffsets[i],
+            getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+        sgOffsets.push_back(add);
+      }
+
+      auto newOp = xegpu::CreateNdDescOp::create(
+          rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
           op.getMixedSizes(), op.getMixedStrides());
-      newCreateNdOps.push_back(newCreateNdOp);
+      newCreateNdOps.push_back(newOp);
     }
-
     rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
     return success();
   }

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..948f136d78709 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -743,3 +743,22 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
         #xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2], order = [0, 1, 2]>>
   return
 }
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{repeated dim (2) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [2, 2]>
+func.func @slice_attr_repeat_dim() {
+  %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+  return
+}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{invalid dim (3) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [3]>
+func.func @slice_attr_repeat_dim() {
+  %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+  return
+}
+

diff  --git a/mlir/test/Dialect/XeGPU/layout.mlir b/mlir/test/Dialect/XeGPU/layout.mlir
index 017dacc8d629a..e4b4e22e5cf97 100644
--- a/mlir/test/Dialect/XeGPU/layout.mlir
+++ b/mlir/test/Dialect/XeGPU/layout.mlir
@@ -50,4 +50,27 @@ gpu.func @convert_layout_wg(%a: vector<32x64xf16>) {
   gpu.return
 }
 
+gpu.func @slice_attr() {
+  //CHECK: arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
+  %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
+  gpu.return
+}
+
+gpu.func @nested_slice_attr() {
+  //CHECK: arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>, dims = [1]>} dense<8> : vector<16xindex>
+  %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>, dims = [1]>} dense<8> : vector<16xindex>
+  gpu.return
+}
+
+gpu.func @softmax_dim_0(%arg0: vector<256x128xf32>) -> vector<256x128xf32> {
+  %cst = arith.constant dense<0.000000e+00> : vector<128xf32>
+  %0 = math.exp %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xf32>
+  //CHECK: vector.multi_reduction <add>, {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>
+  %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>
+  //CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<128xf32> to vector<256x128xf32>
+  %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<128xf32> to vector<256x128xf32>
+  %3 = arith.divf %0, %2 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xf32>
+  gpu.return %3 : vector<256x128xf32>
+}
+
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
new file mode 100644
index 0000000000000..547c7355e00c6
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
+
+//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
+gpu.module @test {
+  gpu.func @slice_attr() -> vector<128xindex> {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+    gpu.return %step : vector<128xindex>
+  }
+
+  gpu.func @nested_slice_attr() -> vector<128xindex> {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
+    //CHECK: [[c32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+    //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+    //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+    //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+    %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
+    gpu.return %0 : vector<128xindex>
+  }
+
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 628a4857d1253..e5cc65e6bd3d7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,5 +1,8 @@
 // RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
 
+#map = affine_map<()[s0] -> (s0 floordiv 4)>
+#map1 = affine_map<()[s0] -> (s0 mod 4)>
+
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: create_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -12,6 +15,30 @@ gpu.module @test_round_robin_assignment {
       gpu.return
     }
 
+  // CHECK-LABEL: create_nd_tdesc_with_shared_data
+  // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
+    //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
+    //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
+    //CHECK: [[C16:%.+]] = arith.constant 16 : index
+    //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
+    //CHECK: [[C64:%.+]] = arith.constant 64 : index
+    //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
+    //CHECK: [[C0:%.+]] = arith.constant 0 : index
+    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
+    //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
+    //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
+    //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
+    //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+    gpu.return
+  }
+
   // CHECK-LABEL: load_nd_tdesc
   // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
   gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d4b00372bc193..180ba8a162c9f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,34 +4,26 @@
 //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: create_nd_tdesc
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
   gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
-  // CHECK: %[[SGID:.*]] = gpu.subgroup_id
-  // CHECK: %[[C8:.*]] = arith.constant 8 : index
-  // CHECK: %[[C32:.*]] = arith.constant 32 : index
-  // CHECK: %[[C4:.*]] = arith.constant 4 : index
-  // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
-  // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
-  // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
-  // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
-  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
-  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
-  // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[C256:.*]] = arith.constant 256 : index
-  // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
-  // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
-  // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
-  // CHECK: %[[C128:.*]] = arith.constant 128 : index
-  // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
-  // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
-  // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-  // CHECK: gpu.return
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-    -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
-  gpu.return
+    //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+    //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+    //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+    //CHECK: [[C32:%.+]] = arith.constant 32 : index
+    //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+    //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
+    //CHECK: [[C0:%.+]] = arith.constant 0 : index
+    //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+    //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+    //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+    //CHECK: [[C256:%.+]] = arith.constant 256 : index
+    //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+    //CHECK: [[C128:%.+]] = arith.constant 128 : index
+    //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+    //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
   }
 
   // CHECK-LABEL: load_nd_tdesc
@@ -347,7 +339,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
   // CHECK-LABEL: @subgroup_id_range_nested_if
   gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
     %sg_id = gpu.subgroup_id : index
-    %c1 = arith.constant 1 : i1 
+    %c1 = arith.constant 1 : i1
     %c3 = arith.constant 3 : index
     %c32 = arith.constant 32 : index
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>

diff  --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c6245b637c2a7..3bea8efcdb0ae 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -7,11 +7,14 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -147,12 +150,118 @@ struct TestXeGPUUnrollingPatterns
   }
 };
 
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "test-xegpu-layout-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+// Test pattern for distributing vector::StepOp from workgroup to subgroup.
+// Validates LayoutTrait interfaces for offset computation abstraction between
+// LayoutAttr and SliceAttr.
+class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    auto layoutName = xegpu::getLayoutName(op->getResult(0));
+    auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+    if (!sliceAttr || sliceAttr.getRank() != 1)
+      return failure();
+
+    std::optional<SmallVector<int64_t>> sgShape = sliceAttr.getSgDataAsInt();
+    if (!sgShape)
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType type = op.getResult().getType();
+    auto wgShape = type.getShape();
+
+    Value sgId =
+        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+    auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+    if (failed(maybeOffsets))
+      return failure();
+
+    VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+    Value base = vector::StepOp::create(rewriter, loc, newTy);
+    SmallVector<Value> newOps;
+    for (auto offsets : *maybeOffsets) {
+      Value bcast =
+          vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+      Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+      newOps.push_back(add);
+    }
+    rewriter.replaceOpWithMultiple(op, {newOps});
+    return success();
+  }
+};
+
+struct TestXeGPULayoutInterface
+    : public PassWrapper<TestXeGPULayoutInterface,
+                         OperationPass<gpu::GPUModuleOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPULayoutInterface)
+
+  StringRef getArgument() const final { return "test-xegpu-layout-interface"; }
+
+  StringRef getDescription() const final {
+    return "Test the implementation of XeGPU Layout interfaces";
+  }
+
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    registry.insert<arith::ArithDialect>();
+    registry.insert<memref::MemRefDialect>();
+    registry.insert<xegpu::XeGPUDialect>();
+    registry.insert<vector::VectorDialect>();
+    registry.insert<index::IndexDialect>();
+  }
+
+  TestXeGPULayoutInterface() = default;
+  TestXeGPULayoutInterface(const TestXeGPULayoutInterface &pass)
+      : PassWrapper(pass) {}
+
+  void runOnOperation() override {
+    MLIRContext *ctx = &getContext();
+
+    TypeConverter typeConverter;
+    auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+                               mlir::ValueRange inputs,
+                               mlir::Location loc) -> mlir::Value {
+      return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+          .getResult(0);
+    };
+    typeConverter.addSourceMaterialization(materializeCast);
+    typeConverter.addTargetMaterialization(materializeCast);
+
+    RewritePatternSet patterns(ctx);
+    patterns.add<TestStepOpPattern>(typeConverter, ctx);
+
+    ConversionTarget target(*ctx);
+    auto isLegal = [&](xegpu::SliceAttr layout) -> bool {
+      return !layout || !layout.isWgLayout();
+    };
+
+    target.addDynamicallyLegalOp<vector::StepOp>(
+        [&](vector::StepOp op) -> bool {
+          auto layoutName = xegpu::getLayoutName(op->getResult(0));
+          auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+          return isLegal(sliceAttr);
+        });
+
+    target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+
+    (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace test {
 void registerTestXeGPULowerings() {
   PassRegistration<TestXeGPUUnrollingPatterns>();
+  PassRegistration<TestXeGPULayoutInterface>();
 }
 } // namespace test
 } // namespace mlir


        


More information about the Mlir-commits mailing list