[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (PR #186958)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 16 22:54:53 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jianhui Li (Jianhui-Li)

<details>
<summary>Changes</summary>

This PR extends XeGPU lane layout to support wrap-around distribution, enabling replication of lane-level tensor tiles across all lanes when the tile size matches lane_data along a given dimension. Previously, distribution required the tile size to exceed the number of lanes × lane_data for even partitioning.

This PR also refactors layout attribute interface functions:

computeDistributedShape() computes the distributed vector shape and is shared by work-to-subgroup and subgroup-to-lane distribution, which follow the same distribution rule (even or wrap-around).

computeStaticDistributedCoords() computes compile-time distributed coordinates of sub-tiles per subgroup/lane. It is the compile-time counterpart of computeDistributedCoords() and is used by isCompatibleWith().

---

Patch is 43.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/186958.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+75-20) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+247-33) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (-13) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+6-2) 
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+16-19) 
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+5-22) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+4-12) 
- (modified) mlir/test/Dialect/XeGPU/layout.mlir (+10) 
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+6-3) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+10-10) 
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+8-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index ce0cce65373e5..7917a0dde556c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -270,6 +270,54 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "computeDistributedCoords",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+    InterfaceMethod<[{Statically computes multidimensional coordinates for all dist units
+                      assigned to a compute unit identified by `linearId`. This is the
+                      compile-time counterpart of `computeDistributedCoords`: it performs
+                      the same delinearization and round-robin enumeration but operates
+                      entirely on static integer values. Returns a list of coordinate
+                      vectors, one per dist unit.}],
+                    /*retTy=*/"SmallVector<SmallVector<int64_t>>",
+                    /*methodName=*/"computeStaticDistributedCoords",
+                    /*args=*/(ins "int64_t":$linearId, "ArrayRef<int64_t>":$shape)>,
+    InterfaceMethod<[{Computes the per-compute-unit shape by dividing each dimension of
+                      `shape` by the corresponding layout factor (sg_layout or
+                      lane_layout). For wrap-around dimensions where the division is uneven,
+                      the tensor tile is broadcasted to all subgroups/lanes.}],
+                    /*retTy=*/"FailureOr<SmallVector<int64_t>>",
+                    /*methodName=*/"computeDistributedShape",
+                    /*args=*/(ins "SmallVector<int64_t>":$shape),
+                    /*methodBody=*/[{
+                      SmallVector<int64_t> layout;
+                      SmallVector<int64_t> subShape;
+                      if ($_self.isForWorkgroup()) {
+                        layout = $_self.getEffectiveSgLayoutAsInt();
+                        subShape = $_self.getEffectiveSgDataAsInt();
+                      } else if ($_self.isForSubgroup()) {
+                        layout = $_self.getEffectiveLaneLayoutAsInt();
+                        subShape = $_self.getEffectiveLaneDataAsInt();
+                      } else {
+                        return failure();
+                      }
+                      assert(
+                          !subShape.empty() &&
+                          "sgdata or lanedata cannot be empty for distributed shape computation");
+                      SmallVector<int64_t> distributedShape(shape.size());
+                      for (auto [i, dim] : llvm::enumerate(shape)) {
+                        int64_t distri_unit = layout[i]*subShape[i];
+                        if ((dim % distri_unit) == 0) {
+                          // Evenly divisible case, divide the dimension by the layout factor.
+                          distributedShape[i] = dim / layout[i];
+                          assert((distributedShape[i] % subShape[i] == 0) &&
+                                "Even distribution: sgdata or lanedata must divide the distributed dimension");
+                        } else {
+                          // wrap around case, the dimension size must be equal to subShape value
+                          assert(dim == subShape[i] &&
+                                "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+                          distributedShape[i] = dim;
+                        }
+                      }
+                      return distributedShape;
+                    }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
@@ -282,28 +330,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                                   "ArrayRef<int64_t>": $perm,
                                   "xegpu::LayoutKind": $kind)>,
     InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
-                     at a specific level of the layout hierarchy. Unlike isEqualTo,
-                     this compares only the effective (non-sliced) fields at the
-                     requested level.}],
+                     at a specific level of the layout hierarchy regarding a given shape. }],
                     /*retTy=*/"bool",
                     /*methodName=*/"isCompatibleWith",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
