[Mlir-commits] [mlir] [MLIR][XeGPU] Extend propagation and sg_to_lane distribution pass support broadcast with low rank and scalar source input (PR #170409)

Jianhui Li llvmlistbot at llvm.org
Fri Dec 5 19:42:22 PST 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/170409

>From 662e38c0be162653c5f91f6b1a7d5a27e541195c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 3 Dec 2025 00:27:14 +0000
Subject: [PATCH 1/6] broadcast 1d/scalar to 2d, propagation and sg
 distribution

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  14 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  21 +++
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  26 +++-
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 128 +++++++++++++++++-
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  61 +++++++++
 5 files changed, 244 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 93c5187b00756..2103b169b5c00 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -283,9 +283,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       }
                       return true;
                     }]>,
-    InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
+    InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
+
+    InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isIdentical",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
   ];
 }
@@ -501,6 +506,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
+    
+    /// Check if this is identical to some other layout.
+    bool isIdentical(const xegpu::DistributeLayoutAttr &other); 
 
   }];
 
@@ -670,7 +678,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
-
+    
+    /// Check if this is identical to some other layout.
+    bool isIdentical(const xegpu::DistributeLayoutAttr &other); 
   }];
 
   let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index fb5d1e758dbd1..efcb7f2b5e4c2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -391,6 +391,13 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   return genCoordinates(builder, loc, ids, layout, subShape, shape);
 }
 
