[Mlir-commits] [mlir] [MLIR][XeGPU] setUnitDim bug fix and add documentation (PR #173521)

Jianhui Li llvmlistbot at llvm.org
Mon Jan 19 22:29:12 PST 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/173521

>From daf352dd17e61c34ff69c5fa220c93cd6458f3e7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 22 Dec 2025 03:22:02 +0000
Subject: [PATCH 1/3] fix-setUnitDim

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 16 ++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 59 ++++++++++---------
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 17 +++---
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  3 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir    |  2 +-
 .../XeGPU/xegpu-wg-to-sg-unify-ops.mlir       | 38 +++++++++++-
 6 files changed, 87 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 446f64fffa468..d4d5c7d58b37e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -454,7 +454,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return 0;
     }
 
-    LayoutAttr dropSgLayoutAndData() {
+    LayoutAttr dropSgLayoutAndData() const{
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getInstData() && !getLaneLayout())
         return nullptr;
@@ -462,7 +462,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
                              getLaneLayout(), getLaneData(), getOrder());
     }
 
-    LayoutAttr dropInstData() {
+    LayoutAttr dropInstData() const{
       // avoid every field of the attribute is nullptr, which may lead to segment fault
       if (!getSgLayout() && !getLaneLayout())
         return nullptr;
@@ -501,10 +501,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     }
 
     //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
-    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
 
     //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
 
     /// Delinearizes a linear ID into its multidimensional indices
     /// based on the effective level of the layout.
@@ -653,7 +653,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return {};
     }
 
-    SliceAttr dropSgLayoutAndData() {
+    SliceAttr dropSgLayoutAndData() const{
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
       parent = parent.dropSgLayoutAndData();
@@ -662,7 +662,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return SliceAttr::get(getContext(), parent, attr.getDims());
     }
 
-    SliceAttr dropInstData() {
+    SliceAttr dropInstData() const{
       SliceAttr attr = flatten();
       auto parent = dyn_cast<LayoutAttr>(attr.getParent());
       parent = parent.dropInstData();
@@ -672,10 +672,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     }
 
     //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
-    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims) const;
 
     //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims) const;
 
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index ccf17da26c942..13bf475920075 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -400,7 +400,7 @@ bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 }
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const{
   auto sgDataOpt = getSgData();
   auto instDataOpt = getInstData();
   auto laneDataOpt = getLaneData();
@@ -441,7 +441,7 @@ DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
 }
 
 // set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const{
   auto sgLayoutOpt = getSgLayout();
   auto laneLayoutOpt = getLaneLayout();
 
@@ -587,7 +587,7 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
   // dims.
   llvm::SmallDenseSet<int64_t> thisDims(
       flattenedThis.getDims().asArrayRef().begin(),
-      flattenedThis.getDims().asArrayRef().end());
+      flattenedThis.getDims().asArrayRef().end()); 
   return llvm::all_of(flattenedOther.getDims().asArrayRef(),
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
@@ -604,55 +604,60 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 }
 
 // Helper function to adjust unit dimensions from sliced space to parent space
+// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
+// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps to
+// dim [0] in parent space; if we want to set unit dim [1] in sliced space, it maps to
+// dim [2] in parent space.
 static SetVector<int64_t>
 adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
                             ArrayRef<int64_t> sliceDims) {
-  // Reconstruct parent's non-sliced dimensions
+  // get max number from sliceDims and unitDims to determine parent space rank
+  int64_t maxDim = -1;
+  maxDim = std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
+  maxDim = std::max(maxDim, *std::max_element(unitDims.begin(), unitDims.end()));
+  int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
 
-  int64_t parentRank = sliceDims.size() + unitDims.size();
+  // get remaining dims in parent space after applying slicing with parent's slice Dims
   llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
                                              sliceDims.end());
-  SmallVector<int64_t> nonSlicedDims;
-  for (int64_t i = 0; i < parentRank; ++i) {
+  SmallVector<int64_t> remainingDims;
+  for (int64_t i = 0; i < parentSpaceRank; ++i) {
     if (!slicedDimsSet.contains(i))
-      nonSlicedDims.push_back(i);
+      remainingDims.push_back(i);
   }
 
   // Map unit dims from sliced space to parent space
   SetVector<int64_t> adjustUnitDims;
   for (auto dim : unitDims) {
-    if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
-      adjustUnitDims.insert(nonSlicedDims[dim]);
-    }
+      int64_t mappedDim = remainingDims[dim];
+      adjustUnitDims.insert(mappedDim);
   }
 
   return adjustUnitDims;
 }
 
 // set the layout for unit dims: sg_data, inst_data and lane_data to 1
-DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
-  SliceAttr attr = flatten();
-  ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
-  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const{
+  DistributeLayoutAttr parentLayout = getParent();
+
+  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
-  SetVector<int64_t> adjustUnitDims =
-      adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+  SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
 
-  return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
-                        attr.getDims());
+  return SliceAttr::get(getContext(), parentLayout.setUnitDimData(adjustUnitDims),
+                        getDims());
 }
 
 // set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
-DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
-  SliceAttr attr = flatten();
-  ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
-  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const{
+  DistributeLayoutAttr parentLayout = getParent();
+
+  ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
-  SetVector<int64_t> adjustUnitDims =
-      adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+  SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
 
-  return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
-                        attr.getDims());
+  return SliceAttr::get(getContext(), parentLayout.setUnitDimLayout(adjustUnitDims),
+                        getDims());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 07572a4950760..0dcabf0fccf6a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -614,20 +614,20 @@ struct WgToSgConvertLayoutOp
   LogicalResult
   matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // TODO: currently, we only support LayoutAttr
-    auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
-    auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
+
+    auto input = op.getInputLayout();
+    auto target = op.getTargetLayout();
 
     if (!input || !target || !input.isForWorkgroup() ||
         !target.isForWorkgroup())
       return rewriter.notifyMatchFailure(
           op, "Input and target layouts must have subgroup layout");
 
-    DenseI32ArrayAttr inputSgLayout = input.getSgLayout();
-    DenseI32ArrayAttr inputSgData = input.getSgData();
+    SmallVector<int64_t> inputSgLayout = input.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> inputSgData = input.getEffectiveSgDataAsInt();
     DenseI32ArrayAttr inputOrder = input.getOrder();
-    DenseI32ArrayAttr targetSgLayout = target.getSgLayout();
-    DenseI32ArrayAttr targetSgData = target.getSgData();
+    SmallVector<int64_t> targetSgLayout = target.getEffectiveSgLayoutAsInt();
+    SmallVector<int64_t> targetSgData = target.getEffectiveSgDataAsInt();
     DenseI32ArrayAttr targetOrder = target.getOrder();
 
     // TODO: currently we only support for optimal case, where input and
@@ -1138,12 +1138,11 @@ struct WgToSgVectorShapeCastOp
           return false;
       return srcIdx == src.size();
     };
