[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance multi-reduction layout propagation rules (PR #186308)
Artem Kroviakov
llvmlistbot at llvm.org
Wed Mar 18 08:20:48 PDT 2026
================
@@ -437,75 +437,118 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
return DenseI32ArrayAttr::get(context, vec32);
};
- // Extract original plain layout for workgroup/subgroup size recovery
- xegpu::SliceAttr consumerSliceLayout =
- dyn_cast<xegpu::SliceAttr>(consumerLayout);
- DistributeLayoutAttr plainLayout =
- consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
- : consumerLayout;
+ // Helper lambda to check if the layout from consumer can be reused for the
+ // source shape
+ auto isLayoutCompatibleWithSrcShape =
+ [&](ArrayRef<int64_t> srcShape,
+ xegpu::DistributeLayoutAttr srcLayout) -> bool {
+ SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
+ SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
+ if (static_cast<size_t>(srcLayout.getRank()) != srcShape.size())
+ return false;
+ for (size_t i = 0; i < srcShape.size(); i++) {
+ if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
+ return false;
+ if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
+ return false;
+ }
+ return true;
+ };
+ // Extract original plain layout for workgroup/subgroup size recovery
+ xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
+ rootPlainLayout = sliceAttr.getParent();
+ }
+ auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
+ const int workgroupSize = std::accumulate(
+ sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
const int subgroupSize = uArch->getSubgroupSize();
int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
- xegpu::DistributeLayoutAttr srcLayout;
+ xegpu::SliceAttr consumerSliceLayout =
+ dyn_cast<xegpu::SliceAttr>(consumerLayout);
----------------
akroviakov wrote:
Looks like there is no check for the cast result before querying it
https://github.com/llvm/llvm-project/pull/186308
More information about the Mlir-commits
mailing list