[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (PR #186958)

Charitha Saumya llvmlistbot at llvm.org
Thu Mar 19 10:48:22 PDT 2026


================
@@ -871,6 +1041,50 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
           (flattenedThis.getDims() == flattenedOther.getDims()));
 }
 
+bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                 SmallVector<int64_t> shape,
+                                 xegpu::LayoutKind level) {
+  if (!other)
+    return false;
+  if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+    // short cut when order is the same, no need to compute coords and compare
+    if (level == xegpu::LayoutKind::Subgroup)
+      if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
+          getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
+        return true;
+    if (level == xegpu::LayoutKind::Lane)
+      if (getEffectiveLaneLayoutAsInt() ==
+              other.getEffectiveLaneLayoutAsInt() &&
+          getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
+        return true;
+  }
+
+  auto flattenedThis = flatten();
+  auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
+  if (level == xegpu::LayoutKind::Subgroup) {
+    int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  if (level == xegpu::LayoutKind::InstData) {
+    return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
+  }
+  if (level == xegpu::LayoutKind::Lane) {
+    int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  return true;
----------------
charithaintc wrote:

looks similar logic. Could be moved to a helper. 

https://github.com/llvm/llvm-project/pull/186958


More information about the Mlir-commits mailing list