[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (PR #186958)
Charitha Saumya
llvmlistbot at llvm.org
Thu Mar 19 10:48:21 PDT 2026
================
@@ -834,10 +921,93 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
// The effective sgIds for offsets computing correspond
// to the dims that are not sliced.
ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
- SmallVector<Value> sgIds =
+ SmallVector<Value> canonicalIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
+ return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
+}
+
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
+/// only the non-sliced dimensions for coordinate computation.
+SmallVector<SmallVector<int64_t>>
+SliceAttr::computeStaticDistributedCoords(int64_t linearId,
+ ArrayRef<int64_t> shape) {
+ assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ SmallVector<int64_t> instData;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ instData = getEffectiveInstDataAsInt();
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ }
+ if (!instData.empty()) {
+ linearId = 0;
+ subShape = instData;
+ }
+
+ assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+ // Delinearize the ID using the parent layout (same as the IR version).
+ SliceAttr flattened = flatten();
+ auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
+ SmallVector<int64_t> parentLayoutVec;
+ if (parent.isForWorkgroup())
+ parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
+ else
+ parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
+
+ DenseI32ArrayAttr orderAttr = parent.getOrder();
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+ return static_cast<int64_t>(idx);
+ });
+ } else {
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, parentLayoutVec.size())));
+ }
+ SmallVector<int64_t> allIds(parentLayoutVec.size());
+ int64_t remaining = linearId;
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
+ if (i < order.size() - 1)
+ remaining = remaining / parentLayoutVec[dimIdx];
+ }
+
+ // The effective IDs for coordinate computation correspond
+ // to the dims that are not sliced.
+ ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
+ SmallVector<int64_t> canonicalIds =
+ XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
+
+ // Compute distribution unit shape (clamped to srcShape).
+ SmallVector<int64_t> distUnitShape(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
+
+ // Compute local offset of this ID within a distribution unit.
+ SmallVector<int64_t> localOffset(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ localOffset[i] = canonicalIds[i] * subShape[i];
+
+ // Enumerate all distribution units and compute coordinates.
+ SmallVector<SmallVector<int64_t>> coordinates;
+ for (SmallVector<int64_t> unitOffs :
+ StaticTileOffsetRange(shape, distUnitShape)) {
+ SmallVector<int64_t> coord(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+ coordinates.push_back(coord);
+ }
+ return coordinates;
----------------
charithaintc wrote:
This part seems common. good to move to a helper and reuse on both places. Otherwise changing only one place in future will cause weird bugs.
https://github.com/llvm/llvm-project/pull/186958
More information about the Mlir-commits
mailing list