[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor layout propagation utilities (PR #179016)

Jianhui Li llvmlistbot at llvm.org
Tue Feb 3 22:16:12 PST 2026


================
@@ -471,6 +468,152 @@ LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
       getLaneData(), getOrder());
 }
 
+// Derive a new layout with sg_data, inst_data and lane_data set to the
+// specified values for the given dimension
+DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
+                                            int64_t instData,
+                                            int64_t laneData) {
+
+  SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
+
+  if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
+    sgDataVec[dim] = sgData;
+  if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
+    instDataVec[dim] = instData;
+  if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
+    laneDataVec[dim] = laneData;
+
+  SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
+  SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
+  SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
+
+  return LayoutAttr::get(
+      getContext(), getSgLayout(),
+      sgDataVec.empty() ? DenseI32ArrayAttr()
+                        : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
+      instDataVec.empty() ? DenseI32ArrayAttr()
+                          : DenseI32ArrayAttr::get(getContext(), instDataVec32),
+      getLaneLayout(),
+      laneDataVec.empty() ? DenseI32ArrayAttr()
+                          : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
+      getOrder());
+}
+
+// Derive a new layout by collapsing groups of dimensions.
+// Each inner array in `dimGroups` specifies a set of adjacent dimensions
+// that are collapsed into a single dimension in the derived layout.
+DistributeLayoutAttr
+LayoutAttr::collapseDims(SmallVector<SmallVector<int64_t>> dimGroups) const {
+
+  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
+  SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
+  SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
+  SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
+  SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
+
+  DenseI32ArrayAttr orderAttr = getOrder();
+  SmallVector<int64_t> orderVec;
+  if (orderAttr && !orderAttr.empty()) {
+    orderVec = llvm::to_vector(
+        llvm::map_range(orderAttr.asArrayRef(),
+                        [](int32_t idx) { return static_cast<int64_t>(idx); }));
+  }
+
+  SmallVector<int64_t> collapsedSgLayout;
+  SmallVector<int64_t> collapsedSgData;
+  SmallVector<int64_t> collapsedInstData;
+  SmallVector<int64_t> collapsedLaneLayout;
+  SmallVector<int64_t> collapsedLaneData;
+  SmallVector<int64_t> collapsedOrder;
+  SetVector<int64_t> coveredDims;
+
+  for (const auto &group : dimGroups) {
+
+    // Collapse by multiplying values across dimension group
+    int64_t collapsedSg = 1, collapsedSgD = 1, collapsedInst = 1;
+    int64_t collapsedLaneL = 1, collapsedLaneD = 1;
+    int64_t collapsedOrderValue = -1;
+    int64_t dimBeforeCurrent = group.front() - 1;
+    for (int64_t dimIdx : group) {
+      // no two groups can cover the same dimension
+      if (!coveredDims.insert(dimIdx))
+        llvm::report_fatal_error(Twine("dimension ") + Twine(dimIdx) +
+                                 " is covered more than once");
+      // dims within group must be adjacent
+      if (dimBeforeCurrent != (dimIdx - 1))
+        llvm::report_fatal_error("dimensions being collapsed must be adjacent");
+      dimBeforeCurrent = dimIdx;
+
+      collapsedSg *= sgLayout[dimIdx];
+      collapsedSgD *= sgData[dimIdx];
+      collapsedInst *= instData[dimIdx];
+      collapsedLaneL *= laneLayout[dimIdx];
+      collapsedLaneD *= laneData[dimIdx];
+      if (!orderVec.empty())
+        collapsedOrderValue = orderVec[dimIdx]; // take the last one's order
+    }
+
+    collapsedSgLayout.push_back(collapsedSg);
+    collapsedSgData.push_back(collapsedSgD);
+    collapsedInstData.push_back(collapsedInst);
+    collapsedLaneLayout.push_back(collapsedLaneL);
+    collapsedLaneData.push_back(collapsedLaneD);
+    collapsedOrder.push_back(collapsedOrderValue);
+  }
+
+  // check covered all dimensions
+  if (coveredDims.size() != sgLayout.size())
+    llvm::report_fatal_error(
+        "not all dimensions are covered in collapseGroups");
+
+  // Create collapsed layout
+  SmallVector<int32_t> collapsedSgLayout32(collapsedSgLayout.begin(),
+                                           collapsedSgLayout.end());
+  SmallVector<int32_t> collapsedSgData32(collapsedSgData.begin(),
+                                         collapsedSgData.end());
+  SmallVector<int32_t> collapsedInstData32(collapsedInstData.begin(),
+                                           collapsedInstData.end());
+  SmallVector<int32_t> collapsedLaneLayout32(collapsedLaneLayout.begin(),
+                                             collapsedLaneLayout.end());
+  SmallVector<int32_t> collapsedLaneData32(collapsedLaneData.begin(),
+                                           collapsedLaneData.end());
+
+  // go through the values inside collapsedOrder, and re-map the order values to
+  // be in range of [0, N-1] where N is the number of dimensions in collapsed
+  // shape
+  SmallVector<int32_t> remappedOrder32;
----------------
Jianhui-Li wrote:

say, we collapse dim 2 and 3 from [1, 2, 3, 4]  to [1, 2, 4]. We need to remap it to [1, 2, 3]. 

https://github.com/llvm/llvm-project/pull/179016


More information about the Mlir-commits mailing list