[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (PR #186958)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 16 22:54:53 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Jianhui Li (Jianhui-Li)
<details>
<summary>Changes</summary>
This PR extends XeGPU lane layout to support wrap-around distribution, enabling replication of lane-level tensor tiles across all lanes when the tile size matches lane_data along a given dimension. Previously, distribution required the tile size to exceed the number of lanes × lane_data for even partitioning.
This PR also refactors layout attribute interface functions:
computeDistributedShape() computes the distributed vector shape and is shared by work-to-subgroup and subgroup-to-lane distribution, which follow the same distribution rule (even or wrap-around).
computeStaticDistributedCoords() computes compile-time distributed coordinates of sub-tiles per subgroup/lane. It is the compile-time counterpart of computeDistributedCoords() and is used by isCompatibleWith().
---
Patch is 43.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/186958.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+75-20)
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+247-33)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp (-13)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp (+6-2)
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+16-19)
- (modified) mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp (+5-22)
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+4-12)
- (modified) mlir/test/Dialect/XeGPU/layout.mlir (+10)
- (modified) mlir/test/Dialect/XeGPU/propagate-layout.mlir (+6-3)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+10-10)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+8-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index ce0cce65373e5..7917a0dde556c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -270,6 +270,54 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"FailureOr<SmallVector<SmallVector<Value>>>",
"computeDistributedCoords",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+ InterfaceMethod<[{Statically computes multidimensional coordinates for all dist units
+ assigned to a compute unit identified by `linearId`. This is the
+ compile-time counterpart of `computeDistributedCoords`: it performs
+ the same delinearization and round-robin enumeration but operates
+ entirely on static integer values. Returns a list of coordinate
+ vectors, one per dist unit.}],
+ /*retTy=*/"SmallVector<SmallVector<int64_t>>",
+ /*methodName=*/"computeStaticDistributedCoords",
+ /*args=*/(ins "int64_t":$linearId, "ArrayRef<int64_t>":$shape)>,
+ InterfaceMethod<[{Computes the per-compute-unit shape by dividing each dimension of
+ `shape` by the corresponding layout factor (sg_layout or
+ lane_layout). For wrap-around dimensions where the division is uneven,
+ the tensor tile is broadcasted to all subgroups/lanes.}],
+ /*retTy=*/"FailureOr<SmallVector<int64_t>>",
+ /*methodName=*/"computeDistributedShape",
+ /*args=*/(ins "SmallVector<int64_t>":$shape),
+ /*methodBody=*/[{
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if ($_self.isForWorkgroup()) {
+ layout = $_self.getEffectiveSgLayoutAsInt();
+ subShape = $_self.getEffectiveSgDataAsInt();
+ } else if ($_self.isForSubgroup()) {
+ layout = $_self.getEffectiveLaneLayoutAsInt();
+ subShape = $_self.getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+ assert(
+ !subShape.empty() &&
+ "sgdata or lanedata cannot be empty for distributed shape computation");
+ SmallVector<int64_t> distributedShape(shape.size());
+ for (auto [i, dim] : llvm::enumerate(shape)) {
+ int64_t distri_unit = layout[i]*subShape[i];
+ if ((dim % distri_unit) == 0) {
+ // Evenly divisible case, divide the dimension by the layout factor.
+ distributedShape[i] = dim / layout[i];
+ assert((distributedShape[i] % subShape[i] == 0) &&
+ "Even distribution: sgdata or lanedata must divide the distributed dimension");
+ } else {
+ // wrap around case, the dimension size must be equal to subShape value
+ assert(dim == subShape[i] &&
+ "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+ distributedShape[i] = dim;
+ }
+ }
+ return distributedShape;
+ }]>,
InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
@@ -282,28 +330,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"ArrayRef<int64_t>": $perm,
"xegpu::LayoutKind": $kind)>,
InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
- at a specific level of the layout hierarchy. Unlike isEqualTo,
- this compares only the effective (non-sliced) fields at the
- requested level.}],
+ at a specific level of the layout hierarchy regarding a given shape. }],
/*retTy=*/"bool",
/*methodName=*/"isCompatibleWith",
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
- "xegpu::LayoutKind": $level),
- /*methodBody=*/[{
- if (!other)
- return false;
- switch (level) {
- case xegpu::LayoutKind::Subgroup:
- return $_self.getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
- $_self.getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt();
- case xegpu::LayoutKind::InstData:
- return $_self.getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt();
- case xegpu::LayoutKind::Lane:
- return $_self.getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
- $_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
- }
- return false;
- }]>,
+ "SmallVector<int64_t>": $shape,
+ "xegpu::LayoutKind": $level)>,
InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
For LayoutAttr, this compares all fields.
For SliceAttr, this requires the same parent and same sliced dims.}],
@@ -559,12 +591,24 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
FailureOr<SmallVector<SmallVector<Value>>>
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+ ///Statically computes multidimensional coordinates for all dist units
+ ///assigned to a compute unit identified by `linearId`. This is the
+ ///compile-time counterpart of `computeDistributedCoords`.
+ SmallVector<SmallVector<int64_t>>
+ computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
+
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
/// Check if this layout is equal to another layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+ /// Check if this layout is compatible with another layout
+ /// at a specific level of the layout hierarchy regarding a given shape.
+ bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+ SmallVector<int64_t> shape,
+ xegpu::LayoutKind level);
+
/// Check if this layout is a transpose of another layout.
bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
}];
@@ -772,16 +816,27 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.
-
FailureOr<SmallVector<SmallVector<Value>>>
computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+ ///Statically computes multidimensional coordinates for all dist units
+ ///assigned to a compute unit identified by `linearId`. This is the
+ ///compile-time counterpart of `computeDistributedCoords`.
+ SmallVector<SmallVector<int64_t>>
+ computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
+
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
/// Check if this layout is equal to another layout.
bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
+ /// Check if this layout is compatible with another layout
+ /// at a specific level of the layout hierarchy regarding a given shape.
+ bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+ SmallVector<int64_t> shape,
+ xegpu::LayoutKind level);
+
/// Check if this layout is a transpose of another layout.
bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index f9aa01aca7172..0fbd339146d7b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -160,7 +160,7 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
// check LaneLayout and LaneData
auto maybeLaneShape =
tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
- attr.getEffectiveLaneDataAsInt(), false);
+ attr.getEffectiveLaneDataAsInt(), true);
return maybeLaneShape.has_value();
}
@@ -238,25 +238,17 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
<< lane_layout.size();
}
- // sg_data is optional for Workgroup layout, but its presence requires
- // sg_layout.
- if (sg_data) {
- if (!sg_layout)
- return emitError() << "expected sg_layout being used with sg_data";
- if (sg_data.size() != sg_layout.size())
- return emitError()
- << "expected sg_data and sg_layout to have the same rank";
- }
+ if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
+ return emitError() << "sg_layout and sg_data must be used together";
+ if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
+ return emitError()
+ << "expected sg_data and sg_layout to have the same rank";
- // lane_data is optional for Subgroup layout, but its presence requires
- // lane_layout.
- if (lane_data) {
- if (!lane_layout)
- return emitError() << "expected lane_layout being used with lane_data";
- if (lane_data.size() != lane_layout.size())
- return emitError()
- << "expected lane_data and lane_layout to have the same rank";
- }
+ if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
+ return emitError() << "lane_layout and lane_data must be used together";
+ if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
+ return emitError()
+ << "expected lane_data and lane_layout to have the same rank";
if (order) {
if (!sg_layout && !lane_layout)
@@ -373,12 +365,8 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
} else {
return failure();
}
- if (subShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, layout))
- subShape = derivedShape.value();
- else
- return failure();
- }
+ assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
+ "distributed coordinates computation");
// delinearize Ids
auto maybeIds = delinearizeId(builder, loc, linearId);
@@ -396,6 +384,70 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
return *this == dyn_cast<xegpu::LayoutAttr>(other);
}
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// LayoutAttr.
+SmallVector<SmallVector<int64_t>>
+LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
+ ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> layoutVec;
+ SmallVector<int64_t> subShape;
+ SmallVector<int64_t> instData;
+ if (isForWorkgroup()) {
+ layoutVec = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ instData = getEffectiveInstDataAsInt();
+ layoutVec = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ }
+ if (!instData.empty()) {
+ linearId = 0;
+ subShape = instData;
+ }
+ assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+ // Delinearize the linear ID using the order attribute.
+ DenseI32ArrayAttr orderAttr = getOrder();
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+ return static_cast<int64_t>(idx);
+ });
+ } else {
+ order =
+ llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, layoutVec.size())));
+ }
+ SmallVector<int64_t> delinearizedId(layoutVec.size());
+ int64_t remaining = linearId;
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
+ remaining = remaining / layoutVec[dimIdx];
+ }
+
+ // Compute distribution unit shape (clamped to srcShape).
+ SmallVector<int64_t> distUnitShape(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ distUnitShape[i] = std::min(shape[i], layoutVec[i] * subShape[i]);
+
+ // Compute local offset of this ID within a distribution unit.
+ SmallVector<int64_t> localOffset(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ localOffset[i] = delinearizedId[i] * subShape[i];
+
+ // Enumerate all distribution units and compute coordinates.
+ SmallVector<SmallVector<int64_t>> coordinates;
+ for (SmallVector<int64_t> unitOffs :
+ StaticTileOffsetRange(shape, distUnitShape)) {
+ SmallVector<int64_t> coord(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+ coordinates.push_back(coord);
+ }
+ return coordinates;
+}
+
// set the layout for unit dims: sg_data, inst_data and lane_data to 1
DistributeLayoutAttr
LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
@@ -743,6 +795,45 @@ bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
return false;
}
+bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+ SmallVector<int64_t> shape,
+ xegpu::LayoutKind level) {
+ if (!other)
+ return false;
+ if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+ if (level == xegpu::LayoutKind::Subgroup)
+ return (getEffectiveSgLayoutAsInt() ==
+ other.getEffectiveSgLayoutAsInt() &&
+ getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
+ if (level == xegpu::LayoutKind::Lane)
+ return (getEffectiveLaneLayoutAsInt() ==
+ other.getEffectiveLaneLayoutAsInt() &&
+ getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
+ }
+ if (level == xegpu::LayoutKind::Subgroup) {
+ int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
+ for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+ auto coords = computeStaticDistributedCoords(id, shape);
+ auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+ if (coords != otherCoords)
+ return false;
+ }
+ }
+ if (level == xegpu::LayoutKind::InstData) {
+ return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
+ }
+ if (level == xegpu::LayoutKind::Lane) {
+ int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
+ for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
+ auto coords = computeStaticDistributedCoords(id, shape);
+ auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+ if (coords != otherCoords)
+ return false;
+ }
+ }
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
@@ -819,12 +910,8 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
return failure();
}
- if (subShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, layout))
- subShape = derivedShape.value();
- else
- return failure();
- }
+ if (subShape.empty())
+ return failure();
// delinearize Ids
auto maybeIds = delinearizeId(builder, loc, linearId);
@@ -834,10 +921,93 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
// The effective sgIds for offsets computing correspond
// to the dims that are not sliced.
ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
- SmallVector<Value> sgIds =
+ SmallVector<Value> canonicalIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
+ return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
+}
+
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
+/// only the non-sliced dimensions for coordinate computation.
+SmallVector<SmallVector<int64_t>>
+SliceAttr::computeStaticDistributedCoords(int64_t linearId,
+ ArrayRef<int64_t> shape) {
+ assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ SmallVector<int64_t> instData;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ instData = getEffectiveInstDataAsInt();
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ }
+ if (!instData.empty()) {
+ linearId = 0;
+ subShape = instData;
+ }
+
+ assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+ // Delinearize the ID using the parent layout (same as the IR version).
+ SliceAttr flattened = flatten();
+ auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
+ SmallVector<int64_t> parentLayoutVec;
+ if (parent.isForWorkgroup())
+ parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
+ else
+ parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
+
+ DenseI32ArrayAttr orderAttr = parent.getOrder();
+ SmallVector<int64_t> order;
+ if (orderAttr && !orderAttr.empty()) {
+ order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+ return static_cast<int64_t>(idx);
+ });
+ } else {
+ order = llvm::to_vector(
+ llvm::reverse(llvm::seq<int64_t>(0, parentLayoutVec.size())));
+ }
+ SmallVector<int64_t> allIds(parentLayoutVec.size());
+ int64_t remaining = linearId;
+ for (size_t i = 0; i < order.size(); ++i) {
+ int64_t dimIdx = order[i];
+ allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
+ if (i < order.size() - 1)
+ remaining = remaining / parentLayoutVec[dimIdx];
+ }
+
+ // The effective IDs for coordinate computation correspond
+ // to the dims that are not sliced.
+ ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
+ SmallVector<int64_t> canonicalIds =
+ XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
+
+ // Compute distribution unit shape (clamped to srcShape).
+ SmallVector<int64_t> distUnitShape(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
+
+ // Compute local offset of this ID within a distribution unit.
+ SmallVector<int64_t> localOffset(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i)
+ localOffset[i] = canonicalIds[i] * subShape[i];
+
+ // Enumerate all distribution units and compute coordinates.
+ SmallVector<SmallVector<int64...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/186958
More information about the Mlir-commits
mailing list