-                                  "xegpu::LayoutKind": $level),
-                    /*methodBody=*/[{
-                      if (!other)
-                        return false;
-                      switch (level) {
-                        case xegpu::LayoutKind::Subgroup:
-                          return $_self.getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
-                                 $_self.getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt();
-                        case xegpu::LayoutKind::InstData:
-                          return $_self.getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt();
-                        case xegpu::LayoutKind::Lane:
-                          return $_self.getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
-                                 $_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
-                      }
-                      return false;
-                    }]>,
+                                  "SmallVector<int64_t>": $shape,
+                                  "xegpu::LayoutKind": $level)>,
     InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
                      For LayoutAttr, this compares all fields.
                      For SliceAttr, this requires the same parent and same sliced dims.}],
@@ -559,12 +591,24 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<SmallVector<Value>>>
     computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    ///Statically computes multidimensional coordinates for all dist units
+    ///assigned to a compute unit identified by `linearId`. This is the
+    ///compile-time counterpart of `computeDistributedCoords`.
+    SmallVector<SmallVector<int64_t>>
+    computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
+
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
 
     /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
 
+    /// Check if this layout is compatible with another layout 
+    /// at a specific level of the layout hierarchy regarding a given shape.
+    bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level);
+
     /// Check if this layout is a transpose of another layout.
     bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
   }];
@@ -772,16 +816,27 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// assigned to a subgroup identified by linearId. The shape parameter
     /// represents the workgroup-level problem size. Each subgroup may access
     /// multiple blocks according to round-robin distribution rules.
-
     FailureOr<SmallVector<SmallVector<Value>>>
     computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    ///Statically computes multidimensional coordinates for all dist units
+    ///assigned to a compute unit identified by `linearId`. This is the
+    ///compile-time counterpart of `computeDistributedCoords`.
+    SmallVector<SmallVector<int64_t>>
+    computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
+
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
 
     /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
 
+    /// Check if this layout is compatible with another layout 
+    /// at a specific level of the layout hierarchy regarding a given shape.
+    bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level);
+
     /// Check if this layout is a transpose of another layout.
     bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index f9aa01aca7172..0fbd339146d7b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -160,7 +160,7 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
   // check LaneLayout and LaneData
   auto maybeLaneShape =
       tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
-                    attr.getEffectiveLaneDataAsInt(), false);
+                    attr.getEffectiveLaneDataAsInt(), true);
   return maybeLaneShape.has_value();
 }
 
@@ -238,25 +238,17 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
                        << lane_layout.size();
   }
 
-  // sg_data is optional for Workgroup layout, but its presence requires
-  // sg_layout.
-  if (sg_data) {
-    if (!sg_layout)
-      return emitError() << "expected sg_layout being used with sg_data";
-    if (sg_data.size() != sg_layout.size())
-      return emitError()
-             << "expected sg_data and sg_layout to have the same rank";
-  }
+  if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
+    return emitError() << "sg_layout and sg_data must be used together";
+  if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
+    return emitError()
+           << "expected sg_data and sg_layout to have the same rank";
 
-  // lane_data is optional for Subgroup layout, but its presence requires
-  // lane_layout.
-  if (lane_data) {
-    if (!lane_layout)
-      return emitError() << "expected lane_layout being used with lane_data";
-    if (lane_data.size() != lane_layout.size())
-      return emitError()
-             << "expected lane_data and lane_layout to have the same rank";
-  }
+  if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
+    return emitError() << "lane_layout and lane_data must be used together";
+  if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
+    return emitError()
+           << "expected lane_data and lane_layout to have the same rank";
 
   if (order) {
     if (!sg_layout && !lane_layout)
@@ -373,12 +365,8 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   } else {
     return failure();
   }
-  if (subShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, layout))
-      subShape = derivedShape.value();
-    else
-      return failure();
-  }
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
+                              "distributed coordinates computation");
 
   // delinearize Ids
   auto maybeIds = delinearizeId(builder, loc, linearId);