-
     xegpu::DistributeLayoutAttr layoutToDistribute = layout;
 
     if (checkOnlyExpandUnitDims(srcShape, wgShape)) {
       xegpu::DistributeLayoutAttr sourceLayout =
-          xegpu::getDistributeLayoutAttr(op.getSource());
+          xegpu::getTemporaryLayout(op->getOpOperand(0));
 
       auto usedByBroadcastOp = [](vector::ShapeCastOp op) {
         return llvm::all_of(op.getResult().getUsers(), [](Operation *user) {
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index d3906e37ffbf1..10fe06ddb756f 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -186,8 +186,7 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
     return layout;
   }
 
-  auto layout = getDistributeLayoutAttr(opr.get());
-  return layout;
+  return nullptr;
 }
 
 // Returns the permanent layout attribute for the given result if it's
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index ad346307437e4..5508e6e938f67 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -155,7 +155,7 @@ gpu.module @test_distribution {
     %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
     %1 = xegpu.load_nd %0[%block_id_x, 0]  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
     %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
-    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256xf32> to vector<256x1xf32>
+    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
     %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
     %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
     xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index da6ad976d3730..d7145ab91ac2f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -462,7 +462,7 @@ gpu.module @test_distribution {
     %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
     %muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
     //CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
-    %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>} : vector<128xindex> to vector<1x1x1x128xindex>
+    %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex> to vector<1x1x1x128xindex>
     gpu.return
   }
 
@@ -655,4 +655,40 @@ gpu.module @test_distribution {
       -> vector<256x128xf32>
     gpu.return
   }
+
+  // CHECK-LABEL: distribute_nested_slice
+  // CHECK: %[[V0:.*]] = vector.shape_cast %{{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
+  // CHECK: %[[V1:.*]] = vector.broadcast %[[V0]] : vector<32x1x32x1xf32> to vector<32x16x32x16xf32>
+  // CHECK: %[[V2:.*]] = vector.shape_cast %[[V1]] : vector<32x16x32x16xf32> to vector<32x16x32x16x1xf32>
+  // CHECK: %[[V3:.*]] = vector.broadcast %[[V2]] : vector<32x16x32x16x1xf32> to vector<32x16x32x16x16xf32>
+  // CHECK: %[[V4:.*]] = vector.shape_cast %[[V3]] : vector<32x16x32x16x16xf32> to vector<32x16x1x32x16x16xf32>
+  // CHECK: %[[V5:.*]] = vector.broadcast %[[V4]] : vector<32x16x1x32x16x16xf32> to vector<32x16x16x32x16x16xf32>
+  gpu.func @distribute_nested_slice(%src: memref<256x256xf32>) {
+
+    %tdesc = xegpu.create_nd_tdesc %src : memref<256x256xf32>
+      -> !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>>
+
+    %load =  xegpu.load_nd %tdesc[0, 0] {layout = #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>}
+      : !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>>
+      -> vector<256x256xf32>
+
+    %load2 = xegpu.convert_layout %load <{input_layout = #xegpu.layout<sg_layout = [8, 8],  sg_data = [32, 32]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>}> : vector<256x256xf32>
+
+    %scast = vector.shape_cast %load2 {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>} : vector<256x256xf32> to vector<256x1x256x1xf32>
+
+    %bcast = vector.broadcast %scast {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x1x256x1xf32> to vector<256x16x256x16xf32>
+
+    %scast1 = vector.shape_cast %bcast {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x16x256x16xf32> to vector<256x16x256x16x1xf32>
+
+    %bcast1 = vector.broadcast %scast1 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>}  : vector<256x16x256x16x1xf32> to vector<256x16x256x16x16xf32>
+
+    %scast2 = vector.shape_cast %bcast1 {layout_result_0 =
+        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x16xf32> to vector<256x16x1x256x16x16xf32>
+
+    %bcast2 = vector.broadcast %scast2 {layout_result_0 =
+        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 =
+        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>} : vector<256x16x1x256x16x16xf32> to vector<256x16x16x256x16x16xf32>
+    gpu.return
+  }
+
 }

>From 1fd70d9648c446be3a552b81ed5dacec20b78b7f Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 16 Jan 2026 00:53:32 +0000
Subject: [PATCH 2/3] address feedback

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 32 +++++++++++++---------
 1 file changed, 19 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 13bf475920075..abcfadf9b6fc6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -587,7 +587,7 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
   // dims.
   llvm::SmallDenseSet<int64_t> thisDims(
       flattenedThis.getDims().asArrayRef().begin(),
-      flattenedThis.getDims().asArrayRef().end()); 
+      flattenedThis.getDims().asArrayRef().end());
   return llvm::all_of(flattenedOther.getDims().asArrayRef(),
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
@@ -603,18 +603,22 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
           (flattenedThis.getDims() == flattenedOther.getDims()));
 }
 
-// Helper function to adjust unit dimensions from sliced space to parent space
+// Helper function to adjust dimensions from sliced space to parent space
 // say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
-// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps to
-// dim [0] in parent space; if we want to set unit dim [1] in sliced space, it maps to
-// dim [2] in parent space.
+// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
+// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
+// it maps to dim [2] in parent space.
 static SetVector<int64_t>
-adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
-                            ArrayRef<int64_t> sliceDims) {
+mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
+                           ArrayRef<int64_t> sliceDims) {
   // get max number from sliceDims and unitDims to determine parent space rank
+  // the recovered parent space from sliceDims/unitDims is not necessary the
+  // actual parent rank. As long as the parent space rank covers both maximum
+  // number of sliceDims and unitDims, the algorithm works.
   int64_t maxDim = -1;
   maxDim = std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
-  maxDim = std::max(maxDim, *std::max_element(unitDims.begin(), unitDims.end()));
+  maxDim =
+      std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
   int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
 
   // get remaining dims in parent space after applying slicing with parent's slice Dims
@@ -628,9 +632,9 @@ adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
 
   // Map unit dims from sliced space to parent space
   SetVector<int64_t> adjustUnitDims;
-  for (auto dim : unitDims) {
-      int64_t mappedDim = remainingDims[dim];
-      adjustUnitDims.insert(mappedDim);
+  for (auto dim : dimsToMap) {
+    int64_t mappedDim = remainingDims[dim];
+    adjustUnitDims.insert(mappedDim);
   }
 
   return adjustUnitDims;
@@ -642,7 +646,8 @@ DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) cons
 
   ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
-  SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+  SetVector<int64_t> adjustUnitDims =
+      mapSlicedDimsToParentSpace(unitDims, sliceDims);
 
   return SliceAttr::get(getContext(), parentLayout.setUnitDimData(adjustUnitDims),
                         getDims());
@@ -654,7 +659,8 @@ DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) co
 
   ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
 
-  SetVector<int64_t> adjustUnitDims = adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+  SetVector<int64_t> adjustUnitDims =
+      mapSlicedDimsToParentSpace(unitDims, sliceDims);
 
   return SliceAttr::get(getContext(), parentLayout.setUnitDimLayout(adjustUnitDims),
                         getDims());

>From 5b1e04f5f176ba86031481118b673997234e03bd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 20 Jan 2026 06:18:51 +0000
Subject: [PATCH 3/3] improve comments

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 80853f437a830..646c6dec862ae 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -624,10 +624,9 @@ bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
 static SetVector<int64_t>
 mapSlicedDimsToParentSpace(const SetVector<int64_t> &dimsToMap,
                            ArrayRef<int64_t> sliceDims) {
-  // get max number from sliceDims and unitDims to determine parent space rank
-  // the recovered parent space from sliceDims/unitDims is not necessary the
-  // actual parent rank. As long as the parent space rank covers both maximum
-  // number of sliceDims and unitDims, the algorithm works.
+  // Rather than recovering the exact parent rank, we compute a safe upper bound
+  // so that dimsToMap can be adjusted safely. This upper bound is defined as
+  // max(dimsToMap, sliceDims) + 1 + sliceDims.size().
   int64_t maxDim = -1;
   maxDim = std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
   maxDim =



More information about the Mlir-commits mailing list