+bool LayoutAttr::isIdentical(const xegpu::DistributeLayoutAttr &other) {
+  if (dyn_cast<xegpu::SliceAttr>(other))
+    return false;
+
+  return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -511,6 +518,20 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
 
+bool SliceAttr::isIdentical(const xegpu::DistributeLayoutAttr &other) {
+  if (dyn_cast<xegpu::LayoutAttr>(other))
+    return false;
+
+  auto flattenedThis = flatten();
+  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+  if ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+      (flattenedThis.getDims() == flattenedOther.getDims())) {
+    return true;
+  }
+  return false;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f2b0e71c9397f..cfa88250f9f14 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -588,9 +588,33 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
 
   // Only consider nD -> nD broadcast.
   if (sourceTy.getRank() != resultTy.getRank()) {
-    broadcast.emitWarning("Expecting source and result to have same rank.");
+    //  broadcast.emitWarning("Expecting source and result to have same rank.");
+
+    auto sourceDims = sourceTy.getShape();
+    auto resultDims = resultTy.getShape();
+    // adding the missing leading missing dims
+    SmallVector<int64_t> bcastDims;
+    int64_t dimDiff = resultTy.getRank() - sourceTy.getRank();
+    for (int i = 0; i < dimDiff; i++) {
+      bcastDims.push_back(i);
+    }
+
+    // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+    // broadcasted dim
+    for (size_t i = 0; i < sourceDims.size(); i++) {
+      if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
+        bcastDims.push_back(i + dimDiff);
+    }
+
+    // create a slice layout for the source
+    xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+        broadcast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
+        DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+
+    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
     return;
   }
+
   SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
   if (broadcastUnitDims.size() != 1) {
     broadcast.emitWarning("Expecting source type to be nD vector only with "
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 0d1c5eeeff711..e06536b828385 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1424,6 +1424,128 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+/// This pattern distributes the `vector.broadcast` operation across lanes in a
+/// warp. The pattern supports three use cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector
+///    must have a slice layout of the result. If the distributed source and
+///    target vector types are identical, this lowers to a no-op; otherwise, it
+///    remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target:
+///    The source vector must have unit dimensions, and lane_layout must be unit
+///    size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
+///    scalar to distributed result type.
+///
+/// Example 1 (lowering to a broadcast with distributed types):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///   dims = [0]> } : () -> (vector<32xf32>) %2 = vector.broadcast %1
+///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///   1]>}: vector<32xf32> to vector<8x32xf32> gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///   dims = [0]> } : () -> (vector<32xf32>) gpu.yield %0 : vector<32xf32>
+/// }
+/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
+///
+/// Example 2 (no-op):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///   dims = [1]> } : () -> (vector<8xf32>) %1 = vector.shape_cast %0
+///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///   1]>}: vector<8xf32> to vector<8x1xf32> %2 = vector.broadcast %1
+///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///   1]>}: vector<8x1xf32> to vector<8x32xf32> gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///   dims = [1]> } : () -> (vector<8xf32>) %1 = vector.shape_cast %0
+///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///   1]>}: vector<8xf32> to vector<8x1xf32> gpu.yield %0 : vector<8x1xf32>
+/// }
+/// // The broadcast is implicit through layout transformation (no-op)
+///  %2 = vector.broadcast %r#0 : vector<8x1xf32> to vector<8x1xf32>
+/// ```
+struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
+    if (!yieldOperand)
+      return failure();
+    auto broadcastOp =
+        cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandIdx = yieldOperand->getOperandNumber();
+
+    // Get the input layout - must be a slice layout
+    VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(broadcastOp.getSource());
+    if (sourceType) {
+      if (!sourceLayout || !isa<xegpu::SliceAttr>(sourceLayout))
+        return rewriter.notifyMatchFailure(
+            warpOp,
+            "Broadcast input must be scalar or have a slice layout attribute.");
+      // also the sourceLayout must be a slice of the broadcast result layout
+      xegpu::DistributeLayoutAttr resultLayout =
+          xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
+      assert(resultLayout && "Broadcast result must have layout attribute.");
+      if (!sourceLayout.isSliceOf(resultLayout) ||
+          sourceLayout.isIdentical(resultLayout))
+        return rewriter.notifyMatchFailure(
+            warpOp, "Broadcast input layout must be a slice of result layout.");
+    }
+    // Get the distributed source type based on layout
+    FailureOr<VectorType> sourceDistTypeOrFailure =
+        getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+    if (failed(sourceDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Failed to distribute the source vector type.");
+
+    // Yield the source from the warp op - broadcast is a no-op
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {broadcastOp.getSource()},
+        {sourceDistTypeOrFailure.value()}, newRetIndices);
+
+    // Replace the broadcast result with the distributed source
+    Value distributedVal = newWarpOp.getResult(newRetIndices[0]);
+    Value newBroadcast = distributedVal;
+    // if sourceDistType is same as orignial warp result type, no need to
+    //  re-create broadcast op
+    if (distributedVal.getType() != warpOp.getResult(operandIdx).getType()) {
+      // generate broadcast op outside warp op to have correct type
+      rewriter.setInsertionPointAfter(newWarpOp);
+      newBroadcast = vector::BroadcastOp::create(
+          rewriter, newWarpOp.getLoc(),
+          cast<VectorType>(warpOp.getResult(operandIdx).getType()),
+          distributedVal);
+    }
+
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
+    return success();
+  }
+};
+
 /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
 /// `gpu.warp_execute_on_lane_0` region.
 struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
@@ -1855,9 +1977,9 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   patterns.add<CreateNdDescDistribution, StoreNdDistribution,
                LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
                GpuBarrierDistribution, VectorMultiReductionDistribution,
-               LoadDistribution, StoreDistribution, VectorTransposeDistribution,
-               VectorBitcastDistribution, LoadMatrixDistribution,
-               StoreMatrixDistribution,
+               VectorBroadcastDistribution, LoadDistribution, StoreDistribution,
+               VectorTransposeDistribution, VectorBitcastDistribution,
+               LoadMatrixDistribution, StoreMatrixDistribution,
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 8fd3cca5594cb..e3b362f62b4f2 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -330,3 +330,64 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) {
+gpu.module @xevm_module{
+   gpu.func  @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16>
+    %tdesc0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+    // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16>
+      %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+    xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) {
+gpu.module @xevm_module{
+   gpu.func  @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [16] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: vector<16xi1>
+    %1 = xegpu.load %arg0[%c0], %mask {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: memref<16xf16>, index, vector<16xi1> -> vector<16xf16>
+    
+    %11 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+    %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+    // CHECK-NOT: vector.broadcast
+    // CHECK-NOT: vector.shape_cast
+ 
+    %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: xegpu.store_nd {{.*}}, {{.*}}[{{.*}}, {{.*}}]
+    // CHECK-SAME: : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+
+    xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) {
+gpu.module @xevm_module{
+   gpu.func  @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %9 = gpu.block_id  x
+    %10 = arith.index_cast %9 : index to i16
+    %11 = arith.bitcast %10 : i16 to f16
+    // CHECK: vector.broadcast {{.*}} : f16 to vector<16xf16>
+    %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+    %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+

>From a5e068b1378ee4c18565f8d17777894bf10e99b8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 3 Dec 2025 02:30:11 +0000
Subject: [PATCH 2/6] add propagation tests

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  9 +--
 mlir/test/Dialect/XeGPU/propagate-layout.mlir | 58 +++++++++++++++++++
 2 files changed, 61 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index cfa88250f9f14..a36b2cc55a0ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -581,15 +581,12 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   // Only consider vector to vector broadcasts for now.
   VectorType resultTy = broadcast.getResultVectorType();
   VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
-  if (!sourceTy) {
-    broadcast.emitWarning("Expecting source type to be a vector type.");
+  // skip layout propagation for non-vector source operand.
+  if (!sourceTy)
     return;
-  }
 
-  // Only consider nD -> nD broadcast.
+  // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
   if (sourceTy.getRank() != resultTy.getRank()) {
-    //  broadcast.emitWarning("Expecting source and result to have same rank.");
-
     auto sourceDims = sourceTy.getShape();
     auto resultDims = resultTy.getShape();
     // adding the missing leading missing dims
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index f8b59b87a122b..48e77d867508b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -640,3 +640,61 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
   return
 }
 }
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK:         %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:      !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME:       {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT:    %[[BROADCAST:.*]] = vector.broadcast %[[REDUCE]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+func.func @vector_broadcast_1d_to_2d_broadcast_along_row(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0000> : vector<16xf16>
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
+  %5 = vector.broadcast %4 : vector<16xf16> to vector<16x16xf16>
+  xegpu.store_nd %5, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_2d_to_2d_along_column(
+// CHECK:            %[[REDUCE:.*]] = vector.multi_reduction <add>
+// CHECK-SAME:       {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT:    %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+// CHECK-NEXT:    vector.broadcast %[[SHAPECAST]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+
+func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0000> : vector<16xf16>
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
+  %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
+  %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16>
+  xegpu.store_nd %6, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_scalar_to_vector(
+// CHECK:         %[[CST:.*]] = arith.constant 0.{{.*}} : f16
+// CHECK-NEXT:    %[[BROADCAST:.*]] = vector.broadcast %[[CST]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+
+func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16>) {
+  %cst = arith.constant 0.0000 : f16
+  %6 = vector.broadcast %cst : f16 to vector<16x16xf16>
+  xegpu.store_nd %6, %arg0  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+}
\ No newline at end of file

>From 23df287f44f33aae7b6afcf8f66839959888ab00 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 3 Dec 2025 03:45:12 +0000
Subject: [PATCH 3/6] fix minor bug

---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp |  4 +--
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 25 +++++++++++--------
 2 files changed, 17 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index a36b2cc55a0ad..efc8fe9f371c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -940,7 +940,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   } else {
 
     // The layout is strictly determined by the payload type.
-    auto payloadTy = dyn_cast<VectorType>(load.getValueType());
+    auto payloadTy = load.getValueType();
     if (!payloadTy) {
       load.emitWarning("Not propagating, non-vector payload supplied.");
       return;
@@ -1010,7 +1010,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     // Currently, for 2D StoreScatterOp we expect that the height dimension of
     // the tensor descriptor is equal to the subgroup size. This is ensured by
     // the op verifier.
-    auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
+    auto payloadTy = storeScatter.getValueType();
     if (!payloadTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index e06536b828385..86b572dd552ad 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1464,22 +1464,27 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
 /// ```
 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
 ///   %0 = "some_def"() {layout_result_0 =
-///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
-///   dims = [1]> } : () -> (vector<8xf32>) %1 = vector.shape_cast %0
-///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
-///   1]>}: vector<8xf32> to vector<8x1xf32> %2 = vector.broadcast %1
-///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
-///   1]>}: vector<8x1xf32> to vector<8x32xf32> gpu.yield %1 : vector<8x32xf32>
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [1]> } : () -> (vector<8xf32>)
+///   %1 = vector.shape_cast %0
+///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///      1]>}: vector<8xf32> to vector<8x1xf32>
+///   %2 = vector.broadcast %1
+///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///     1]>}: vector<8x1xf32> to vector<8x32xf32>
+///   gpu.yield %1 : vector<8x32xf32>
 /// }
 /// ```
 /// is lowered to:
 /// ```
 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
 ///   %0 = "some_def"() {layout_result_0 =
-///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
-///   dims = [1]> } : () -> (vector<8xf32>) %1 = vector.shape_cast %0
-///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
-///   1]>}: vector<8xf32> to vector<8x1xf32> gpu.yield %0 : vector<8x1xf32>
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [1]> } : () -> (vector<8xf32>)
+///   %1 = vector.shape_cast %0
+///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///     1]>}: vector<8xf32> to vector<8x1xf32>
+///   gpu.yield %0 : vector<8x1xf32>
 /// }
 /// // The broadcast is implicit through layout transformation (no-op)
 ///  %2 = vector.broadcast %r#0 : vector<8x1xf32> to vector<8x1xf32>

>From 03a47382ea13811f3df0ed1c0896df13689a1131 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 4 Dec 2025 03:18:29 +0000
Subject: [PATCH 4/6] address feedback

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  6 +--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 11 ++---
 .../XeGPU/Transforms/XeGPUPropagateLayout.cpp | 17 ++++----
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 42 ++++++++++---------
 4 files changed, 38 insertions(+), 38 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 2103b169b5c00..3d28c1fe9bfab 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -290,7 +290,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
 
     InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
                     /*retTy=*/"bool",
-                    /*methodName=*/"isIdentical",
+                    /*methodName=*/"isEqualTo",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
   ];
 }
@@ -508,7 +508,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
     
     /// Check if this is identical to some other layout.
-    bool isIdentical(const xegpu::DistributeLayoutAttr &other); 
+    bool isEqualTo(const xegpu::DistributeLayoutAttr &other); 
 
   }];
 
@@ -680,7 +680,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
     
     /// Check if this is identical to some other layout.
-    bool isIdentical(const xegpu::DistributeLayoutAttr &other); 
+    bool isEqualTo(const xegpu::DistributeLayoutAttr &other); 
   }];
 
   let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 30c6238a19d87..af8d8ec7156c7 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -390,7 +390,7 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   return genCoordinates(builder, loc, ids, layout, subShape, shape);
 }
 
-bool LayoutAttr::isIdentical(const xegpu::DistributeLayoutAttr &other) {
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
   if (dyn_cast<xegpu::SliceAttr>(other))
     return false;
 
@@ -517,18 +517,15 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
 
-bool SliceAttr::isIdentical(const xegpu::DistributeLayoutAttr &other) {
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
   if (dyn_cast<xegpu::LayoutAttr>(other))
     return false;
 
   auto flattenedThis = flatten();
   auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
 
-  if ((flattenedThis.getParent() == flattenedOther.getParent()) &&
-      (flattenedThis.getDims() == flattenedOther.getDims())) {
-    return true;
-  }
-  return false;
+  return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+          (flattenedThis.getDims() == flattenedOther.getDims()));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index f8f8bf4a49623..3cc76315dda84 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -588,23 +588,22 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   if (sourceTy.getRank() != resultTy.getRank()) {
     auto sourceDims = sourceTy.getShape();
     auto resultDims = resultTy.getShape();
-    // adding the missing leading missing dims
     SmallVector<int64_t> bcastDims;
-    int64_t dimDiff = resultTy.getRank() - sourceTy.getRank();
-    for (int i = 0; i < dimDiff; i++) {
+    auto dimDiff = resultTy.getRank() - sourceTy.getRank();
+    // adding the missing leading dims
+    for (int i = 0; i < dimDiff; i++)
       bcastDims.push_back(i);
-    }
 
     // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
     // broadcasted dim
-    for (size_t i = 0; i < sourceDims.size(); i++) {
+    for (size_t i = 0; i < sourceDims.size(); i++)
       if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
         bcastDims.push_back(i + dimDiff);
-    }
 
     // create a slice layout for the source
     xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
-        broadcast->getContext(), cast<xegpu::LayoutAttr>(resultLayout.get()),
+        broadcast->getContext(),
+        cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
         DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
 
     propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
@@ -938,7 +937,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   } else {
 
     // The layout is strictly determined by the payload type.
-    auto payloadTy = load.getValueType();
+    VectorType payloadTy = load.getValueType();
     if (!payloadTy) {
       load.emitWarning("Not propagating, non-vector payload supplied.");
       return;
@@ -1008,7 +1007,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     // Currently, for 2D StoreScatterOp we expect that the height dimension of
     // the tensor descriptor is equal to the subgroup size. This is ensured by
     // the op verifier.
-    auto payloadTy = storeScatter.getValueType();
+    VectorType payloadTy = storeScatter.getValueType();
     if (!payloadTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 86b572dd552ad..bbf94be26398e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -1445,18 +1445,21 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
 /// ```
 /// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
 ///   %0 = "some_def"() {layout_result_0 =
-///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
-///   dims = [0]> } : () -> (vector<32xf32>) %2 = vector.broadcast %1
-///   {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
-///   1]>}: vector<32xf32> to vector<8x32xf32> gpu.yield %1 : vector<8x32xf32>
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [0]> } : () -> (vector<32xf32>)
+///   %2 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
+///     : vector<32xf32> to vector<8x32xf32>
+///     gpu.yield %1 : vector<8x32xf32>
 /// }
 /// ```
 /// is lowered to:
 /// ```
 /// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
 ///   %0 = "some_def"() {layout_result_0 =
-///   #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
-///   dims = [0]> } : () -> (vector<32xf32>) gpu.yield %0 : vector<32xf32>
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [0]> } : () -> (vector<32xf32>)
+///   gpu.yield %0 : vector<32xf32>
 /// }
 /// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
 ///
@@ -1484,10 +1487,10 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
 ///   %1 = vector.shape_cast %0
 ///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
 ///     1]>}: vector<8xf32> to vector<8x1xf32>
-///   gpu.yield %0 : vector<8x1xf32>
+///   gpu.yield %1 : vector<8x1xf32>
 /// }
 /// // The broadcast is implicit through layout transformation (no-op)
-///  %2 = vector.broadcast %r#0 : vector<8x1xf32> to vector<8x1xf32>
+///  "some_use"(%r#0)
 /// ```
 struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
   using gpu::WarpDistributionPattern::WarpDistributionPattern;
@@ -1514,10 +1517,11 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
       xegpu::DistributeLayoutAttr resultLayout =
           xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
       assert(resultLayout && "Broadcast result must have layout attribute.");
-      if (!sourceLayout.isSliceOf(resultLayout) ||
-          sourceLayout.isIdentical(resultLayout))
+      if (!sourceLayout.isSliceOf(resultLayout) &&
+          !sourceLayout.isEqualTo(resultLayout))
         return rewriter.notifyMatchFailure(
-            warpOp, "Broadcast input layout must be a slice of result layout.");
+            warpOp, "Broadcast input layout must be either a slice of or equal "
+                    "to result layout.");
     }
     // Get the distributed source type based on layout
     FailureOr<VectorType> sourceDistTypeOrFailure =
@@ -1533,17 +1537,17 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
         {sourceDistTypeOrFailure.value()}, newRetIndices);
 
     // Replace the broadcast result with the distributed source
-    Value distributedVal = newWarpOp.getResult(newRetIndices[0]);
-    Value newBroadcast = distributedVal;
+    Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+    Value newBroadcast = distributedSource;
     // if sourceDistType is same as orignial warp result type, no need to
     //  re-create broadcast op
-    if (distributedVal.getType() != warpOp.getResult(operandIdx).getType()) {
+    if (distributedSource.getType() != warpOp.getResult(operandIdx).getType()) {
       // generate broadcast op outside warp op to have correct type
       rewriter.setInsertionPointAfter(newWarpOp);
       newBroadcast = vector::BroadcastOp::create(
           rewriter, newWarpOp.getLoc(),
           cast<VectorType>(warpOp.getResult(operandIdx).getType()),
-          distributedVal);
+          distributedSource);
     }
 
     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
@@ -1982,9 +1986,9 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   patterns.add<CreateNdDescDistribution, StoreNdDistribution,
                LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
                GpuBarrierDistribution, VectorMultiReductionDistribution,
-               VectorBroadcastDistribution, LoadDistribution, StoreDistribution,
-               VectorTransposeDistribution, VectorBitcastDistribution,
-               LoadMatrixDistribution, StoreMatrixDistribution,
+               LoadDistribution, StoreDistribution, VectorTransposeDistribution,
+               VectorBitcastDistribution, LoadMatrixDistribution,
+               StoreMatrixDistribution,
                MemrefExtractAlignedPointerAsIndexDistribution>(
       patterns.getContext(),
       /*pattern benefit=*/regularPatternBenefit);
@@ -1992,7 +1996,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   // patterns. Therefore, assign higher benefit.
   patterns
       .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
-           VectorInsertStridedSliceDistribution>(
+           VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
           patterns.getContext(),
           /*pattern benefit=*/highPatternBenefit);
 }