@@ -396,6 +384,70 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
   return *this == dyn_cast<xegpu::LayoutAttr>(other);
 }
 
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// LayoutAttr.
+SmallVector<SmallVector<int64_t>>
+LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
+                                           ArrayRef<int64_t> shape) {
+  SmallVector<int64_t> layoutVec;
+  SmallVector<int64_t> subShape;
+  SmallVector<int64_t> instData;
+  if (isForWorkgroup()) {
+    layoutVec = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    instData = getEffectiveInstDataAsInt();
+    layoutVec = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  }
+  if (!instData.empty()) {
+    linearId = 0;
+    subShape = instData;
+  }
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+  // Delinearize the linear ID using the order attribute.
+  DenseI32ArrayAttr orderAttr = getOrder();
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+      return static_cast<int64_t>(idx);
+    });
+  } else {
+    order =
+        llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, layoutVec.size())));
+  }
+  SmallVector<int64_t> delinearizedId(layoutVec.size());
+  int64_t remaining = linearId;
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
+    remaining = remaining / layoutVec[dimIdx];
+  }
+
+  // Compute distribution unit shape (clamped to srcShape).
+  SmallVector<int64_t> distUnitShape(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    distUnitShape[i] = std::min(shape[i], layoutVec[i] * subShape[i]);
+
+  // Compute local offset of this ID within a distribution unit.
+  SmallVector<int64_t> localOffset(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    localOffset[i] = delinearizedId[i] * subShape[i];
+
+  // Enumerate all distribution units and compute coordinates.
+  SmallVector<SmallVector<int64_t>> coordinates;
+  for (SmallVector<int64_t> unitOffs :
+       StaticTileOffsetRange(shape, distUnitShape)) {
+    SmallVector<int64_t> coord(shape.size());
+    for (size_t i = 0; i < shape.size(); ++i)
+      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+    coordinates.push_back(coord);
+  }
+  return coordinates;
+}
+
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
 DistributeLayoutAttr
 LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
@@ -743,6 +795,45 @@ bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
   return false;
 }
 
+bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level) {
+  if (!other)
+    return false;
+  if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+    if (level == xegpu::LayoutKind::Subgroup)
+      return (getEffectiveSgLayoutAsInt() ==
+                  other.getEffectiveSgLayoutAsInt() &&
+              getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
+    if (level == xegpu::LayoutKind::Lane)
+      return (getEffectiveLaneLayoutAsInt() ==
+                  other.getEffectiveLaneLayoutAsInt() &&
+              getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
+  }
+  if (level == xegpu::LayoutKind::Subgroup) {
+    int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  if (level == xegpu::LayoutKind::InstData) {
+    return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
+  }
+  if (level == xegpu::LayoutKind::Lane) {
+    int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -819,12 +910,8 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
     return failure();
   }
 
-  if (subShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, layout))
-      subShape = derivedShape.value();
-    else
-      return failure();
-  }
+  if (subShape.empty())
+    return failure();
 
   // delinearize Ids
   auto maybeIds = delinearizeId(builder, loc, linearId);
@@ -834,10 +921,93 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   // The effective sgIds for offsets computing correspond
   // to the dims that are not sliced.
   ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
-  SmallVector<Value> sgIds =
+  SmallVector<Value> canonicalIds =
       XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
 
-  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
+  return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
+}
+
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
+/// only the non-sliced dimensions for coordinate computation.
+SmallVector<SmallVector<int64_t>>
+SliceAttr::computeStaticDistributedCoords(int64_t linearId,
+                                          ArrayRef<int64_t> shape) {
+  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  SmallVector<int64_t> instData;
+  if (isForWorkgroup()) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    instData = getEffectiveInstDataAsInt();
+    layout = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  }
+  if (!instData.empty()) {
+    linearId = 0;
+    subShape = instData;
+  }
+
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+  // Delinearize the ID using the parent layout (same as the IR version).
+  SliceAttr flattened = flatten();
+  auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
+  SmallVector<int64_t> parentLayoutVec;
+  if (parent.isForWorkgroup())
+    parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
+  else
+    parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
+
+  DenseI32ArrayAttr orderAttr = parent.getOrder();
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+      return static_cast<int64_t>(idx);
+    });
+  } else {
+    order = llvm::to_vector(
+        llvm::reverse(llvm::seq<int64_t>(0, parentLayoutVec.size())));
+  }
+  SmallVector<int64_t> allIds(parentLayoutVec.size());
+  int64_t remaining = linearId;
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
+    if (i < order.size() - 1)
+      remaining = remaining / parentLayoutVec[dimIdx];
+  }
+
+  // The effective IDs for coordinate computation correspond
+  // to the dims that are not sliced.
+  ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
+  SmallVector<int64_t> canonicalIds =
+      XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
+
+  // Compute distribution unit shape (clamped to srcShape).
+  SmallVector<int64_t> distUnitShape(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
+
+  // Compute local offset of this ID within a distribution unit.
+  SmallVector<int64_t> localOffset(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    localOffset[i] = canonicalIds[i] * subShape[i];
+
+  // Enumerate all distribution units and compute coordinates.
+  SmallVector<SmallVector<int64...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/186958


More information about the Mlir-commits mailing list