[Mlir-commits] [mlir] [MLIR][XeGPU] Enhance Layout Propagation for broadcasting both leading dimensions and inner unit dimensions (PR #185583)

Jianhui Li llvmlistbot at llvm.org
Tue Mar 10 17:47:44 PDT 2026


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/185583

>From ae1c313b5fe22cabe08de33e9e0775103e4e7185 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 6 Mar 2026 14:00:56 +0000
Subject: [PATCH 1/4] broadcast layout inference enhancing

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      | 20 +++++++++++--------
 .../XeGPU/resolve-layout-conflicts.mlir       |  2 +-
 2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7aa186bb22224..a278b87070bf2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -143,20 +143,24 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
                                   ArrayRef<int64_t> srcShape) {
 
   SmallVector<int64_t> bcastDims;
-  auto returnLayout = resLayout;
 
   // Handling broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
-  int dimDiff = resShape.size() - srcShape.size();
+  size_t dimDiff = resShape.size() - srcShape.size();
+  for (size_t i = 0; i < dimDiff; i++)
+    bcastDims.push_back(i);
 
-  if (dimDiff > 0) {
-    // Adding the missing leading dims
-    for (int i = 0; i < dimDiff; i++)
+  for (size_t i = 0; i < resShape.size(); i++)
+    if ((i < dimDiff) || ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1)))
       bcastDims.push_back(i);
 
-    // Create a slice layout for the source
+  auto returnLayout = resLayout.setUnitDimData(bcastDims);
+  if (dimDiff > 0) {
+    SmallVector<int64_t> sliceDims;
+    for (size_t i = 0; i < dimDiff; i++)
+      sliceDims.push_back(i);
     returnLayout = xegpu::SliceAttr::get(
-        resLayout.getContext(), resLayout,
-        DenseI64ArrayAttr::get(resLayout.getContext(), bcastDims));
+        resLayout.getContext(), returnLayout,
+        DenseI64ArrayAttr::get(resLayout.getContext(), sliceDims));
   }
   return returnLayout;
 }
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index c73a9990542de..033ad188e5347 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -115,7 +115,7 @@ func.func @elementwise_conflict_uniform() -> vector<2xf16> {
 // CHECK-LABEL: func.func @broadcast_source_conflict
 // CHECK:         %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [16]>} : () -> vector<16xf16>
 // CHECK:         %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
-// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}>
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16]>, dims = [0]>}>
 // CHECK-SAME:      : vector<16xf16>
 // CHECK:         %[[BC:.*]] = vector.broadcast %[[CVT]]
 // CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf16> to vector<16x16xf16>

>From 2c52353c52768dc1128f195f812f376d8ec35c71 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 7 Mar 2026 04:53:31 +0000
Subject: [PATCH 2/4] improve broadcast

---
 .../Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp   | 14 +++++++++-----
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp      | 12 +++++++-----
 .../Dialect/XeGPU/resolve-layout-conflicts.mlir    |  2 +-
 3 files changed, 17 insertions(+), 11 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index a278b87070bf2..c244a27b988b6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -146,14 +146,18 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
 
   // Handling broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
   size_t dimDiff = resShape.size() - srcShape.size();
-  for (size_t i = 0; i < dimDiff; i++)
-    bcastDims.push_back(i);
+  // for (size_t i = 0; i < dimDiff; i++)
+  //   bcastDims.push_back(i);
 
-  for (size_t i = 0; i < resShape.size(); i++)
-    if ((i < dimDiff) || ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1)))
+  auto returnLayout = resLayout;
+  for (size_t i = dimDiff; i < resShape.size(); i++) {
+    if ((srcShape[i - dimDiff] == 1) && (resShape[i] != 1))
       bcastDims.push_back(i);
+  }
+
+  if (!bcastDims.empty())
+    returnLayout = returnLayout.setUnitDimData(bcastDims);
 