>From ed38b401db7a6ee1d2d1eb8783e8691e8c1df153 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 5 Dec 2025 19:29:22 +0000
Subject: [PATCH 5/6] remove trailing space

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index d93ffb70881bd..bd4851f90f856 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -404,7 +404,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
-                       OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint, 
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
                        OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let results = (outs XeGPU_ValueType: $value);

>From e5397f1b105664258f68247da17feb3cc8375af2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 6 Dec 2025 03:42:01 +0000
Subject: [PATCH 6/6] add unit tests and fix bugs

---
 .../Transforms/XeGPUSubgroupDistribute.cpp    | 68 +++++++++++--------
 .../XeGPU/subgroup-distribute-unit.mlir       | 65 ++++++++++++++++++
 .../Dialect/XeGPU/subgroup-distribute.mlir    |  2 +-
 3 files changed, 106 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index bbf94be26398e..6ddbb486a76b1 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -80,7 +80,8 @@ static constexpr unsigned highPatternBenefit = 2;
 /// | 2x32x16               | [1, 16]     | 2x32x1                   |
 static FailureOr<VectorType>
 getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
-                                VectorType originalType) {
+                                VectorType originalType,
+                                bool allowUnitDim = false) {
   if (!layout)
     return failure();
   assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
@@ -99,7 +100,10 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
     if (i < distributionStart)
       continue;
-
+    if (allowUnitDim && dim == 1) {
+      distributedShape[i] = dim;
+      continue;
+    }
     // Check if the dimension can be distributed evenly.
     if (dim % effectiveLaneLayout[i - distributionStart] != 0)
       return failure();
@@ -1504,50 +1508,58 @@ struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
         cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
     unsigned operandIdx = yieldOperand->getOperandNumber();
 
-    // Get the input layout - must be a slice layout
     VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType destType =
+        dyn_cast<VectorType>(broadcastOp.getResult().getType());
     xegpu::DistributeLayoutAttr sourceLayout =
-        xegpu::getDistributeLayoutAttr(broadcastOp.getSource());
+        xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
     if (sourceType) {
-      if (!sourceLayout || !isa<xegpu::SliceAttr>(sourceLayout))
-        return rewriter.notifyMatchFailure(
-            warpOp,
-            "Broadcast input must be scalar or have a slice layout attribute.");
-      // also the sourceLayout must be a slice of the broadcast result layout
-      xegpu::DistributeLayoutAttr resultLayout =
-          xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
       assert(resultLayout && "Broadcast result must have layout attribute.");
-      if (!sourceLayout.isSliceOf(resultLayout) &&
-          !sourceLayout.isEqualTo(resultLayout))
+      bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
+      bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+
+      if (!isSliceOf && !isEqualTo) {
         return rewriter.notifyMatchFailure(
             warpOp, "Broadcast input layout must be either a slice of or equal "
                     "to result layout.");
+      }
     }
-    // Get the distributed source type based on layout
-    FailureOr<VectorType> sourceDistTypeOrFailure =
-        getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
-    if (failed(sourceDistTypeOrFailure))
+
+    FailureOr<VectorType> sourceDistType = getDistVecTypeBasedOnLaneLayout(
+        sourceLayout, sourceType, /*allowUnitDim=*/true);
+    FailureOr<VectorType> destDistType =
+        getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+    if ((sourceType != nullptr) && (failed(sourceDistType))) {
       return rewriter.notifyMatchFailure(
           warpOp, "Failed to distribute the source vector type.");
+    }
+    if (failed(destDistType))
+      return rewriter.notifyMatchFailure(
+          warpOp, "Failed to distribute the dest vector type.");
+
+    Type sourceElemOrDistType;
+    if (sourceType)
+      sourceElemOrDistType = sourceDistType.value();
+    else
+      sourceElemOrDistType = broadcastOp.getSourceType();
 
-    // Yield the source from the warp op - broadcast is a no-op
     SmallVector<size_t> newRetIndices;
     auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
-        rewriter, warpOp, {broadcastOp.getSource()},
-        {sourceDistTypeOrFailure.value()}, newRetIndices);
+        rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
+        newRetIndices);
 
