[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance XeGPU lane layout to support "wrap-around" distribution (PR #186958)

Jianhui Li llvmlistbot at llvm.org
Thu Mar 19 15:08:40 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/186958

>From 78033e901efab7bbc61d97c3ba712dbe0bbcd192 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 14 Mar 2026 21:13:31 +0000
Subject: [PATCH 1/7] add computeDistributeLayout and use them in two
 distribution process

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 57 ++++++++++++
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 87 +++++++++++++++++--
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 44 ++++++----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 56 +++++++-----
 mlir/test/Dialect/XeGPU/invalid.mlir          |  8 --
 mlir/test/Dialect/XeGPU/layout.mlir           | 10 +++
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 20 ++---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 16 ++--
 8 files changed, 225 insertions(+), 73 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index ce0cce65373e5..7eba66d4485e2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -270,6 +270,63 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "computeDistributedCoords",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+    InterfaceMethod<[{Computes the per-compute-unit shape by dividing each dimension of
+                      `shape` by the corresponding layout factor (sg_layout or
+                      lane_layout). For wrap-around dimensions where the division is uneven,
+                      the tensor tile is broadcasted to all subgroups/lanes.}],
+                    /*retTy=*/"FailureOr<SmallVector<int64_t>>",
+                    /*methodName=*/"computeDistributedShape",
+                    /*args=*/(ins "SmallVector<int64_t>":$shape),
+                    /*methodBody=*/[{
+                      SmallVector<int64_t> layout;
+                      SmallVector<int64_t> subShape;
+                      if ($_self.isForWorkgroup()) {
+                        layout = $_self.getEffectiveSgLayoutAsInt();
+                        subShape = $_self.getEffectiveSgDataAsInt();
+                      } else if ($_self.isForSubgroup()) {
+                        layout = $_self.getEffectiveLaneLayoutAsInt();
+                        subShape = $_self.getEffectiveLaneDataAsInt();
+                      } else {
+                        return failure();
+                      }
+                      assert(
+                          !subShape.empty() &&
+                          "sgdata or lanedata cannot be empty for distributed shape computation");
+
+                      SmallVector<int64_t> distributedShape(shape.size());
+                      llvm::errs() << "computeDistributedShape:\n";
+                      llvm::errs() << "  shape: [";
+                        llvm::interleaveComma(shape, llvm::errs());
+                        llvm::errs() << "]\n";
+                      llvm::errs() << "  layout: [";
+                        llvm::interleaveComma(layout, llvm::errs());
+                        llvm::errs() << "]\n";
+                      llvm::errs() << "  subShape: [";
+                        llvm::interleaveComma(subShape, llvm::errs());
+                        llvm::errs() << "]\n";
+                      for (auto [i, dim] : llvm::enumerate(shape)) {
+                        int64_t distri_unit = layout[i]*subShape[i];
+                        if ((dim % distri_unit) == 0) {
+                          // Evenly divisible case, divide the dimension by the layout factor.
+                          distributedShape[i] = dim / layout[i];
+                          assert((distributedShape[i] % subShape[i] == 0) &&
+                                "Even distribution: sgdata or lanedata must divide the distributed dimension");
+                          llvm::errs() << "  dim[" << i << "]=" << dim
+                            << " evenly divisible, distributed=" << distributedShape[i] << "\n";
+                        } else {
+                          // wrap around case, the dimension size must be equal to subShape value
+                          assert(dim == subShape[i] &&
+                                "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+                          distributedShape[i] = dim;
+                          llvm::errs() << "  dim[" << i << "]=" << dim
+                            << " wrap-around, kept=" << distributedShape[i] << "\n";
+                        }
+                      }
+                      llvm::errs() << "  result: [";
+                        llvm::interleaveComma(distributedShape, llvm::errs());
+                        llvm::errs() << "]\n";
+                      return distributedShape;
+                    }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index f9aa01aca7172..4402e0674208e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -160,7 +160,7 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
   // check LaneLayout and LaneData
   auto maybeLaneShape =
       tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
-                    attr.getEffectiveLaneDataAsInt(), false);
+                    attr.getEffectiveLaneDataAsInt(), true);
   return maybeLaneShape.has_value();
 }
 
@@ -205,6 +205,83 @@ LogicalResult ScatterTensorDescAttr::verify(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPU_DistributeLayoutAttr
+//===----------------------------------------------------------------------===//
+
+/// Computes the per-compute-unit shape by dividing each dimension of
+/// `shape` by the corresponding layout factor (sg_layout or
+/// lane_layout). For wrap-around dimensions where the division is uneven,
+/// the tensor tile is broadcasted to all subgroups/lanes.
+// FailureOr<SmallVector<int64_t>>
+// DistributeLayoutAttr::computeDistributedShape(SmallVector<int64_t> shape)
+// const {
+
+//   SmallVector<int64_t> layout;
+//   SmallVector<int64_t> subShape;
+//   if (isForWorkgroup()) {
+//     layout = getEffectiveSgLayoutAsInt();
+//     subShape = getEffectiveSgDataAsInt();
+//   } else if (isForSubgroup()) {
+//     layout = getEffectiveLaneLayoutAsInt();
+//     subShape = getEffectiveLaneDataAsInt();
+//   } else {
+//     return failure();
+//   }
+//   assert(
+//       !subShape.empty() &&
+//       "sgdata or lanedata cannot be empty for distributed shape
+//       computation");
+
+//   SmallVector<int64_t> distributedShape(shape);
+//   for (auto [i, dim] : llvm::enumerate(shape)) {
+//     if (dim % layout[i] != 0) {
+//       // wrap around case, the dimension size must be equal to subShape value
+//       assert(dim == subShape[i] &&
+//              "Wrap-around distribution: sgdata or lanedata must be same as "
+//              "tensor tile shape");
+//       distributedShape[i] = dim;
+//     } else {
+//       // Evenly divisible case, divide the dimension by the layout factor.
+//       distributedShape[i] = dim / layout[i];
+//       assert(distributedShape[i] % subShape[i] == 0 &&
+//              "Even distribution: sgdata or lanedata must divide the "
+//              "distributed dimension");
+//     }
+//   }
+//   return distributedShape;
+// }
+
+// bool DistributeLayoutAttr::isCompatibleWith(
+//     const xegpu::DistributeLayoutAttr &other, xegpu::LayoutKind level) {
+//   if (!other)
+//     return false;
+//   switch (level) {
+//   case xegpu::LayoutKind::Subgroup:
+//     if (getEffectiveSgLayoutAsInt() ==
+//                other.getEffectiveSgLayoutAsInt() &&
+//            getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt() &&
+//           getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+//       return true;
+//     }
+//   case xegpu::LayoutKind::InstData:
+//     if (getEffectiveInstDataAsInt() ==
+//            other.getEffectiveInstDataAsInt()) {
+//       return true;
+//     }
+//   case xegpu::LayoutKind::Lane:
+//     if (getEffectiveLaneLayoutAsInt() ==
+//                other.getEffectiveLaneLayoutAsInt() &&
+//            getEffectiveLaneDataAsInt() ==
+//                other.getEffectiveLaneDataAsInt() &&
+//                getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+//       return true;
+//     }
+//   }
+
+//   return false;
+// }
+
 //===----------------------------------------------------------------------===//
 // XeGPU_LayoutAttr
 //===----------------------------------------------------------------------===//
@@ -373,12 +450,8 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   } else {
     return failure();
   }
-  if (subShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, layout))
-      subShape = derivedShape.value();
-    else
-      return failure();
-  }
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
+                              "distributed coordinates computation");
 
   // delinearize Ids
   auto maybeIds = delinearizeId(builder, loc, linearId);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9f8dbc15f6422..c9998093171c6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -52,21 +52,25 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
                    xegpu::DistributeLayoutAttr layout) {
   int count = 1;
   SmallVector<int64_t> sgShape(shape);
-  if (layout && layout.isForWorkgroup()) {
-    SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
-    if (!layout.getEffectiveSgDataAsInt().empty())
-      sgShape = layout.getEffectiveSgDataAsInt();
-    else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout))
-      sgShape = *maybeDerivedSgData;
-    SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape);
-    // Clamp distUnit to the original shape to handle cases where data is
-    // shared among subgroups, which may cause distUnit to exceed the original
-    // shape.
-    for (size_t i = 0; i < distUnit.size(); ++i)
-      distUnit[i] = std::min(shape[i], distUnit[i]);
-    count = computeProduct(shape) / computeProduct(distUnit);
-  }
-  return std::make_pair(sgShape, count);
+  auto distributedShape = layout.computeDistributedShape(
+      SmallVector<int64_t>(shape.begin(), shape.end()));
+  if (failed(distributedShape))
+    return std::make_pair(sgShape, count);
+  auto sgData = layout.getEffectiveSgDataAsInt();
+  count = computeProduct(distributedShape.value()) / computeProduct(sgData);
+  // auto sgLayout = layout.getEffectiveSgLayoutAsInt();
+  // auto sgData = layout.getEffectiveSgDataAsInt();
+  // SmallVector<int64_t> distUnit =
+  //     computeElementwiseMul(sgLayout, sgData);
+  // // Clamp distUnit to the original shape to handle wrap-around cases where
+  // data
+  // // is shared among subgroups, which may cause distUnit to exceed the
+  // original
+  // // shape.
+  // for (size_t i = 0; i < distUnit.size(); ++i)
+  //   distUnit[i] = std::min(shape[i], distUnit[i]);
+  // count = computeProduct(shape) / computeProduct(distUnit);
+  return std::make_pair(sgData, count);
 }
 
 /// Utility helper for deriving a list of offsets for each sub-TensorDescs
