[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance multi-reduction layout propagation rules (PR #186308)

Artem Kroviakov llvmlistbot at llvm.org
Wed Mar 18 08:20:48 PDT 2026


================
@@ -437,75 +437,118 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     return DenseI32ArrayAttr::get(context, vec32);
   };
 
-  // Extract original plain layout for workgroup/subgroup size recovery
-  xegpu::SliceAttr consumerSliceLayout =
-      dyn_cast<xegpu::SliceAttr>(consumerLayout);
-  DistributeLayoutAttr plainLayout =
-      consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
-                          : consumerLayout;
+  // Helper lambda to check if the layout from consumer can be reused for the
+  // source shape
+  auto isLayoutCompatibleWithSrcShape =
+      [&](ArrayRef<int64_t> srcShape,
+          xegpu::DistributeLayoutAttr srcLayout) -> bool {
+    SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
+    if (static_cast<size_t>(srcLayout.getRank()) != srcShape.size())
+      return false;
+    for (size_t i = 0; i < srcShape.size(); i++) {
+      if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
+        return false;
+      if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
+        return false;
+    }
+    return true;
+  };
 
+  // Extract original plain layout for workgroup/subgroup size recovery
+  xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
+  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
+    rootPlainLayout = sliceAttr.getParent();
+  }
+  auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
+  const int workgroupSize = std::accumulate(
+      sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
   const int subgroupSize = uArch->getSubgroupSize();
   int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
 
-  xegpu::DistributeLayoutAttr srcLayout;
+  xegpu::SliceAttr consumerSliceLayout =
+      dyn_cast<xegpu::SliceAttr>(consumerLayout);
----------------
akroviakov wrote:

Looks like there is no check for the cast result before querying it

https://github.com/llvm/llvm-project/pull/186308


More information about the Mlir-commits mailing list