[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (PR #186958)

Charitha Saumya llvmlistbot at llvm.org
Thu Mar 19 10:48:21 PDT 2026


================
@@ -834,10 +921,93 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   // The effective sgIds for offsets computing correspond
   // to the dims that are not sliced.
   ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
-  SmallVector<Value> sgIds =
+  SmallVector<Value> canonicalIds =
       XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
 
-  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
+  return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
+}
+
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
+/// only the non-sliced dimensions for coordinate computation.
+SmallVector<SmallVector<int64_t>>
+SliceAttr::computeStaticDistributedCoords(int64_t linearId,
+                                          ArrayRef<int64_t> shape) {
+  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  SmallVector<int64_t> instData;
+  if (isForWorkgroup()) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    instData = getEffectiveInstDataAsInt();
+    layout = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  }
+  if (!instData.empty()) {
+    linearId = 0;
+    subShape = instData;
+  }
+
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+  // Delinearize the ID using the parent layout (same as the IR version).
+  SliceAttr flattened = flatten();
+  auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
+  SmallVector<int64_t> parentLayoutVec;
+  if (parent.isForWorkgroup())
+    parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
+  else
+    parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
+
+  DenseI32ArrayAttr orderAttr = parent.getOrder();
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+      return static_cast<int64_t>(idx);
+    });
+  } else {
+    order = llvm::to_vector(
+        llvm::reverse(llvm::seq<int64_t>(0, parentLayoutVec.size())));
+  }
+  SmallVector<int64_t> allIds(parentLayoutVec.size());
+  int64_t remaining = linearId;
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
+    if (i < order.size() - 1)
+      remaining = remaining / parentLayoutVec[dimIdx];
+  }
+
+  // The effective IDs for coordinate computation correspond
+  // to the dims that are not sliced.
+  ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
+  SmallVector<int64_t> canonicalIds =
+      XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
+
+  // Compute distribution unit shape (clamped to srcShape).
+  SmallVector<int64_t> distUnitShape(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
+
+  // Compute local offset of this ID within a distribution unit.
+  SmallVector<int64_t> localOffset(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    localOffset[i] = canonicalIds[i] * subShape[i];
+
+  // Enumerate all distribution units and compute coordinates.
+  SmallVector<SmallVector<int64_t>> coordinates;
+  for (SmallVector<int64_t> unitOffs :
+       StaticTileOffsetRange(shape, distUnitShape)) {
+    SmallVector<int64_t> coord(shape.size());
+    for (size_t i = 0; i < shape.size(); ++i)
+      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+    coordinates.push_back(coord);
+  }
+  return coordinates;
----------------
charithaintc wrote:

This part seems common. good to move to a helper and reuse on both places. Otherwise changing only one place in future will cause weird bugs. 

https://github.com/llvm/llvm-project/pull/186958


More information about the Mlir-commits mailing list