-  auto returnLayout = resLayout.setUnitDimData(bcastDims);
   if (dimDiff > 0) {
     SmallVector<int64_t> sliceDims;
     for (size_t i = 0; i < dimDiff; i++)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 7f7e8d6ad7734..5abc40b3aea48 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -676,11 +676,13 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   auto srcShape = sourceTy.getShape();
   auto resShape = resultTy.getShape();
 
-  size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
-  for (size_t i = 0; i < srcShape.size(); i++)
-    if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
-      broadcast.emitWarning("broadcast must either from low-rank or same-rank "
-                            "with unit-dim, mixed scenario is not supported!");
+  // size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+  // for (size_t i = 0; i < srcShape.size(); i++)
+  //   if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
+  //     broadcast.emitWarning("broadcast must either from low-rank or same-rank
+  //     "
+  //                           "with unit-dim, mixed scenario is not
+  //                           supported!");
 
   auto resultLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
diff --git a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
index 033ad188e5347..c73a9990542de 100644
--- a/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
+++ b/mlir/test/Dialect/XeGPU/resolve-layout-conflicts.mlir
@@ -115,7 +115,7 @@ func.func @elementwise_conflict_uniform() -> vector<2xf16> {
 // CHECK-LABEL: func.func @broadcast_source_conflict
 // CHECK:         %[[V0:.*]] = "some_op"() {layout_result_0 = #xegpu.layout<inst_data = [16]>} : () -> vector<16xf16>
 // CHECK:         %[[CVT:.*]] = xegpu.convert_layout %[[V0]]
-// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [1, 16]>, dims = [0]>}>
+// CHECK-SAME:      <{input_layout = #xegpu.layout<inst_data = [16]>, target_layout = #xegpu.slice<#xegpu.layout<inst_data = [16, 16]>, dims = [0]>}>
 // CHECK-SAME:      : vector<16xf16>
 // CHECK:         %[[BC:.*]] = vector.broadcast %[[CVT]]
 // CHECK-SAME:      {layout_result_0 = #xegpu.layout<inst_data = [16, 16]>} : vector<16xf16> to vector<16x16xf16>

>From a00bd8a30e4f80824123f09ccbabda261ebfe35a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 9 Mar 2026 23:21:16 +0000
Subject: [PATCH 3/4] polish the implement: change warining to assert

---
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp        |  2 ++
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp   | 17 ++++++++++-------
 .../Transforms/XeGPUSubgroupDistribute.cpp      |  9 +++------
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp  | 11 ++++++-----
 .../XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir      |  4 ++--
 .../Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir | 14 +++++++-------
 6 files changed, 30 insertions(+), 27 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 2f232bcba3496..b94c4b37fce0a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -155,6 +155,8 @@ xegpu::inferBroadcastSourceLayout(xegpu::DistributeLayoutAttr resLayout,
       bcastDims.push_back(i);
   }
 
+  // the sg_layout and lane_layout for unit dimensions are preserved so it can
+  // be propagate to producer op so potentially used by the multi-reduction op.
   if (!bcastDims.empty())
     returnLayout = returnLayout.setUnitDimData(bcastDims);
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 8141e9b87ac95..28da5226f1413 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -676,13 +676,16 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   auto srcShape = sourceTy.getShape();
   auto resShape = resultTy.getShape();
 
-  // size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
-  // for (size_t i = 0; i < srcShape.size(); i++)
-  //   if ((srcShape[i] == 1) && (resShape[i + dimDiff] != 1))
-  //     broadcast.emitWarning("broadcast must either from low-rank or same-rank
-  //     "
-  //                           "with unit-dim, mixed scenario is not
-  //                           supported!");
+  size_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+  if (dimDiff == 0) {
+    bool hasUnitDim =
+        llvm::any_of(srcShape, [](int64_t dim) { return dim == 1; });
+    Operation *srcOp = broadcast.getSource().getDefiningOp();
+    bool produceByShapeCast = srcOp || isa<vector::ShapeCastOp>(srcOp);
+    assert(
+        hasUnitDim && produceByShapeCast &&
+        "When broadcasting from unit-dim, the producer op must be shape_cast!");
+  }
 
   auto resultLayoutAttr =
       dyn_cast<xegpu::DistributeLayoutAttr>(resLayoutInfo.get());
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 38bc95d39c2c6..d7db33a0913b8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1520,12 +1520,9 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
         auto broadcastUnitDimsSet = broadcastOp.computeBroadcastedUnitDims();
         SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
                                                broadcastUnitDimsSet.end());
-        bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
-        if (!isEqualTo)
-          return rewriter.notifyMatchFailure(
-              warpOp, "For same-rank broadcast, source must be identical to "
-                      "adjusted result layouts with unit dims.");
-        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+        assert(sourceLayout.isEqualTo(
+                   sourceLayout.setUnitDimData(broadcastUnitDims)) &&
+               "The sg_data for unit dimensions should be set as 1");
         sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
       }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 139a30e76854f..5a380685d9e7b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -1173,11 +1173,12 @@ struct WgToSgVectorShapeCastOp
 
       if (!sourceLayout.isSliceOf(layout))
         return rewriter.notifyMatchFailure(
-            op, "The ShapeCast op only expands dimensions, the result layout "
-                "must be a slice of the input layout, or vice versa.");
-      layoutToDistribute = layoutToDistribute.setUnitDimData(expandedUnitDims);
-      layoutToDistribute =
-          layoutToDistribute.setUnitDimLayout(expandedUnitDims);
+            op, "The ShapeCast op only expands dimensions, the input layout "
+                "must be a slice of the result layout.");
+
+      assert(layoutToDistribute.isEqualTo(
+                 layoutToDistribute.setUnitDimData(expandedUnitDims)) &&
+             "The sg_data for unit dimensions should be set as 1");
     }
 
     SmallVector<int64_t> sgShape =
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
index 28865c2ec50bb..068dd6d865ead 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
@@ -157,8 +157,8 @@ gpu.module @test_distribution {
     %block_id_x = gpu.block_id x
     %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
     %1 = xegpu.load_nd %0[%block_id_x, 0] {layout = #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>}  : !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>> -> vector<256x128xf32>
-    %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
-    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
+    %2 = vector.multi_reduction <maximumf>, %1, %cst_0 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1], inst_data = [8, 1]>, dims = [1]>} [1] : vector<256x128xf32> to vector<256xf32>
+    %3 = vector.shape_cast %2 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1], inst_data = [8, 1]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1], inst_data = [8, 1]>, dims = [1]>} : vector<256xf32> to vector<256x1xf32>
     %4 = vector.broadcast %3 {layout_result_0 =  #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>} : vector<256x1xf32>to vector<256x128xf32>
     %9 = xegpu.create_nd_tdesc %arg0 : memref<4096x128xf32> -> !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
     xegpu.store_nd %4, %9[%block_id_x, 0] : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.block_tdesc_attr<boundary_check = false>, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 128], inst_data = [8, 16]>>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 9a7ea4d1c2f43..9e0ae881c8a7e 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -948,22 +948,22 @@ gpu.module @test_distribution {
       : !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32]>>
       -> vector<256x256xf32>
 
