[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance multi-reduction layout propagation rules (PR #186308)

Charitha Saumya llvmlistbot at llvm.org
Tue Mar 17 11:54:33 PDT 2026


================
@@ -437,75 +437,118 @@ xegpu::SliceAttr xegpu::setupMultiReductionResultLayout(
     return DenseI32ArrayAttr::get(context, vec32);
   };
 
-  // Extract original plain layout for workgroup/subgroup size recovery
-  xegpu::SliceAttr consumerSliceLayout =
-      dyn_cast<xegpu::SliceAttr>(consumerLayout);
-  DistributeLayoutAttr plainLayout =
-      consumerSliceLayout ? consumerSliceLayout.flatten().getParent()
-                          : consumerLayout;
+  // Helper lambda to check if the layout from consumer can be reused for the
+  // source shape
+  auto isLayoutCompatibleWithSrcShape =
+      [&](ArrayRef<int64_t> srcShape,
+          xegpu::DistributeLayoutAttr srcLayout) -> bool {
+    SmallVector<int64_t> sgLayout = srcLayout.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> laneLayout = srcLayout.getEffectiveLaneLayoutAsInt();
+    if (static_cast<size_t>(srcLayout.getRank()) != srcShape.size())
+      return false;
+    for (size_t i = 0; i < srcShape.size(); i++) {
+      if (!sgLayout.empty() && srcShape[i] % sgLayout[i] != 0)
+        return false;
+      if (!laneLayout.empty() && srcShape[i] % laneLayout[i] != 0)
+        return false;
+    }
+    return true;
+  };
 
+  // Extract original plain layout for workgroup/subgroup size recovery
+  xegpu::DistributeLayoutAttr rootPlainLayout = consumerLayout;
+  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(rootPlainLayout)) {
+    rootPlainLayout = sliceAttr.getParent();
+  }
+  auto sgLayoutVec = rootPlainLayout.getEffectiveSgLayoutAsInt();
+  const int workgroupSize = std::accumulate(
+      sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
   const int subgroupSize = uArch->getSubgroupSize();
   int64_t maxReduceVectorSize = 1; // could extend to spirv vector Size
 
-  xegpu::DistributeLayoutAttr srcLayout;
+  xegpu::SliceAttr consumerSliceLayout =
+      dyn_cast<xegpu::SliceAttr>(consumerLayout);
+  SmallVector<int64_t> consumerSgLayout =
+      consumerLayout.getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> consumerLaneLayout =
+      consumerLayout.getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> consumerOrder = consumerLayout.getEffectiveOrderAsInt();
+  DenseI32ArrayAttr orderAttr = consumerLayout.getOrder();
 
+  xegpu::DistributeLayoutAttr srcLayout;
   if (layoutKind == xegpu::LayoutKind::Subgroup) {
-    auto sgLayoutVec = plainLayout.getEffectiveSgLayoutAsInt();
-    const int workgroupSize = std::accumulate(
-        sgLayoutVec.begin(), sgLayoutVec.end(), 1, std::multiplies<int64_t>());
-    SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank);
-    SmallVector<int64_t> consumerSgLayout =
-        consumerLayout.getEffectiveSgLayoutAsInt();
-    int remainingSgCount = workgroupSize;
-    int consumerIdx = consumerSgLayout.size() - 1;
-
-    // First pass: Match consumer's layout on non-reduction dimensions
-    for (int i = srcRank - 1; i >= 0; i--) {
-      if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
-        sgLayout[i] = consumerSgLayout[consumerIdx];
-        assert((srcShape[i] % sgLayout[i] == 0) &&
-               "source shape not divisible by consumer sg_layout");
-        sgData[i] = srcShape[i] / sgLayout[i];
-        remainingSgCount /= sgLayout[i];
-        consumerIdx--;
+    if (consumerSliceLayout &&
+        consumerSliceLayout.getDims().asArrayRef().equals(reductionDims) &&
+        isLayoutCompatibleWithSrcShape(srcShape,
+                                       consumerSliceLayout.getParent())) {
+      int64_t sgDataValue = -1;
+      srcLayout = consumerSliceLayout.getParent();
+      SmallVector<int64_t> sgLayoutFromConsumer =
+          srcLayout.getEffectiveSgLayoutAsInt();
+      for (int dim = 0; dim < srcRank; dim++) {
+        sgDataValue = srcShape[dim] / sgLayoutFromConsumer[dim];
+        srcLayout = srcLayout.setDimData(dim, sgDataValue, -1, -1);
       }
-    }
+    } else {
 
-    // Second pass: Distribute remaining subgroups across reduction dimensions
-    for (int i = srcRank - 1; i >= 0; i--) {
-      if (llvm::is_contained(reductionDims, i)) {
-        sgLayout[i] =
-            std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
-        assert((srcShape[i] % sgLayout[i] == 0) &&
-               "source shape not divisible by sg_layout");
-        sgData[i] = srcShape[i] / sgLayout[i];
-        remainingSgCount /= sgLayout[i];
+      SmallVector<int64_t> sgLayout(srcRank), sgData(srcRank), order(srcRank);
+      int remainingSgCount = workgroupSize;
+      int consumerIdx = consumerSgLayout.size() - 1;
+
+      // First pass: Match consumer's layout on non-reduction dimensions
+      for (int i = srcRank - 1; i >= 0; i--) {
+        if (!llvm::is_contained(reductionDims, i) && consumerIdx >= 0) {
+          sgLayout[i] = consumerSgLayout[consumerIdx];
+          assert((srcShape[i] % sgLayout[i] == 0) &&
+                 "source shape not divisible by consumer sg_layout");
+          sgData[i] = srcShape[i] / sgLayout[i];
+          remainingSgCount /= sgLayout[i];
+          order[i] = consumerOrder[consumerIdx];
+          consumerIdx--;
+        }
       }
-    }
 
-    assert(remainingSgCount == 1 && "not all subgroups distributed");
-    srcLayout = xegpu::LayoutAttr::get(
-        context, toInt32Attr(sgLayout), toInt32Attr(sgData),
-        /*inst_data =*/nullptr, /*lane_layout =*/nullptr,
-        /*lane_data =*/nullptr, /*order =*/nullptr);
+      // Second pass: Distribute remaining subgroups across reduction dimensions
+      int64_t remainOrder = consumerSgLayout.size();
+      for (int i = srcRank - 1; i >= 0; i--) {
+        if (llvm::is_contained(reductionDims, i)) {
+          sgLayout[i] =
+              std::min(srcShape[i], static_cast<int64_t>(remainingSgCount));
----------------
charithaintc wrote:

an example in the function comment section will help to understand this logic. 

nit: Can't we do this in a single loop?

https://github.com/llvm/llvm-project/pull/186308


More information about the Mlir-commits mailing list