@@ -1712,16 +1716,20 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
   converter.addConversion(
       [&](xegpu::TensorDescType type,
           SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
+        xegpu::LayoutAttr layout = type.getLayoutAttr();
+        // Only convert WG-level tensor descs. SG-level or layout-less types
+        // are already legal and should pass through unchanged.
+        if (!layout || !layout.isForWorkgroup())
+          return std::nullopt;
+
         Type elemTy = type.getElementType();
         ArrayRef<int64_t> shape = type.getShape();
 
         int count;
         SmallVector<int64_t> subShape;
-        xegpu::LayoutAttr layout = type.getLayoutAttr();
         std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
 
-        if (layout)
-          layout = layout.dropSgLayoutAndData();
+        layout = layout.dropSgLayoutAndData();
 
         auto newTy = xegpu::TensorDescType::get(
             type.getContext(), subShape, elemTy, type.getEncoding(), layout);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index a57bf8512ddec..9884a46a7d5b8 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -103,33 +103,45 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
 
+// FailureOr<VectorType>
+// xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
+//                                        VectorType originalType) {
+//   if (!layout)
+//     return failure();
+//   assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
+//          "Expecting a valid layout.");
+//   SmallVector<int64_t> effectiveLaneLayout =
+//       layout.getEffectiveLaneLayoutAsInt();
+//   assert(static_cast<size_t>(originalType.getRank()) >=
+//              effectiveLaneLayout.size() &&
+//          "Rank of the original vector type should be greater or equal to the
+//          " "size of the lane layout to distribute the vector type.");
+//   SmallVector<int64_t> distributedShape(originalType.getShape());
+//   // Only distribute the last `laneLayout.size()` dimensions. The remaining
+//   // dimensions are not distributed.
+//   unsigned distributionStart =
+//       originalType.getRank() - effectiveLaneLayout.size();
+//   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
+//     if (i < distributionStart)
+//       continue;
+//     // Check if the dimension can be distributed evenly.
+//     if (dim % effectiveLaneLayout[i - distributionStart] != 0)
+//       return failure();
+//     distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
+//   }
+//   return VectorType::get(distributedShape, originalType.getElementType());
+// }
+
 FailureOr<VectorType>
 xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
                                        VectorType originalType) {
   if (!layout)
     return failure();
-  assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
-         "Expecting a valid layout.");
-  SmallVector<int64_t> effectiveLaneLayout =
-      layout.getEffectiveLaneLayoutAsInt();
-  assert(static_cast<size_t>(originalType.getRank()) >=
-             effectiveLaneLayout.size() &&
-         "Rank of the original vector type should be greater or equal to the "
-         "size of the lane layout to distribute the vector type.");
-  SmallVector<int64_t> distributedShape(originalType.getShape());
-  // Only distribute the last `laneLayout.size()` dimensions. The remaining
-  // dimensions are not distributed.
-  unsigned distributionStart =
-      originalType.getRank() - effectiveLaneLayout.size();
-  for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
-    if (i < distributionStart)
-      continue;
-    // Check if the dimension can be distributed evenly.
-    if (dim % effectiveLaneLayout[i - distributionStart] != 0)
-      return failure();
-    distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
-  }
-  return VectorType::get(distributedShape, originalType.getElementType());
+  auto distributedShape = layout.computeDistributedShape(
+      SmallVector<int64_t>(originalType.getShape()));
+  if (failed(distributedShape))
+    return failure();
+  return VectorType::get(*distributedShape, originalType.getElementType());
 }
 
 std::string xegpu::getTemporaryLayoutName(const OpOperand &operand) {
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 53d497e4c2087..74c505e6dc6be 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -657,14 +657,6 @@ func.func @tensor_desc_invalid_map_layout_1(%src: memref<24x32xf32>) {
   return
 }
 
-// -----
-func.func @tensor_desc_invalid_map_data(%src: memref<24x32xf32>) {
-  %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
-      // expected-error at +1 {{cannot distribute [4, 8] using #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>}}
-      !xegpu.tensor_desc<4x8xf32,  #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
-  return
-}
-
 // -----
 func.func @tensor_desc_invalid_map_data_1(%src: memref<24x32xf32>) {
   %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
diff --git a/mlir/test/Dialect/XeGPU/layout.mlir b/mlir/test/Dialect/XeGPU/layout.mlir
index e4b4e22e5cf97..29670d0b5aadd 100644
--- a/mlir/test/Dialect/XeGPU/layout.mlir
+++ b/mlir/test/Dialect/XeGPU/layout.mlir
@@ -27,6 +27,16 @@ gpu.func @create_nd_tdesc_subgroup_3(%src: memref<128x128xf32>) {
   gpu.return
 }
 
+
+// -----
+// CHECK: func.func @create_nd_tdesc_wrap_around_layout(%[[arg0:.*]]: memref<24x32xf32>) {
+func.func @create_nd_tdesc_wrap_around_layout(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<4x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
+    %0 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> ->
+      !xegpu.tensor_desc<4x8xf32,  #xegpu.layout<lane_layout = [2, 8], lane_data = [4, 1]>>
+  return
+}
+
 // CHECK: gpu.func @create_nd_tdesc_wg_1(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @create_nd_tdesc_wg_1(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [3, 2], sg_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 9e0ae881c8a7e..ecc5fe3dd75e0 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -104,28 +104,28 @@ gpu.module @test_distribution {
 
   // CHECK-LABEL: dpas_no_sg_data
   gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
-    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
     %tdesc_a = xegpu.create_nd_tdesc %a : memref<128x128xf16>
-      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
-    %load_a =  xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+    %load_a =  xegpu.load_nd %tdesc_a[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>}
-      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
       -> vector<128x128xf16>
     %tdesc_b = xegpu.create_nd_tdesc %b : memref<128x128xf16>
-      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>>
-    %load_b =  xegpu.load_nd %tdesc_b[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]> }
-      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+    %load_b =  xegpu.load_nd %tdesc_b[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]> }
+      : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>>
       -> vector<128x128xf16>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout_a = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>,
-       layout_b = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+       layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>,
-       layout_cd =  #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+       layout_cd =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16],  lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
       : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
     gpu.return
   }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 61ca028dd3ea1..df3fa880c9d6d 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -107,22 +107,22 @@ gpu.module @test_1_1_assignment {
   // CHECK-LABEL: dpas_no_sg_data
   gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
     %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<128x128xf16>
-      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
-    %load_a =  xegpu.load_nd %tdesc_a {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1],
+    %load_a =  xegpu.load_nd %tdesc_a {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1],
       order = [1, 0]>>
       -> vector<128x128xf16>
     %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x128xf16>
-      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+      -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>>
-    %load_b =  xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1],
+    %load_b =  xegpu.load_nd %tdesc_b {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>} : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1],
       order = [1, 0]>>
       -> vector<128x128xf16>
-    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x16xf16>, vector<16x16xf16> -> vector<16x16xf32>
+    // CHECK: %[[DPAS:.*]] = xegpu.dpas %{{.*}}, %{{.*}} {layout_a = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>, layout_b = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>, layout_cd = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<16x128xf16>, vector<128x16xf16> -> vector<16x16xf32>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout_a = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>,
-       layout_b = #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>,
-       layout_cd =  #xegpu.layout<sg_layout = [8, 8], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
+      {layout_a = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>,
+       layout_b = #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16], lane_layout = [1, 16], lane_data = [2, 1], order = [1, 0]>,
+       layout_cd =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>}
       : vector<128x128xf16>, vector<128x128xf16> -> vector<128x128xf32>
     gpu.return
   }

>From b072030f9eb76a23aa4cffd293e0b505b50c5798 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sun, 15 Mar 2026 05:54:44 +0000
Subject: [PATCH 2/7] saving work

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  56 ++-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 342 ++++++++++++------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   4 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |   2 +-
 4 files changed, 279 insertions(+), 125 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 7eba66d4485e2..04b058a91060f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -270,6 +270,15 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     "FailureOr<SmallVector<SmallVector<Value>>>",
                     "computeDistributedCoords",
                     (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+    InterfaceMethod<[{Statically computes multidimensional coordinates for all dist units
+                      assigned to a compute unit identified by `linearId`. This is the
+                      compile-time counterpart of `computeDistributedCoords`: it performs
+                      the same delinearization and round-robin enumeration but operates
+                      entirely on static integer values. Returns a list of coordinate
+                      vectors, one per dist unit.}],
+                    /*retTy=*/"SmallVector<SmallVector<int64_t>>",
+                    /*methodName=*/"computeStaticDistributedCoords",
+                    /*args=*/(ins "int64_t":$linearId, "ArrayRef<int64_t>":$shape)>,
     InterfaceMethod<[{Computes the per-compute-unit shape by dividing each dimension of
                       `shape` by the corresponding layout factor (sg_layout or
                       lane_layout). For wrap-around dimensions where the division is uneven,
@@ -339,28 +348,12 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                                   "ArrayRef<int64_t>": $perm,
                                   "xegpu::LayoutKind": $kind)>,
     InterfaceMethod</*desc=*/[{Check if this layout is compatible with another layout
-                     at a specific level of the layout hierarchy. Unlike isEqualTo,
-                     this compares only the effective (non-sliced) fields at the
-                     requested level.}],
+                     at a specific level of the layout hierarchy regarding a given shape. }],
                     /*retTy=*/"bool",
                     /*methodName=*/"isCompatibleWith",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