-    %load2 = xegpu.convert_layout %load <{input_layout = #xegpu.layout<sg_layout = [8, 8],  sg_data = [32, 32]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>}> : vector<256x256xf32>
+    %load2 = xegpu.convert_layout %load <{input_layout = #xegpu.layout<sg_layout = [8, 8],  sg_data = [32, 32]>, target_layout = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>, dims=[1, 3]>}> : vector<256x256xf32>
 
-    %scast = vector.shape_cast %load2 {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, dims=[1, 3]>} : vector<256x256xf32> to vector<256x1x256x1xf32>
+    %scast = vector.shape_cast %load2 {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>, dims=[1, 3]>} : vector<256x256xf32> to vector<256x1x256x1xf32>
 
-    %bcast = vector.broadcast %scast {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x1x256x1xf32> to vector<256x16x256x16xf32>
+    %bcast = vector.broadcast %scast {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>, dims=[4]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 1, 16, 32, 1, 1]>, dims=[2]>, dims=[4]>} : vector<256x1x256x1xf32> to vector<256x16x256x16xf32>
 
-    %scast1 = vector.shape_cast %bcast {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, dims=[4]>} : vector<256x16x256x16xf32> to vector<256x16x256x16x1xf32>
+    %scast1 = vector.shape_cast %bcast {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>, dims=[4]>} : vector<256x16x256x16xf32> to vector<256x16x256x16x1xf32>
 
-    %bcast1 = vector.broadcast %scast1 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>}  : vector<256x16x256x16x1xf32> to vector<256x16x256x16x16xf32>
+    %bcast1 = vector.broadcast %scast1 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 1]>, dims=[2]>}  : vector<256x16x256x16x1xf32> to vector<256x16x256x16x16xf32>
 
     %scast2 = vector.shape_cast %bcast1 {layout_result_0 =
-        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x16xf32> to vector<256x16x1x256x16x16xf32>
+        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 1, 32, 16, 16]>, layout_operand_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 1, 32, 16, 16]>, dims=[2]>} : vector<256x16x256x16x16xf32> to vector<256x16x1x256x16x16xf32>
 
     %bcast2 = vector.broadcast %scast2 {layout_result_0 =
         #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>, layout_operand_0 =
-        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 16, 32, 16, 16]>} : vector<256x16x1x256x16x16xf32> to vector<256x16x16x256x16x16xf32>
+        #xegpu.layout<sg_layout = [8, 1, 1, 8, 1, 1], sg_data = [32, 16, 1, 32, 16, 16]>} : vector<256x16x1x256x16x16xf32> to vector<256x16x16x256x16x16xf32>
     gpu.return
   }
 

