[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance multi-reduction layout propagation rules (PR #186308)
Charitha Saumya
llvmlistbot at llvm.org
Tue Mar 17 11:54:33 PDT 2026
================
@@ -437,75 +437,118 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
return DenseI32ArrayAttr::get(context, vec32);
};
- // Extract original plain layout for workgroup/subgroup size recovery
- xegpu::SliceAttr consumerSliceLayout =
- dyn_cast<xegpu::SliceAttr>(consumerLayout);
- DistributeLayoutAttr plainLayout =
- consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
- : consumerLayout;
+ // Helper lambda to check if the layout from consumer can be reused for the
+ // source shape
+ auto isLayoutCompatibleWithSrcShape =
+ [&](ArrayRef<int64_t> srcShape,
+ xegpu::DistributeLayoutAttr srcLayout) -> bool {
+ SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
+ if (static_cast<size_t>(srcLayout.getRank()) != srcShape.size())
+ return false;
+ for (size_t i = 0; i < srcShape.size(); i++) {
+ if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
+ return false;
+ if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
+ return false;
+ }
+ return true;
+ };
+ // Extract original plain layout for workgroup/subgroup size recovery
+ xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
+ rootPlainLayout = sliceAttr.getParent();
+ }
+ auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
+ const int workgroupSize = std::accumulate(
+ sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
const int subgroupSize = uArch->getSubgroupSize();
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
- xegpu::DistributeLayoutAttr srcLayout;
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
+ SmallVector<int64_t> consumerSgLayout =
+ consumerLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> consumerLaneLayout =
+ consumerLayout.getEffectiveLaneLayoutAsInt();
+ SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
+ DenseI32ArrayAttr orderAttr = consumerLayout.getOrder();
+ xegpu::DistributeLayoutAttr srcLayout;
if (layoutKind == xegpu::LayoutKind::Subgroup) {
- auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
- const int workgroupSize = std::accumulate(
- sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
- SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
- SmallVector<int64_t> consumerSgLayout =
- consumerLayout.getEffectiveSgLayoutAsInt();
- int remainingSgCount = workgroupSize;
- int consumerIdx = consumerSgLayout.size() - 1;
-
- // First pass: Match consumer's layout on non-reduction dimensions
- for (int i = srcRank - 1; i >= 0; i--) {
- if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
- sgLayout[i] = consumerSgLayout[consumerIdx];
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by consumer sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
- remainingSgCount /= sgLayout[i];
- consumerIdx--;
+ if (consumerSliceLayout &&
+ consumerSliceLayout.getDims().asArrayRef().equals(reductionDims) &&
+ isLayoutCompatibleWithSrcShape(srcShape,
+ consumerSliceLayout.getParent())) {
+ int64_t sgDataValue = -1;
+ srcLayout = consumerSliceLayout.getParent();
+ SmallVector<int64_t> sgLayoutFromConsumer =
+ srcLayout.getEffectiveSgLayoutAsInt();
+ for (int dim = 0; dim < srcRank; dim++) {
+ sgDataValue = srcShape[dim] / sgLayoutFromConsumer[dim];
+ srcLayout = srcLayout.setDimData(dim, sgDataValue, -1, -1);
}
- }
+ } else {
- // Second pass: Distribute remaining subgroups across reduction dimensions
- for (int i = srcRank - 1; i >= 0; i--) {
- if (llvm::is_contained(reductionDims, i)) {
- sgLayout[i] =
- std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
- assert((srcShape[i] % sgLayout[i] == 0) &&
- "source shape not divisible by sg_layout");
- sgData[i] = srcShape[i] / sgLayout[i];
- remainingSgCount /= sgLayout[i];
+ SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
+ int remainingSgCount = workgroupSize;
+ int consumerIdx = consumerSgLayout.size() - 1;
+
+ // First pass: Match consumer's layout on non-reduction dimensions
+ for (int i = srcRank - 1; i >= 0; i--) {
+ if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
+ sgLayout[i] = consumerSgLayout[consumerIdx];
+ assert((srcShape[i] % sgLayout[i] == 0) &&
+ "source shape not divisible by consumer sg_layout");
+ sgData[i] = srcShape[i] / sgLayout[i];
+ remainingSgCount /= sgLayout[i];
+ order[i] = consumerOrder[consumerIdx];
+ consumerIdx--;
+ }
}
- }
- assert(remainingSgCount == 1 && "not all subgroups distributed");
- srcLayout = xegpu::LayoutAttr::get(
- context, toInt32Attr(sgLayout), toInt32Attr(sgData),
- /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
- /*lane_data =*/nullptr, /*order =*/nullptr);
+ // Second pass: Distribute remaining subgroups across reduction dimensions
+ int64_t remainOrder = consumerSgLayout.size();
+ for (int i = srcRank - 1; i >= 0; i--) {
+ if (llvm::is_contained(reductionDims, i)) {
+ sgLayout[i] =
+ std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
----------------
charithaintc wrote:
an example in the function comment section will help to understand this logic.
nit: Can't we do this in a single loop?
https://github.com/llvm/llvm-project/pull/186308
More information about the Mlir-commits
mailing list