-                                  "xegpu::LayoutKind": $level),
-                    /*methodBody=*/[{
-                      if (!other)
-                        return false;
-                      switch (level) {
-                        case xegpu::LayoutKind::Subgroup:
-                          return $_self.getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
-                                 $_self.getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt();
-                        case xegpu::LayoutKind::InstData:
-                          return $_self.getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt();
-                        case xegpu::LayoutKind::Lane:
-                          return $_self.getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
-                                 $_self.getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
-                      }
-                      return false;
-                    }]>,
+                                  "SmallVector<int64_t>: %shape",
+                                  "xegpu::LayoutKind": $level)>,
     InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
                      For LayoutAttr, this compares all fields.
                      For SliceAttr, this requires the same parent and same sliced dims.}],
@@ -616,12 +609,24 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     FailureOr<SmallVector<SmallVector<Value>>>
     computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    ///Statically computes multidimensional coordinates for all dist units
+    ///assigned to a compute unit identified by `linearId`. This is the
+    ///compile-time counterpart of `computeDistributedCoords`.
+    SmallVector<SmallVector<int64_t>>
+    computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
+
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
 
     /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
 
+    /// Check if this layout is compatible with another layout 
+    /// at a specific level of the layout hierarchy regarding a given shape.
+    bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level);
+
     /// Check if this layout is a transpose of another layout.
     bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
   }];
@@ -829,16 +834,27 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     /// assigned to a subgroup identified by linearId. The shape parameter
     /// represents the workgroup-level problem size. Each subgroup may access
     /// multiple blocks according to round-robin distribution rules.
-
     FailureOr<SmallVector<SmallVector<Value>>>
     computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
 
+    ///Statically computes multidimensional coordinates for all dist units
+    ///assigned to a compute unit identified by `linearId`. This is the
+    ///compile-time counterpart of `computeDistributedCoords`.
+    SmallVector<SmallVector<int64_t>>
+    computeStaticDistributedCoords(int64_t linearId, ArrayRef<int64_t> shape);
+
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
 
     /// Check if this layout is equal to another layout.
     bool isEqualTo(const xegpu::DistributeLayoutAttr &other);
 
+    /// Check if this layout is compatible with another layout 
+    /// at a specific level of the layout hierarchy regarding a given shape.
+    bool isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level);
+
     /// Check if this layout is a transpose of another layout.
     bool isTransposeOf(const xegpu::DistributeLayoutAttr &other, ArrayRef<int64_t> perm, const xegpu::LayoutKind kind);
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 4402e0674208e..bd221c03757f8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -205,83 +205,6 @@ LogicalResult ScatterTensorDescAttr::verify(
   return success();
 }
 
-//===----------------------------------------------------------------------===//
-// XeGPU_DistributeLayoutAttr
-//===----------------------------------------------------------------------===//
-
-/// Computes the per-compute-unit shape by dividing each dimension of
-/// `shape` by the corresponding layout factor (sg_layout or
-/// lane_layout). For wrap-around dimensions where the division is uneven,
-/// the tensor tile is broadcasted to all subgroups/lanes.
-// FailureOr<SmallVector<int64_t>>
-// DistributeLayoutAttr::computeDistributedShape(SmallVector<int64_t> shape)
-// const {
-
-//   SmallVector<int64_t> layout;
-//   SmallVector<int64_t> subShape;
-//   if (isForWorkgroup()) {
-//     layout = getEffectiveSgLayoutAsInt();
-//     subShape = getEffectiveSgDataAsInt();
-//   } else if (isForSubgroup()) {
-//     layout = getEffectiveLaneLayoutAsInt();
-//     subShape = getEffectiveLaneDataAsInt();
-//   } else {
-//     return failure();
-//   }
-//   assert(
-//       !subShape.empty() &&
-//       "sgdata or lanedata cannot be empty for distributed shape
-//       computation");
-
-//   SmallVector<int64_t> distributedShape(shape);
-//   for (auto [i, dim] : llvm::enumerate(shape)) {
-//     if (dim % layout[i] != 0) {
-//       // wrap around case, the dimension size must be equal to subShape value
-//       assert(dim == subShape[i] &&
-//              "Wrap-around distribution: sgdata or lanedata must be same as "
-//              "tensor tile shape");
-//       distributedShape[i] = dim;
-//     } else {
-//       // Evenly divisible case, divide the dimension by the layout factor.
-//       distributedShape[i] = dim / layout[i];
-//       assert(distributedShape[i] % subShape[i] == 0 &&
-//              "Even distribution: sgdata or lanedata must divide the "
-//              "distributed dimension");
-//     }
-//   }
-//   return distributedShape;
-// }
-
-// bool DistributeLayoutAttr::isCompatibleWith(
-//     const xegpu::DistributeLayoutAttr &other, xegpu::LayoutKind level) {
-//   if (!other)
-//     return false;
-//   switch (level) {
-//   case xegpu::LayoutKind::Subgroup:
-//     if (getEffectiveSgLayoutAsInt() ==
-//                other.getEffectiveSgLayoutAsInt() &&
-//            getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt() &&
-//           getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
-//       return true;
-//     }
-//   case xegpu::LayoutKind::InstData:
-//     if (getEffectiveInstDataAsInt() ==
-//            other.getEffectiveInstDataAsInt()) {
-//       return true;
-//     }
-//   case xegpu::LayoutKind::Lane:
-//     if (getEffectiveLaneLayoutAsInt() ==
-//                other.getEffectiveLaneLayoutAsInt() &&
-//            getEffectiveLaneDataAsInt() ==
-//                other.getEffectiveLaneDataAsInt() &&
-//                getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
-//       return true;
-//     }
-//   }
-
-//   return false;
-// }
-
 //===----------------------------------------------------------------------===//
 // XeGPU_LayoutAttr
 //===----------------------------------------------------------------------===//
@@ -315,25 +238,17 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
                        << lane_layout.size();
   }
 
-  // sg_data is optional for Workgroup layout, but its presence requires
-  // sg_layout.
-  if (sg_data) {
-    if (!sg_layout)
-      return emitError() << "expected sg_layout being used with sg_data";
-    if (sg_data.size() != sg_layout.size())
-      return emitError()
-             << "expected sg_data and sg_layout to have the same rank";
-  }
+  if ((sg_data && !sg_layout) || (!sg_data && sg_layout))
+    return emitError() << "sg_layout and sg_data must be used together";
+  if (sg_data.size() != sg_layout.size())
+    return emitError()
+           << "expected sg_data and sg_layout to have the same rank";
 
-  // lane_data is optional for Subgroup layout, but its presence requires
-  // lane_layout.
-  if (lane_data) {
-    if (!lane_layout)
-      return emitError() << "expected lane_layout being used with lane_data";
-    if (lane_data.size() != lane_layout.size())
-      return emitError()
-             << "expected lane_data and lane_layout to have the same rank";
-  }
+  if ((lane_data && !lane_layout) || (!lane_data && lane_layout))
+    return emitError() << "lane_layout and lane_data must be used together";
+  if (lane_data.size() != lane_layout.size())
+    return emitError()
+           << "expected lane_data and lane_layout to have the same rank";
 
   if (order) {
     if (!sg_layout && !lane_layout)
@@ -469,6 +384,70 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
   return *this == dyn_cast<xegpu::LayoutAttr>(other);
 }
 
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// LayoutAttr.
+SmallVector<SmallVector<int64_t>>
+LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
+                                           ArrayRef<int64_t> shape) {
+  SmallVector<int64_t> layoutVec;
+  SmallVector<int64_t> subShape;
+  SmallVector<int64_t> instData;
+  if (isForWorkgroup()) {
+    layoutVec = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    instData = getEffectiveInstDataAsInt();
+    layoutVec = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  }
+  if (!instData.empty()) {
+    linearId = 0;
+    subShape = instData;
+  }
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+  // Delinearize the linear ID using the order attribute.
+  DenseI32ArrayAttr orderAttr = getOrder();
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+      return static_cast<int64_t>(idx);
+    });
+  } else {
+    order =
+        llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, layoutVec.size())));
+  }
+  SmallVector<int64_t> delinearizedId(layoutVec.size());
+  int64_t remaining = linearId;
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
+    remaining = remaining / layoutVec[dimIdx];
+  }
+
+  // Compute distribution unit shape (clamped to srcShape).
+  SmallVector<int64_t> distUnitShape(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    distUnitShape[i] = std::min(shape[i], layoutVec[i] * subShape[i]);
+
+  // Compute local offset of this ID within a distribution unit.
+  SmallVector<int64_t> localOffset(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    localOffset[i] = delinearizedId[i] * subShape[i];
+
+  // Enumerate all distribution units and compute coordinates.
+  SmallVector<SmallVector<int64_t>> coordinates;
+  for (SmallVector<int64_t> unitOffs :
+       StaticTileOffsetRange(shape, distUnitShape)) {
+    SmallVector<int64_t> coord(shape.size());
+    for (size_t i = 0; i < shape.size(); ++i)
+      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+    coordinates.push_back(coord);
+  }
+  return coordinates;
+}
+
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
 DistributeLayoutAttr
 LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
@@ -816,6 +795,44 @@ bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
   return false;
 }
 
+bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level) {
+  if (!other)
+    return false;
+  if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+    if (level == xegpu::LayoutKind::Subgroup)
+      return (getEffectiveSgLayoutAsInt() ==
+                  other.getEffectiveSgLayoutAsInt() &&
+              getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
+    if (level == xegpu::LayoutKind::Lane)
+      return (getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
+        getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
+  }
+  if (level == xegpu::LayoutKind::Subgroup) {
+    int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  if (level == xegpu::LayoutKind::InstData) {
+    return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
+  }
+  if (level == xegpu::LayoutKind::Lane) {
+    int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -892,12 +909,8 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
     return failure();
   }
 
-  if (subShape.empty()) {
-    if (auto derivedShape = computeShapeRatio(shape, layout))
-      subShape = derivedShape.value();
-    else
-      return failure();
-  }
+  if (subShape.empty())
+    return failure();
 
   // delinearize Ids
   auto maybeIds = delinearizeId(builder, loc, linearId);
@@ -907,10 +920,92 @@ SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   // The effective sgIds for offsets computing correspond
   // to the dims that are not sliced.
   ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
-  SmallVector<Value> sgIds =
+  SmallVector<Value> canonicalIds =
       XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
 
-  return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
+  return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
+}
+
+/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
+/// compute multi-dimensional offsets for a given linear ID when distributed by
+/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
+/// only the non-sliced dimensions for coordinate computation.
+SmallVector<SmallVector<int64_t>>
+SliceAttr::computeStaticDistributedCoords(int64_t linearId,
+                                          ArrayRef<int64_t> shape) {
+  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+
+  SmallVector<int64_t> layout;
+  SmallVector<int64_t> subShape;
+  if (isForWorkgroup()) {
+    layout = getEffectiveSgLayoutAsInt();
+    subShape = getEffectiveSgDataAsInt();
+  } else if (isForSubgroup()) {
+    instData = getEffectiveInstDataAsInt();
+    layoutVec = getEffectiveLaneLayoutAsInt();
+    subShape = getEffectiveLaneDataAsInt();
+  }
+  if (!instData.empty()) {
+    linearId = 0;
+    subShape = instData;
+  }
+
+  assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
+
+  // Delinearize the ID using the parent layout (same as the IR version).
+  SliceAttr flattened = flatten();
+  auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
+  SmallVector<int64_t> parentLayoutVec;
+  if (parent.isForWorkgroup())
+    parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
+  else
+    parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
+
+  DenseI32ArrayAttr orderAttr = parent.getOrder();
+  SmallVector<int64_t> order;
+  if (orderAttr && !orderAttr.empty()) {
+    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
+      return static_cast<int64_t>(idx);
+    });
+  } else {
+    order = llvm::to_vector(
+        llvm::reverse(llvm::seq<int64_t>(0, parentLayoutVec.size())));
+  }
+  SmallVector<int64_t> allIds(parentLayoutVec.size());
+  int64_t remaining = linearId;
+  for (size_t i = 0; i < order.size(); ++i) {
+    int64_t dimIdx = order[i];
+    allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
+    if (i < order.size() - 1)
+      remaining = remaining / parentLayoutVec[dimIdx];
+  }
+
+  // The effective IDs for coordinate computation correspond
+  // to the dims that are not sliced.
+  ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
+  SmallVector<int64_t> canonicalIds =
+      XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
+
+  // Compute distribution unit shape (clamped to srcShape).
+  SmallVector<int64_t> distUnitShape(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
+
+  // Compute local offset of this ID within a distribution unit.
+  SmallVector<int64_t> localOffset(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    localOffset[i] = canonicalIds[i] * subShape[i];
+
+  // Enumerate all distribution units and compute coordinates.
+  SmallVector<SmallVector<int64_t>> coordinates;
+  for (SmallVector<int64_t> unitOffs :
+       StaticTileOffsetRange(shape, distUnitShape)) {
+    SmallVector<int64_t> coord(shape.size());
+    for (size_t i = 0; i < shape.size(); ++i)
+      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+    coordinates.push_back(coord);
+  }
+  return coordinates;
 }
 
 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
@@ -944,6 +1039,47 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
           (flattenedThis.getDims() == flattenedOther.getDims()));
 }
 
+bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                  SmallVector<int64_t> shape,
+                                  xegpu::LayoutKind level) {
+  if (!other)
+    return false;
+  if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+    if (level == xegpu::LayoutKind::Subgroup)
+      return (getEffectiveSgLayoutAsInt() ==
+                  other.getEffectiveSgLayoutAsInt() &&
+              getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
+    if (level == xegpu::LayoutKind::Lane)
+      return (getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
+        getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
+  }
+
+  auto flattenedThis = flatten();
+  auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
+  if (level == xegpu::LayoutKind::Subgroup) {
+    int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  if (level == xegpu::LayoutKind::InstData) {
+    return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
+  }
+  if (level == xegpu::LayoutKind::Lane) {
+    int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
+    for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
+      auto coords = computeStaticDistributedCoords(id, shape);
+      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      if (coords != otherCoords)
+        return false;
+    }
+  }
+  return true;
+}
+
 xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
   if (sliceDimsToDrop.empty())
     return *this;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 7e8ad733fa0ee..e56e443feae0f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -2080,11 +2080,13 @@ struct ConvertLayoutDistribution
                                 PatternRewriter &rewriter) const override {
     auto inputLayout = op.getInputLayoutAttr();
     auto targetLayout = op.getTargetLayoutAttr();
+    auto resShape = cast<VectorType>(op.getResult().getType()).getShape();
 
     if (!inputLayout || !targetLayout)
       return rewriter.notifyMatchFailure(op, "missing layout attributes");
 
-    if (!inputLayout.isCompatibleWith(targetLayout, xegpu::LayoutKind::Lane)) {
+    if (!inputLayout.isCompatibleWith(targetLayout, resShape,
+                                      xegpu::LayoutKind::Lane)) {
       return rewriter.notifyMatchFailure(
           op, "lowering incompatible convert_layout not yet supported");
     }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c9998093171c6..9de5861d81fac 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -630,7 +630,7 @@ struct WgToSgConvertLayoutOp
     SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
 
     // Fast path: if sg_layout and sg_data are identical, no SLM needed
-    if (inputLayout.isCompatibleWith(targetLayout,
+    if (inputLayout.isCompatibleWith(targetLayout, wgShape,
                                      xegpu::LayoutKind::Subgroup)) {
       inputLayout = inputLayout.dropSgLayoutAndData();
       targetLayout = targetLayout.dropSgLayoutAndData();

>From edc26726254a8b02f471943bc32dec8825a567be Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 16 Mar 2026 01:17:06 +0000
Subject: [PATCH 3/7] enhance isCompatibleWith and computeStaticCoord
 utitlities

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 27 ++++++++++---------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |  3 ++-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  3 ++-
 4 files changed, 20 insertions(+), 15 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 04b058a91060f..1c67244658d15 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -352,7 +352,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     /*retTy=*/"bool",
                     /*methodName=*/"isCompatibleWith",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other,
-                                  "SmallVector<int64_t>: %shape",
+                                  "SmallVector<int64_t>": $shape,
                                   "xegpu::LayoutKind": $level)>,
     InterfaceMethod</*desc=*/[{Check if this layout is equal to another layout.
                      For LayoutAttr, this compares all fields.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index bd221c03757f8..b0d34afe62e6a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -238,15 +238,15 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
                        << lane_layout.size();
   }
 
-  if ((sg_data && !sg_layout) || (!sg_data && sg_layout))
+  if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
     return emitError() << "sg_layout and sg_data must be used together";
-  if (sg_data.size() != sg_layout.size())
+  if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
     return emitError()
            << "expected sg_data and sg_layout to have the same rank";
 
-  if ((lane_data && !lane_layout) || (!lane_data && lane_layout))
+  if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
     return emitError() << "lane_layout and lane_data must be used together";
-  if (lane_data.size() != lane_layout.size())
+  if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
     return emitError()
            << "expected lane_data and lane_layout to have the same rank";
 
@@ -806,8 +806,9 @@ bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
                   other.getEffectiveSgLayoutAsInt() &&
               getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
     if (level == xegpu::LayoutKind::Lane)
-      return (getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
-        getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
+      return (getEffectiveLaneLayoutAsInt() ==
+                  other.getEffectiveLaneLayoutAsInt() &&
+              getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
   }
   if (level == xegpu::LayoutKind::Subgroup) {
     int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
@@ -937,12 +938,13 @@ SliceAttr::computeStaticDistributedCoords(int64_t linearId,
 
   SmallVector<int64_t> layout;
   SmallVector<int64_t> subShape;
+  SmallVector<int64_t> instData;
   if (isForWorkgroup()) {
     layout = getEffectiveSgLayoutAsInt();
     subShape = getEffectiveSgDataAsInt();
   } else if (isForSubgroup()) {
     instData = getEffectiveInstDataAsInt();
-    layoutVec = getEffectiveLaneLayoutAsInt();
+    layout = getEffectiveLaneLayoutAsInt();
     subShape = getEffectiveLaneDataAsInt();
   }
   if (!instData.empty()) {
@@ -1039,9 +1041,9 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
           (flattenedThis.getDims() == flattenedOther.getDims()));
 }
 
-bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
-                                  SmallVector<int64_t> shape,
-                                  xegpu::LayoutKind level) {
+bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
+                                 SmallVector<int64_t> shape,
+                                 xegpu::LayoutKind level) {
   if (!other)
     return false;
   if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
@@ -1050,8 +1052,9 @@ bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
                   other.getEffectiveSgLayoutAsInt() &&
               getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
     if (level == xegpu::LayoutKind::Lane)
-      return (getEffectiveLaneLayoutAsInt() == other.getEffectiveLaneLayoutAsInt() &&
-        getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt();
+      return (getEffectiveLaneLayoutAsInt() ==
+                  other.getEffectiveLaneLayoutAsInt() &&
+              getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
   }
 
   auto flattenedThis = flatten();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index e56e443feae0f..6d6725617d306 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -2085,7 +2085,8 @@ struct ConvertLayoutDistribution
     if (!inputLayout || !targetLayout)
       return rewriter.notifyMatchFailure(op, "missing layout attributes");
 
-    if (!inputLayout.isCompatibleWith(targetLayout, resShape,
+    SmallVector<int64_t> resShapeVec(resShape.begin(), resShape.end());
+    if (!inputLayout.isCompatibleWith(targetLayout, resShapeVec,
                                       xegpu::LayoutKind::Lane)) {
       return rewriter.notifyMatchFailure(
           op, "lowering incompatible convert_layout not yet supported");
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9de5861d81fac..687617b5150cb 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -630,7 +630,8 @@ struct WgToSgConvertLayoutOp
     SmallVector<int64_t> targetSgData = targetLayout.getEffectiveSgDataAsInt();
 
     // Fast path: if sg_layout and sg_data are identical, no SLM needed
-    if (inputLayout.isCompatibleWith(targetLayout, wgShape,
+    SmallVector<int64_t> wgShapeVec(wgShape.begin(), wgShape.end());
+    if (inputLayout.isCompatibleWith(targetLayout, wgShapeVec,
                                      xegpu::LayoutKind::Subgroup)) {
       inputLayout = inputLayout.dropSgLayoutAndData();
       targetLayout = targetLayout.dropSgLayoutAndData();

>From 71f3a9efdecf9b755cd0f1c260faf06a93b852a4 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 17 Mar 2026 05:21:35 +0000
Subject: [PATCH 4/7] adding test and polish

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 32 +++++++--
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 13 ----
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 69 ++++++++++++++++++-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  8 +--
 mlir/test/Dialect/XeGPU/propagate-layout.mlir |  9 ++-
 5 files changed, 103 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index b0d34afe62e6a..ded262692ac75 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -1047,18 +1047,25 @@ bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
   if (!other)
     return false;
   if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
+    // short cut when order is the same, no need to compute coords and compare
     if (level == xegpu::LayoutKind::Subgroup)
-      return (getEffectiveSgLayoutAsInt() ==
-                  other.getEffectiveSgLayoutAsInt() &&
-              getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt());
+      if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
+          getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
+        return true;
     if (level == xegpu::LayoutKind::Lane)
-      return (getEffectiveLaneLayoutAsInt() ==
-                  other.getEffectiveLaneLayoutAsInt() &&
-              getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
+      if (getEffectiveLaneLayoutAsInt() ==
+              other.getEffectiveLaneLayoutAsInt() &&
+          getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
+        return true;
   }
 
   auto flattenedThis = flatten();
   auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
+  // debug print parent
+  llvm::dbgs() << "Parent layout - sgLayout: ";
+  for (auto &l : parent.getEffectiveSgLayoutAsInt()) {
+    llvm::dbgs() << l << " ";
+  }
   if (level == xegpu::LayoutKind::Subgroup) {
     int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
     for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
@@ -1076,6 +1083,19 @@ bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
     for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
       auto coords = computeStaticDistributedCoords(id, shape);
       auto otherCoords = other.computeStaticDistributedCoords(id, shape);
+      llvm::dbgs() << "Lane comparison - id: " << id << ", coords: ";
+      for (auto &c : coords) {
+        llvm::dbgs() << "[";
+        llvm::interleaveComma(c, llvm::dbgs());
+        llvm::dbgs() << "] ";
+      }
+      llvm::dbgs() << "otherCoords: ";
+      for (auto &c : otherCoords) {
+        llvm::dbgs() << "[";
+        llvm::interleaveComma(c, llvm::dbgs());
+        llvm::dbgs() << "] ";
+      }
+      llvm::dbgs() << "\n";
       if (coords != otherCoords)
         return false;
     }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 57db05f4a3b74..08773a0cd2791 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -677,19 +677,6 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   auto srcShape = sourceTy.getShape();
   auto resShape = resultTy.getShape();
 
-  size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
-  if (dimDiff == 0) {
-    bool hasUnitDim =
-        llvm::any_of(srcShape, [](int64_t dim) { return dim == 1; });
-    Operation *srcOp = broadcast.getSource().getDefiningOp();
-    if (!srcOp)
-      return;
-    bool produceByShapeCast = isa<vector::ShapeCastOp>(srcOp);
-    assert(
-        hasUnitDim && produceByShapeCast &&
-        "When broadcasting from unit-dim, the producer op must be shape_cast!");
-  }
-
   auto resultLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 6d6725617d306..f5cf7592b1dc5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1091,33 +1091,69 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
+    LLVM_DEBUG(DBGS() << "LoadDistribution: attempting to match warpOp: "
+                      << warpOp << "\n");
+
+    Operation *lastNode = warpOp.getTerminator()->getPrevNode();
+    LLVM_DEBUG(DBGS() << "LoadDistribution: warpOp.getTerminator() = "
+                      << warpOp.getTerminator() << "\n");
+    LLVM_DEBUG(
+        DBGS() << "LoadDistribution: warpOp.getTerminator()->getPrevNode() = "
+               << lastNode << "\n");
+
+    auto loadGatherOp1 = dyn_cast_or_null<xegpu::LoadGatherOp>(lastNode);
+    // print loadGatherOp and its operands
+    if (loadGatherOp1) {
+      LLVM_DEBUG(DBGS() << "LoadDistribution: found LoadGatherOp: "
+                        << loadGatherOp1 << "\n");
+    } else {
+      LLVM_DEBUG(DBGS() << "LoadDistribution: no LoadGatherOp found as last op "
+                           "in warp region\n");
+    }
+
     OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
       // Check if the yield operand that was produced by the *last* scattered
       // load op to avoid sinking it before barriers (maintain memory order).
       return isa<xegpu::LoadGatherOp>(op) &&
              warpOp.getTerminator()->getPrevNode() == op;
     });
-    if (!producedByLastLoad)
+    if (!producedByLastLoad) {
+      LLVM_DEBUG(DBGS() << "LoadDistribution: no LoadGatherOp as last op in "
+                           "warp region, bailing out\n");
       return rewriter.notifyMatchFailure(
           warpOp, "The last op is not xegpu::LoadGatherOp");
+    }
 
     auto loadGatherOp =
         producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
+    LLVM_DEBUG(DBGS() << "LoadDistribution: found LoadGatherOp: "
+                      << loadGatherOp << "\n");
     auto offsets = loadGatherOp.getOffsets();
     if (!offsets || !isa<VectorType>(offsets.getType()) ||
-        !isa<VectorType>(loadGatherOp.getMask().getType()))
+        !isa<VectorType>(loadGatherOp.getMask().getType())) {
+      LLVM_DEBUG(DBGS() << "LoadDistribution: offsets or mask are not vector "
+                           "types, bailing out\n");
       return rewriter.notifyMatchFailure(
           loadGatherOp,
           "Load op must have a vector arguments for offsets and mask");
+    }
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
     VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
     VectorType resultVecTy =
         cast<VectorType>(loadGatherOp.getResult().getType());
+    LLVM_DEBUG(DBGS() << "LoadDistribution: offsetsTy=" << offsetsTy
+                      << ", maskTy=" << maskTy
+                      << ", resultVecTy=" << resultVecTy << "\n");
     // add handling leading unit dimensions support
     int chunkSize = loadGatherOp.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+    LLVM_DEBUG(DBGS() << "LoadDistribution: chunkSize=" << chunkSize
+                      << ", effectiveVecRank=" << effectiveVecRank << "\n");
     for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
       if (resultVecTy.getShape()[i] != 1) {
+        LLVM_DEBUG(DBGS() << "LoadDistribution: non-unit leading dim at index "
+                          << i << " (size=" << resultVecTy.getShape()[i]
+                          << "), bailing out\n");
         return rewriter.notifyMatchFailure(
             loadGatherOp, "Only unit dimensions allowed for the leading "
                           "dimensions of the load vector!");
@@ -1127,6 +1163,8 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     auto layoutOffsets =
         xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
     auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
+    LLVM_DEBUG(DBGS() << "LoadDistribution: layoutOffsets=" << layoutOffsets
+                      << ", layoutMask=" << layoutMask << "\n");
 
     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -1134,6 +1172,9 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
     if (failed(distOffsetsByWarpOpOrFailure) ||
         failed(distMaskByWarpOpOrFailure)) {
+      LLVM_DEBUG(
+          DBGS() << "LoadDistribution: failed to compute distributed types "
+                    "for offsets or mask, bailing out\n");
       return rewriter.notifyMatchFailure(
           loadGatherOp,
           "Some vector operands have no layouts, using defaults instead.");
@@ -1147,12 +1188,24 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
     VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
+    LLVM_DEBUG(DBGS() << "LoadDistribution: operandIdx=" << operandIdx
+                      << ", distResultTy=" << distResultTy
+                      << ", distOffsetsTy=" << distOffsetsTy
+                      << ", distMaskTy=" << distMaskTy << "\n");
 
     SmallVector<Type> operandTypesToYield = {operands[0].getType(),
                                              distOffsetsTy, distMaskTy};
 
+    LLVM_DEBUG(DBGS() << "LoadDistribution: creating new warp op with "
+                      << operands.size() << " operands and "
+                      << operandTypesToYield.size() << " types to yield\n");
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
+    LLVM_DEBUG({
+      DBGS() << "LoadDistribution: newRetIndices=[";
+      llvm::interleaveComma(newRetIndices, llvm::dbgs());
+      llvm::dbgs() << "]\n";
+    });
 
     rewriter.setInsertionPointAfter(newWarpOp);
 
@@ -1166,24 +1219,36 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     VectorType distMaskTy1D =
         VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
                         distMaskByWarpOpOrFailure.value().getElementType());
+    LLVM_DEBUG(DBGS() << "LoadDistribution: 1D types: loadVecTy1D="
+                      << loadVecTy1D << ", distOffsetsTy1D=" << distOffsetsTy1D
+                      << ", distMaskTy1D=" << distMaskTy1D << "\n");
 
     Value distOffsetVal = resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
     Value distmaskVal = resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
+    LLVM_DEBUG(DBGS() << "LoadDistribution: resolved distOffsetVal="
+                      << distOffsetVal << ", distmaskVal=" << distmaskVal
+                      << "\n");
 
     SmallVector<Value> newLoadGatherOperands = {
         newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
 
+    LLVM_DEBUG(DBGS() << "LoadDistribution: creating distributed "
+                         "LoadGatherOp with result type "
+                      << loadVecTy1D << "\n");
     xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
         rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
         loadGatherOp->getAttrs());
     xegpu::removeLayoutAttrs(newOp);
+    LLVM_DEBUG(DBGS() << "LoadDistribution: created new op: " << newOp << "\n");
     Value distributedVal = newWarpOp.getResult(operandIdx);
     // Resolve the output type and replace all uses.
     rewriter.replaceAllUsesWith(
         distributedVal,
         resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
+    LLVM_DEBUG(DBGS() << "LoadDistribution: successfully distributed "
+                         "LoadGatherOp\n");
     return success();
   }
 };
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 74c505e6dc6be..7390b47b3f8d9 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -758,8 +758,8 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
   %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
       !xegpu.tensor_desc<16x2xf32,
         #xegpu.scatter_tdesc_attr<chunk_size = 2>,
-        // expected-error at +1 {{expected sg_layout being used with sg_data}}
-        #xegpu.layout<sg_data = [16, 2], lane_layout = [8, 1], lane_data = [1, 2]>>
+        // expected-error at +1 {{sg_layout and sg_data must be used together}}
+        #xegpu.layout<sg_layout = [2, 1], lane_layout = [8, 1], lane_data = [1, 2]>>
   return
 }
 
@@ -768,8 +768,8 @@ func.func @tensor_desc_rank_mismatch(%src: ui64, %offsets: vector<16xindex>) {
   %1 = xegpu.create_tdesc %src, %offsets : ui64, vector<16xindex> ->
       !xegpu.tensor_desc<16x2xf32,
         #xegpu.scatter_tdesc_attr<chunk_size = 2>,
-        // expected-error at +1 {{expected lane_layout being used with lane_data}}
-        #xegpu.layout<inst_data = [16, 2], lane_data = [1, 2]>>
+        // expected-error at +1 {{lane_layout and lane_data must be used together}}
+        #xegpu.layout<inst_data = [16, 2], lane_layout = [16, 1]>>
   return
 }
 
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index ddd2d22108d1f..2dfade7369e53 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -739,7 +739,9 @@ gpu.module @test {
 // CHECK-SAME:       {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16>
 // CHECK-NEXT:    %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]]
 // CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
-// CHECK-NEXT:    vector.broadcast %[[SHAPECAST]]
+// CHECK-NEXT:    %[[EXP:.*]]  = math.exp %[[SHAPECAST]]
+// CHECK-SAME        {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16>
+// CHECK-NEXT:    vector.broadcast %[[EXP]]
 // CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
 
 func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
@@ -748,8 +750,9 @@ func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x1
   %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
   %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
   %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
-  %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16>
-  xegpu.store_nd %6, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  %6 = math.exp %5: vector<16x1xf16>
+  %7 = vector.broadcast %6 : vector<16x1xf16> to vector<16x16xf16>
+  xegpu.store_nd %7, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
   return
 }
 }

>From c9ac053d1a95cc17cefe0dc81143684c654bdabe Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 17 Mar 2026 05:41:25 +0000
Subject: [PATCH 5/7] remove debug print

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 18 -----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 18 -----
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 66 +------------------
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 12 ----
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   | 29 --------
 5 files changed, 1 insertion(+), 142 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1c67244658d15..7917a0dde556c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -301,18 +301,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       assert(
                           !subShape.empty() &&
                           "sgdata or lanedata cannot be empty for distributed shape computation");
-
                       SmallVector<int64_t> distributedShape(shape.size());
-                      llvm::errs() << "computeDistributedShape:\n";
-                      llvm::errs() << "  shape: [";
-                        llvm::interleaveComma(shape, llvm::errs());
-                        llvm::errs() << "]\n";
-                      llvm::errs() << "  layout: [";
-                        llvm::interleaveComma(layout, llvm::errs());
-                        llvm::errs() << "]\n";
-                      llvm::errs() << "  subShape: [";
-                        llvm::interleaveComma(subShape, llvm::errs());
-                        llvm::errs() << "]\n";
                       for (auto [i, dim] : llvm::enumerate(shape)) {
                         int64_t distri_unit = layout[i]*subShape[i];
                         if ((dim % distri_unit) == 0) {
@@ -320,20 +309,13 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                           distributedShape[i] = dim / layout[i];
                           assert((distributedShape[i] % subShape[i] == 0) &&
                                 "Even distribution: sgdata or lanedata must divide the distributed dimension");
-                          llvm::errs() << "  dim[" << i << "]=" << dim
-                            << " evenly divisible, distributed=" << distributedShape[i] << "\n";
                         } else {
                           // wrap around case, the dimension size must be equal to subShape value
                           assert(dim == subShape[i] &&
                                 "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
                           distributedShape[i] = dim;
-                          llvm::errs() << "  dim[" << i << "]=" << dim
-                            << " wrap-around, kept=" << distributedShape[i] << "\n";
                         }
                       }
-                      llvm::errs() << "  result: [";
-                        llvm::interleaveComma(distributedShape, llvm::errs());
-                        llvm::errs() << "]\n";
                       return distributedShape;
                     }]>,
     InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ded262692ac75..0fbd339146d7b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -1061,11 +1061,6 @@ bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
 
   auto flattenedThis = flatten();
   auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
-  // debug print parent
-  llvm::dbgs() << "Parent layout - sgLayout: ";
-  for (auto &l : parent.getEffectiveSgLayoutAsInt()) {
-    llvm::dbgs() << l << " ";
-  }
   if (level == xegpu::LayoutKind::Subgroup) {
     int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
     for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
@@ -1083,19 +1078,6 @@ bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
     for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
       auto coords = computeStaticDistributedCoords(id, shape);
       auto otherCoords = other.computeStaticDistributedCoords(id, shape);
-      llvm::dbgs() << "Lane comparison - id: " << id << ", coords: ";
-      for (auto &c : coords) {
-        llvm::dbgs() << "[";
-        llvm::interleaveComma(c, llvm::dbgs());
-        llvm::dbgs() << "] ";
-      }
-      llvm::dbgs() << "otherCoords: ";
-      for (auto &c : otherCoords) {
-        llvm::dbgs() << "[";
-        llvm::interleaveComma(c, llvm::dbgs());
-        llvm::dbgs() << "] ";
-      }
-      llvm::dbgs() << "\n";
       if (coords != otherCoords)
         return false;
     }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index f5cf7592b1dc5..4ccdf000849b1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1091,26 +1091,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
   LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
                                 PatternRewriter &rewriter) const override {
-    LLVM_DEBUG(DBGS() << "LoadDistribution: attempting to match warpOp: "
-                      << warpOp << "\n");
-
-    Operation *lastNode = warpOp.getTerminator()->getPrevNode();
-    LLVM_DEBUG(DBGS() << "LoadDistribution: warpOp.getTerminator() = "
-                      << warpOp.getTerminator() << "\n");
-    LLVM_DEBUG(
-        DBGS() << "LoadDistribution: warpOp.getTerminator()->getPrevNode() = "
-               << lastNode << "\n");
-
-    auto loadGatherOp1 = dyn_cast_or_null<xegpu::LoadGatherOp>(lastNode);
-    // print loadGatherOp and its operands
-    if (loadGatherOp1) {
-      LLVM_DEBUG(DBGS() << "LoadDistribution: found LoadGatherOp: "
-                        << loadGatherOp1 << "\n");
-    } else {
-      LLVM_DEBUG(DBGS() << "LoadDistribution: no LoadGatherOp found as last op "
-                           "in warp region\n");
-    }
-
     OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
       // Check if the yield operand that was produced by the *last* scattered
       // load op to avoid sinking it before barriers (maintain memory order).
@@ -1118,42 +1098,27 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
              warpOp.getTerminator()->getPrevNode() == op;
     });
     if (!producedByLastLoad) {
-      LLVM_DEBUG(DBGS() << "LoadDistribution: no LoadGatherOp as last op in "
-                           "warp region, bailing out\n");
       return rewriter.notifyMatchFailure(
           warpOp, "The last op is not xegpu::LoadGatherOp");
     }
 
     auto loadGatherOp =
         producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
-    LLVM_DEBUG(DBGS() << "LoadDistribution: found LoadGatherOp: "
-                      << loadGatherOp << "\n");
     auto offsets = loadGatherOp.getOffsets();
     if (!offsets || !isa<VectorType>(offsets.getType()) ||
-        !isa<VectorType>(loadGatherOp.getMask().getType())) {
-      LLVM_DEBUG(DBGS() << "LoadDistribution: offsets or mask are not vector "
-                           "types, bailing out\n");
+        !isa<VectorType>(loadGatherOp.getMask().getType()))
       return rewriter.notifyMatchFailure(
           loadGatherOp,
           "Load op must have a vector arguments for offsets and mask");
-    }
     VectorType offsetsTy = cast<VectorType>(offsets.getType());
     VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
     VectorType resultVecTy =
         cast<VectorType>(loadGatherOp.getResult().getType());
-    LLVM_DEBUG(DBGS() << "LoadDistribution: offsetsTy=" << offsetsTy
-                      << ", maskTy=" << maskTy
-                      << ", resultVecTy=" << resultVecTy << "\n");
     // add handling leading unit dimensions support
     int chunkSize = loadGatherOp.getChunkSize().value_or(1);
     int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
-    LLVM_DEBUG(DBGS() << "LoadDistribution: chunkSize=" << chunkSize
-                      << ", effectiveVecRank=" << effectiveVecRank << "\n");
     for (int i = 0; i < resultVecTy.getRank() - effectiveVecRank; i++) {
       if (resultVecTy.getShape()[i] != 1) {
-        LLVM_DEBUG(DBGS() << "LoadDistribution: non-unit leading dim at index "
-                          << i << " (size=" << resultVecTy.getShape()[i]
-                          << "), bailing out\n");
         return rewriter.notifyMatchFailure(
             loadGatherOp, "Only unit dimensions allowed for the leading "
                           "dimensions of the load vector!");
@@ -1163,8 +1128,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     auto layoutOffsets =
         xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(1));
     auto layoutMask = xegpu::getTemporaryLayout(loadGatherOp->getOpOperand(2));
-    LLVM_DEBUG(DBGS() << "LoadDistribution: layoutOffsets=" << layoutOffsets
-                      << ", layoutMask=" << layoutMask << "\n");
 
     FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
         getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
@@ -1172,9 +1135,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
     if (failed(distOffsetsByWarpOpOrFailure) ||
         failed(distMaskByWarpOpOrFailure)) {
-      LLVM_DEBUG(
-          DBGS() << "LoadDistribution: failed to compute distributed types "
-                    "for offsets or mask, bailing out\n");
       return rewriter.notifyMatchFailure(
           loadGatherOp,
           "Some vector operands have no layouts, using defaults instead.");
@@ -1188,24 +1148,12 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
         cast<VectorType>(warpOp.getResult(operandIdx).getType());
     VectorType distOffsetsTy = distOffsetsByWarpOpOrFailure.value();
     VectorType distMaskTy = distMaskByWarpOpOrFailure.value();
-    LLVM_DEBUG(DBGS() << "LoadDistribution: operandIdx=" << operandIdx
-                      << ", distResultTy=" << distResultTy
-                      << ", distOffsetsTy=" << distOffsetsTy
-                      << ", distMaskTy=" << distMaskTy << "\n");
 
     SmallVector<Type> operandTypesToYield = {operands[0].getType(),
                                              distOffsetsTy, distMaskTy};
 
-    LLVM_DEBUG(DBGS() << "LoadDistribution: creating new warp op with "
-                      << operands.size() << " operands and "
-                      << operandTypesToYield.size() << " types to yield\n");
     gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
         rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
-    LLVM_DEBUG({
-      DBGS() << "LoadDistribution: newRetIndices=[";
-      llvm::interleaveComma(newRetIndices, llvm::dbgs());
-      llvm::dbgs() << "]\n";
-    });
 
     rewriter.setInsertionPointAfter(newWarpOp);
 
@@ -1219,36 +1167,24 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
     VectorType distMaskTy1D =
         VectorType::get({distMaskByWarpOpOrFailure.value().getNumElements()},
                         distMaskByWarpOpOrFailure.value().getElementType());
-    LLVM_DEBUG(DBGS() << "LoadDistribution: 1D types: loadVecTy1D="
-                      << loadVecTy1D << ", distOffsetsTy1D=" << distOffsetsTy1D
-                      << ", distMaskTy1D=" << distMaskTy1D << "\n");
 
     Value distOffsetVal = resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[1]), distOffsetsTy1D, rewriter);
     Value distmaskVal = resolveDistributedTy(
         newWarpOp.getResult(newRetIndices[2]), distMaskTy1D, rewriter);
-    LLVM_DEBUG(DBGS() << "LoadDistribution: resolved distOffsetVal="
-                      << distOffsetVal << ", distmaskVal=" << distmaskVal
-                      << "\n");
 
     SmallVector<Value> newLoadGatherOperands = {
         newWarpOp.getResult(newRetIndices[0]), distOffsetVal, distmaskVal};
 
-    LLVM_DEBUG(DBGS() << "LoadDistribution: creating distributed "
-                         "LoadGatherOp with result type "
-                      << loadVecTy1D << "\n");
     xegpu::LoadGatherOp newOp = xegpu::LoadGatherOp::create(
         rewriter, newWarpOp.getLoc(), loadVecTy1D, newLoadGatherOperands,
         loadGatherOp->getAttrs());
     xegpu::removeLayoutAttrs(newOp);
-    LLVM_DEBUG(DBGS() << "LoadDistribution: created new op: " << newOp << "\n");
     Value distributedVal = newWarpOp.getResult(operandIdx);
     // Resolve the output type and replace all uses.
     rewriter.replaceAllUsesWith(
         distributedVal,
         resolveDistributedTy(newOp.getResult(), distResultTy, rewriter));
-    LLVM_DEBUG(DBGS() << "LoadDistribution: successfully distributed "
-                         "LoadGatherOp\n");
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 687617b5150cb..ade7276e52173 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -58,18 +58,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape,
     return std::make_pair(sgShape, count);
   auto sgData = layout.getEffectiveSgDataAsInt();
   count = computeProduct(distributedShape.value()) / computeProduct(sgData);
-  // auto sgLayout = layout.getEffectiveSgLayoutAsInt();
-  // auto sgData = layout.getEffectiveSgDataAsInt();
-  // SmallVector<int64_t> distUnit =
-  //     computeElementwiseMul(sgLayout, sgData);
-  // // Clamp distUnit to the original shape to handle wrap-around cases where
-  // data
-  // // is shared among subgroups, which may cause distUnit to exceed the
-  // original
-  // // shape.
-  // for (size_t i = 0; i < distUnit.size(); ++i)
-  //   distUnit[i] = std::min(shape[i], distUnit[i]);
-  // count = computeProduct(shape) / computeProduct(distUnit);
   return std::make_pair(sgData, count);
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 9884a46a7d5b8..26a211b043ead 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -103,35 +103,6 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
   return xegpu::getDistributedVectorType(helperTdescTy);
 }
 
-// FailureOr<VectorType>
-// xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
-//                                        VectorType originalType) {
-//   if (!layout)
-//     return failure();
-//   assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
-//          "Expecting a valid layout.");
-//   SmallVector<int64_t> effectiveLaneLayout =
-//       layout.getEffectiveLaneLayoutAsInt();
-//   assert(static_cast<size_t>(originalType.getRank()) >=
-//              effectiveLaneLayout.size() &&
-//          "Rank of the original vector type should be greater or equal to the
-//          " "size of the lane layout to distribute the vector type.");
-//   SmallVector<int64_t> distributedShape(originalType.getShape());
-//   // Only distribute the last `laneLayout.size()` dimensions. The remaining
-//   // dimensions are not distributed.
-//   unsigned distributionStart =
-//       originalType.getRank() - effectiveLaneLayout.size();
-//   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
-//     if (i < distributionStart)
-//       continue;
-//     // Check if the dimension can be distributed evenly.
-//     if (dim % effectiveLaneLayout[i - distributionStart] != 0)
-//       return failure();
-//     distributedShape[i] = dim / effectiveLaneLayout[i - distributionStart];
-//   }
-//   return VectorType::get(distributedShape, originalType.getElementType());
-// }
-
 FailureOr<VectorType>
 xegpu::getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
                                        VectorType originalType) {

>From f3239ecb591b5e07a3601cad3fd735ca0fd6648e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 19 Mar 2026 21:14:32 +0000
Subject: [PATCH 6/7] address feedback

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 18 +++++++++++------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 20 ++-----------------
 2 files changed, 14 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 7917a0dde556c..733857c03ce78 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -279,10 +279,16 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                     /*retTy=*/"SmallVector<SmallVector<int64_t>>",
                     /*methodName=*/"computeStaticDistributedCoords",
                     /*args=*/(ins "int64_t":$linearId, "ArrayRef<int64_t>":$shape)>,
-    InterfaceMethod<[{Computes the per-compute-unit shape by dividing each dimension of
-                      `shape` by the corresponding layout factor (sg_layout or
-                      lane_layout). For wrap-around dimensions where the division is uneven,
-                      the tensor tile is broadcasted to all subgroups/lanes.}],
+    InterfaceMethod<[{Computes the distributed shape for each compute unit by dividing each
+                      dimension of `shape` by the corresponding layout factor (sg_layout or
+                      lane_layout).
+                      The distributed shape represents the per-compute-unit tile. Each
+                      distribution unit is defined as the combination of layout factors and
+                      per-unit data (`subshape`, e.g., sg_data or lane_data). When `shape`
+                      spans multiple distribution units, the distributed shape may contain
+                      multiple such units.
+                      For wrap-around dimensions where the division is uneven, the tensor tile
+                      is broadcast to all subgroups/lanes.}],
                     /*retTy=*/"FailureOr<SmallVector<int64_t>>",
                     /*methodName=*/"computeDistributedShape",
                     /*args=*/(ins "SmallVector<int64_t>":$shape),
@@ -303,8 +309,8 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                           "sgdata or lanedata cannot be empty for distributed shape computation");
                       SmallVector<int64_t> distributedShape(shape.size());
                       for (auto [i, dim] : llvm::enumerate(shape)) {
-                        int64_t distri_unit = layout[i]*subShape[i];
-                        if ((dim % distri_unit) == 0) {
+                        int64_t distriUnit = layout[i]*subShape[i];
+                        if ((dim % distriUnit) == 0) {
                           // Evenly divisible case, divide the dimension by the layout factor.
                           distributedShape[i] = dim / layout[i];
                           assert((distributedShape[i] % subShape[i] == 0) &&
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0fbd339146d7b..615a3b7828b77 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -409,15 +409,7 @@ LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
 
   // Delinearize the linear ID using the order attribute.
   DenseI32ArrayAttr orderAttr = getOrder();
-  SmallVector<int64_t> order;
-  if (orderAttr && !orderAttr.empty()) {
-    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
-      return static_cast<int64_t>(idx);
-    });
-  } else {
-    order =
-        llvm::to_vector(llvm::reverse(llvm::seq<int64_t>(0, layoutVec.size())));
-  }
+  SmallVector<int64_t> order = getEffectiveOrderAsInt();
   SmallVector<int64_t> delinearizedId(layoutVec.size());
   int64_t remaining = linearId;
   for (size_t i = 0; i < order.size(); ++i) {
@@ -964,15 +956,7 @@ SliceAttr::computeStaticDistributedCoords(int64_t linearId,
     parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
 
   DenseI32ArrayAttr orderAttr = parent.getOrder();
-  SmallVector<int64_t> order;
-  if (orderAttr && !orderAttr.empty()) {
-    order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
-      return static_cast<int64_t>(idx);
-    });
-  } else {
-    order = llvm::to_vector(
-        llvm::reverse(llvm::seq<int64_t>(0, parentLayoutVec.size())));
-  }
+  SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
   SmallVector<int64_t> allIds(parentLayoutVec.size());
   int64_t remaining = linearId;
   for (size_t i = 0; i < order.size(); ++i) {

>From 5cf6741fcb9bcf72dd3bf67083aa260b57af553c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 19 Mar 2026 22:08:18 +0000
Subject: [PATCH 7/7] refactor with helpers

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 110 ++++++++----------
 .../Transforms/XeGPUSubgroupDistribute.cpp    |   3 +-
 2 files changed, 49 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 615a3b7828b77..2ad63626c4b67 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -96,6 +96,31 @@ genCoordinates(OpBuilder &builder, Location loc,
   return coordinates;
 }
 
+static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
+    llvm::ArrayRef<int64_t> canonicalIds, llvm::ArrayRef<int64_t> layout,
+    llvm::ArrayRef<int64_t> subShape, llvm::ArrayRef<int64_t> shape) {
+  // Compute distribution unit shape (clamped to srcShape).
+  SmallVector<int64_t> distUnitShape(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
+
+  // Compute local offset of this ID within a distribution unit.
+  SmallVector<int64_t> localOffset(shape.size());
+  for (size_t i = 0; i < shape.size(); ++i)
+    localOffset[i] = canonicalIds[i] * subShape[i];
+
+  // Enumerate all distribution units and compute coordinates.
+  SmallVector<SmallVector<int64_t>> coordinates;
+  for (SmallVector<int64_t> unitOffs :
+       StaticTileOffsetRange(shape, distUnitShape)) {
+    SmallVector<int64_t> coord(shape.size());
+    for (size_t i = 0; i < shape.size(); ++i)
+      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
+    coordinates.push_back(coord);
+  }
+  return coordinates;
+}
+
 // Checks if the given shape can be evenly distributed based on the layout
 // and data factors provided by the LayoutAttr.
 bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
@@ -408,7 +433,6 @@ LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
   assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
 
   // Delinearize the linear ID using the order attribute.
-  DenseI32ArrayAttr orderAttr = getOrder();
   SmallVector<int64_t> order = getEffectiveOrderAsInt();
   SmallVector<int64_t> delinearizedId(layoutVec.size());
   int64_t remaining = linearId;
@@ -418,26 +442,7 @@ LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
     remaining = remaining / layoutVec[dimIdx];
   }
 
-  // Compute distribution unit shape (clamped to srcShape).
-  SmallVector<int64_t> distUnitShape(shape.size());
-  for (size_t i = 0; i < shape.size(); ++i)
-    distUnitShape[i] = std::min(shape[i], layoutVec[i] * subShape[i]);
-
-  // Compute local offset of this ID within a distribution unit.
-  SmallVector<int64_t> localOffset(shape.size());
-  for (size_t i = 0; i < shape.size(); ++i)
-    localOffset[i] = delinearizedId[i] * subShape[i];
-
-  // Enumerate all distribution units and compute coordinates.
-  SmallVector<SmallVector<int64_t>> coordinates;
-  for (SmallVector<int64_t> unitOffs :
-       StaticTileOffsetRange(shape, distUnitShape)) {
-    SmallVector<int64_t> coord(shape.size());
-    for (size_t i = 0; i < shape.size(); ++i)
-      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
-    coordinates.push_back(coord);
-  }
-  return coordinates;
+  return genStaticCoordinates(delinearizedId, layoutVec, subShape, shape);
 }
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
@@ -802,26 +807,27 @@ bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
                   other.getEffectiveLaneLayoutAsInt() &&
               getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt());
   }
-  if (level == xegpu::LayoutKind::Subgroup) {
-    int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
-    for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+
+  auto compareCoordsForAllIds = [&](int64_t size) {
+    for (int64_t id : llvm::seq<int64_t>(0, size)) {
       auto coords = computeStaticDistributedCoords(id, shape);
       auto otherCoords = other.computeStaticDistributedCoords(id, shape);
       if (coords != otherCoords)
         return false;
     }
+    return true;
+  };
+
+  if (level == xegpu::LayoutKind::Subgroup) {
+    int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
+    return compareCoordsForAllIds(wgSize);
   }
   if (level == xegpu::LayoutKind::InstData) {
     return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
   }
   if (level == xegpu::LayoutKind::Lane) {
     int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
-    for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
-      auto coords = computeStaticDistributedCoords(id, shape);
-      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
-      if (coords != otherCoords)
-        return false;
-    }
+    return compareCoordsForAllIds(subgroupSize);
   }
   return true;
 }
@@ -955,7 +961,6 @@ SliceAttr::computeStaticDistributedCoords(int64_t linearId,
   else
     parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
 
-  DenseI32ArrayAttr orderAttr = parent.getOrder();
   SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
   SmallVector<int64_t> allIds(parentLayoutVec.size());
   int64_t remaining = linearId;
@@ -972,26 +977,7 @@ SliceAttr::computeStaticDistributedCoords(int64_t linearId,
   SmallVector<int64_t> canonicalIds =
       XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
 
-  // Compute distribution unit shape (clamped to srcShape).
-  SmallVector<int64_t> distUnitShape(shape.size());
-  for (size_t i = 0; i < shape.size(); ++i)
-    distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
-
-  // Compute local offset of this ID within a distribution unit.
-  SmallVector<int64_t> localOffset(shape.size());
-  for (size_t i = 0; i < shape.size(); ++i)
-    localOffset[i] = canonicalIds[i] * subShape[i];
-
-  // Enumerate all distribution units and compute coordinates.
-  SmallVector<SmallVector<int64_t>> coordinates;
-  for (SmallVector<int64_t> unitOffs :
-       StaticTileOffsetRange(shape, distUnitShape)) {
-    SmallVector<int64_t> coord(shape.size());
-    for (size_t i = 0; i < shape.size(); ++i)
-      coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
-    coordinates.push_back(coord);
-  }
-  return coordinates;
+  return genStaticCoordinates(canonicalIds, layout, subShape, shape);
 }
 
 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
@@ -1043,28 +1029,28 @@ bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
         return true;
   }
 
-  auto flattenedThis = flatten();
-  auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
-  if (level == xegpu::LayoutKind::Subgroup) {
-    int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
-    for (int64_t id : llvm::seq<int64_t>(0, wgSize)) {
+  auto compareCoordsForAllIds = [&](int64_t size) {
+    for (int64_t id : llvm::seq<int64_t>(0, size)) {
       auto coords = computeStaticDistributedCoords(id, shape);
       auto otherCoords = other.computeStaticDistributedCoords(id, shape);
       if (coords != otherCoords)
         return false;
     }
+    return true;
+  };
+
+  auto flattenedThis = flatten();
+  auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
+  if (level == xegpu::LayoutKind::Subgroup) {
+    int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
+    return compareCoordsForAllIds(wgSize);
   }
   if (level == xegpu::LayoutKind::InstData) {
     return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
   }
   if (level == xegpu::LayoutKind::Lane) {
     int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
-    for (int64_t id : llvm::seq<int64_t>(0, subgroupSize)) {
-      auto coords = computeStaticDistributedCoords(id, shape);
-      auto otherCoords = other.computeStaticDistributedCoords(id, shape);
-      if (coords != otherCoords)
-        return false;
-    }
+    return compareCoordsForAllIds(subgroupSize);
   }
   return true;
 }
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 4ccdf000849b1..6d6725617d306 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1097,10 +1097,9 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
       return isa<xegpu::LoadGatherOp>(op) &&
              warpOp.getTerminator()->getPrevNode() == op;
     });
-    if (!producedByLastLoad) {
+    if (!producedByLastLoad)
       return rewriter.notifyMatchFailure(
           warpOp, "The last op is not xegpu::LoadGatherOp");
-    }
 
     auto loadGatherOp =
         producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();



More information about the Mlir-commits mailing list