>From 55daece69d3519cb095d8db2605a38ef3c61e389 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 10 Mar 2026 06:07:59 +0000
Subject: [PATCH 4/4] add tests

---
 .../XeGPU/propagate-layout-inst-data.mlir     | 21 +++++++++++++++
 .../XeGPU/propagate-layout-subgroup.mlir      | 21 +++++++++++++++
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 27 +++++++++++++++++++
 3 files changed, 69 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
index c510a1d5f0fdf..503fb25deb151 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-inst-data.mlir
@@ -334,3 +334,24 @@ gpu.module @test{
     return
   }
 }
+
+// -----
+gpu.module @test{
+  // CHECK-LABEL: broadcast_both_leadingdims_innerdims
+  // CHECK: arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 1, 16]>} dense<true> : vector<2x2x6x32xi1>
+  // CHECK: arith.constant {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 1, 16]>} dense<1.000000e+00> : vector<2x2x6x32xf32>
+  // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<inst_data = [1, 1, 1, 1]>, dims = [0, 1]>, dims = [1]>} : vector<6xindex>
+  // CHECK: vector.shape_cast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<inst_data = [1, 1, 1, 1]>, dims = [0, 1]>} : vector<6xindex> to vector<6x1xindex>
+  // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [1, 1, 1, 16]>} : vector<6x1xindex> to vector<2x2x6x32xindex>
+  gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
+    %cst = arith.constant dense<true> : vector<2x2x6x32xi1>
+    %cst_0 = arith.constant dense<1.000000e+00> : vector<2x2x6x32xf32>
+    %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
+    %0 = arith.index_cast %intptr : index to i64
+    %1 = vector.step : vector<6xindex>
+    %2 = vector.shape_cast %1 : vector<6xindex> to vector<6x1xindex>
+    %3 = vector.broadcast %2 : vector<6x1xindex> to vector<2x2x6x32xindex>
+    xegpu.store %cst_0, %0[%3], %cst <{layout = #xegpu.layout<inst_data = [1, 1, 1, 16]>}> : vector<2x2x6x32xf32>, i64, vector<2x2x6x32xindex>, vector<2x2x6x32xi1>
+    gpu.return
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 39fd815b1b380..9ee3de4490727 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -165,6 +165,27 @@ gpu.module @test {
   }
 }
 
+// -----
+gpu.module @test {
+  // CHECK-LABEL: broadcast_both_leadingdims_innerdims
+  gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
+    // CHECK: arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<true> : vector<2x2x6x32xi1>
+    %cst = arith.constant dense<true> : vector<2x2x6x32xi1>
+    // CHECK: arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} dense<1.000000e+00> : vector<2x2x6x32xf32>
+    %cst_0 = arith.constant dense<1.000000e+00> : vector<2x2x6x32xf32>
+    %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
+    %0 = arith.index_cast %intptr : index to i64
+    // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [0, 1]>, dims = [1]>} : vector<6xindex>
+    %1 = vector.step : vector<6xindex>
+    // CHECK: vector.shape_cast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [0, 1]>} : vector<6xindex> to vector<6x1xindex>
+    %2 = vector.shape_cast %1 : vector<6xindex> to vector<6x1xindex>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : vector<6x1xindex> to vector<2x2x6x32xindex>
+    %3 = vector.broadcast %2 : vector<6x1xindex> to vector<2x2x6x32xindex>
+    xegpu.store %cst_0, %0[%3], %cst <{layout = #xegpu.layout<sg_layout = [2, 2, 6, 1], sg_data = [1, 1, 1, 32]>}> : vector<2x2x6x32xf32>, i64, vector<2x2x6x32xindex>, vector<2x2x6x32xi1>
+    gpu.return
+  }
+}
+
 // -----
 gpu.module @test {
   // CHECK-LABEL: for_loop_dpas
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index 43c01445876cb..ddd2d22108d1f 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -638,6 +638,33 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
 }
 }
 // -----