-    // Replace the broadcast result with the distributed source
     Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+
     Value newBroadcast = distributedSource;
-    // if sourceDistType is same as orignial warp result type, no need to
-    //  re-create broadcast op
-    if (distributedSource.getType() != warpOp.getResult(operandIdx).getType()) {
+
+    if (sourceElemOrDistType != destDistType.value()) {
       // generate broadcast op outside warp op to have correct type
       rewriter.setInsertionPointAfter(newWarpOp);
-      newBroadcast = vector::BroadcastOp::create(
-          rewriter, newWarpOp.getLoc(),
-          cast<VectorType>(warpOp.getResult(operandIdx).getType()),
-          distributedSource);
+      newBroadcast =
+          vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
+                                      destDistType.value(), distributedSource);
     }
 
     rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 44ec21359593f..216f3d19cff94 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -920,4 +920,69 @@ gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane
+// CHECK-SAME: (%[[ARG0:.*]]: index) {
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<1xf16>)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST_INNER:.*]] = vector.broadcast %[[DEF]]
+// CHECK: gpu.yield %[[BCAST_INNER]], %[[DEF]]
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[R]]#1 : vector<1xf16> to vector<16x1xf16>
+// CHECK: "some_use"(%[[BCAST]])
+gpu.func  @vector_broadcast_1d_to_2d_broadcast_within_lane(%laneid: index) {
+
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+
+    %1 = "some_def"() : () -> vector<16xf16>
+
+    %2 = vector.broadcast %1 {
+      layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    } : vector<16xf16> to vector<16x16xf16>
+
+    gpu.yield %2 : vector<16x16xf16>
+  }
+  "some_use"(%r) : (vector<16x1xf16>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<16x1xf16>)
+// CHECK:   %[[DEF:.*]] = "some_def"() : () -> vector<16x1xf16>
+// CHECK:   %[[BCAST:.*]] = vector.broadcast %[[DEF]]
+// CHECK-SAME: : vector<16x1xf16> to vector<16x16xf16>
+// CHECK:   gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, vector<16x1xf16>
+// CHECK: "some_use"(%[[R]]#1) : (vector<16x1xf16>) -> ()
+gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: index) {
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+    %1 = "some_def"() : () -> vector<16x1xf16>
+    %2 = vector.broadcast %1 {
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    } : vector<16x1xf16> to vector<16x16xf16>
+    gpu.yield %2: vector<16x16xf16>
+  }
+  "some_use"(%0) : (vector<16x1xf16>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, f16)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x1xf16>
+// CHECK: "some_use"(%[[RESULT]])
+gpu.func
+ at vector_shape_cast_scalar_to_vector(%arg0: index) {
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+    %1 = "some_def"() : () -> f16
+    %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+    gpu.yield %2 : vector<16x16xf16>
+  }
+  "some_use"(%0) : (vector<16x1xf16>) -> ()
+  gpu.return
+}
+
 }
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 05bb41f7bd034..e5e3d2a1c1ad5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -344,7 +344,7 @@ gpu.module @xevm_module{
     %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
     %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
     // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16>
-      %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+    %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
     xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
     gpu.return
   }



More information about the Mlir-commits mailing list