+gpu.module @test{
+  // CHECK-LABEL: broadcast_both_leadingdims_innerdims
+  // CHECK: arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>} dense<1.000000e+00> : vector<1x1x1x16xf32>
+  // CHECK: arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>} dense<true> : vector<1x1x1x16xi1>
+  // CHECK: vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<1xindex>
+  // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : index to vector<1xindex>
+  // CHECK: arith.addi {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<1xindex>
+  // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>} : vector<1xindex> to vector<1x1x1x16xindex>
+  gpu.func @broadcast_both_leadingdims_innerdims(%arg0: memref<32x2x192xf32>, %arg1: memref<32x2x192xf32>, %arg2: memref<32x2x192xf32>) kernel attributes {known_block_size = array<i32: 768, 1, 1>, known_grid_size = array<i32: 16, 1, 1>} {
+    %cst = arith.constant dense<1.000000e+00> : vector<1x1x1x16xf32>
+    %cst_0 = arith.constant dense<true> : vector<1x1x1x16xi1>
+    %c6 = arith.constant 6 : index
+    %intptr = memref.extract_aligned_pointer_as_index %arg2 : memref<32x2x192xf32> -> index
+    %0 = arith.index_cast %intptr : index to i64
+    %1 = gpu.subgroup_id : index
+    %2 = arith.remui %1, %c6 : index
+    %3 = arith.remui %2, %c6 : index
+    %4 = vector.step : vector<1xindex>
+    %5 = vector.broadcast %3 : index to vector<1xindex>
+    %6 = arith.addi %4, %5 : vector<1xindex>
+    %7 = vector.broadcast %6 : vector<1xindex> to vector<1x1x1x16xindex>
+    xegpu.store %cst, %0[%7], %cst_0 <{chunk_size = 1 : i64}> : vector<1x1x1x16xf32>, i64, vector<1x1x1x16xindex>, vector<1x1x1x16xi1>
+    xegpu.store %cst, %0[%7], %cst_0 <{chunk_size = 1 : i64}> : vector<1x1x1x16xf32>, i64, vector<1x1x1x16xindex>, vector<1x1x1x16xi1>
+    gpu.return
+  }
+}
+// -----
 gpu.module @test {
 // CHECK-LABEL: func.func @vector_shape_cast_expand_non_unit_dims(
 // CHECK: %[[LOAD:.*]] = xegpu.load %arg0[%[[STEP:.*]]], %[[CST:.*]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<1024xf16>, vector<1024xindex>, vector<1024xi1> -> vector<1024xf16>



More information about the Mlir-commits mailing list