[Mlir-commits] [mlir] 5236af8 - [MLIR][XeGPU] Extend propagation and sg_to_lane distribution pass support broadcast with low rank and scalar source input (#170409)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 9 08:48:31 PST 2025


Author: Jianhui Li
Date: 2025-12-09T08:48:27-08:00
New Revision: 5236af88e5ed0a3449b2292ef02be28b8722b172

URL: https://github.com/llvm/llvm-project/commit/5236af88e5ed0a3449b2292ef02be28b8722b172
DIFF: https://github.com/llvm/llvm-project/commit/5236af88e5ed0a3449b2292ef02be28b8722b172.diff

LOG: [MLIR][XeGPU] Extend propagation and sg_to_lane distribution pass support broadcast with low rank and scalar source input (#170409)

This PR extends XeGPU layout propagation and distribution for
vector.broadcast operation.
It relaxes the restriction of layout propagation to allow low-rank and
scalar source input, and adds a pattern in sg-to-wi distribution to
support the lowering.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
    mlir/test/Dialect/XeGPU/propagate-layout.mlir
    mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
    mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 93c5187b00756..eae0bd4e68a84 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -223,6 +223,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
     InterfaceMethod<"Derive a new layout by dropping InstData",
                     "xegpu::DistributeLayoutAttr",
                     "dropInstData">,
+    InterfaceMethod<"Derive a new layout with sg_data, inst_data and lane_data set to 1 for the specified unit dims",
+                    "xegpu::DistributeLayoutAttr",
+                    "setUnitDimData",
+                    /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
+    InterfaceMethod<"Derive a new layout with sg_lane and lane_layout set to 1 for the specified unit dims",
+                    "xegpu::DistributeLayoutAttr",
+                    "setUnitDimLayout",
+                    /*args=*/(ins "const llvm::SetVector<int64_t>": $unitDims)>,
     InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
                       indices based on the effective layout level.}],
                     "FailureOr<SmallVector<Value>>",
@@ -283,9 +291,14 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
                       }
                       return true;
                     }]>,
-    InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
+    InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
                     /*retTy=*/"bool",
                     /*methodName=*/"isSliceOf",
+                    /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>,
+
+    InterfaceMethod</*desc=*/[{Check if this layout is identical to another layout.}],
+                    /*retTy=*/"bool",
+                    /*methodName=*/"isEqualTo",
                     /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
   ];
 }
@@ -487,6 +500,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
       return {};
     }
 
+    //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
+    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+
+    //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+
     /// Delinearizes a linear ID into its multidimensional indices
     /// based on the effective level of the layout.
     FailureOr<SmallVector<Value>>
@@ -501,6 +520,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
+    
+    /// Check if this is identical to some other layout.
+    bool isEqualTo(const xegpu::DistributeLayoutAttr &other); 
 
   }];
 
@@ -649,6 +671,12 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
       return SliceAttr::get(getContext(), parent, attr.getDims());
     }
 
+    //set the layout for the sepcified unit dims: sg_data, inst_data and lane_data to 1
+    DistributeLayoutAttr setUnitDimData(SetVector<int64_t> unitDims);
+
+    //set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+    DistributeLayoutAttr setUnitDimLayout(SetVector<int64_t> unitDims);
+
     /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
     /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
     /// it will coalese two slice operations and return a simplified SliceAttr
@@ -670,7 +698,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
 
     /// Check if this is slice of some other layout.
     bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
-
+    
+    /// Check if this is identical to some other layout.
+    bool isEqualTo(const xegpu::DistributeLayoutAttr &other); 
   }];
 
   let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 4fe1087d18879..b54d620c3c0c3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -405,7 +405,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
-                       OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint, 
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint,
                        OptionalAttr<DistributeLayoutAttr>:$layout);
 
   let results = (outs XeGPU_ValueType: $value);

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 7ab2e612ed890..1a19ab5fd970b 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -390,6 +390,86 @@ LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
   return genCoordinates(builder, loc, ids, layout, subShape, shape);
 }
 
+bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+  if (dyn_cast<xegpu::SliceAttr>(other))
+    return false;
+
+  return *this == dyn_cast<xegpu::LayoutAttr>(other);
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+  auto sgDataOpt = getSgData();
+  auto instDataOpt = getInstData();
+  auto laneDataOpt = getLaneData();
+
+  SmallVector<int32_t> sgData;
+  SmallVector<int32_t> instData;
+  SmallVector<int32_t> laneData;
+
+  if (sgDataOpt) {
+    sgData = llvm::to_vector(sgDataOpt.asArrayRef());
+  }
+  if (instDataOpt) {
+    instData = llvm::to_vector(instDataOpt.asArrayRef());
+  }
+  if (laneDataOpt) {
+    laneData = llvm::to_vector(laneDataOpt.asArrayRef());
+  }
+
+  for (auto dim : unitDims) {
+    if (dim < static_cast<int64_t>(sgData.size()))
+      sgData[dim] = 1;
+    if (dim < static_cast<int64_t>(instData.size()))
+      instData[dim] = 1;
+    if (dim < static_cast<int64_t>(laneData.size()))
+      laneData[dim] = 1;
+  }
+
+  return LayoutAttr::get(
+      getContext(), getSgLayout(),
+      sgData.empty() ? DenseI32ArrayAttr()
+                     : DenseI32ArrayAttr::get(getContext(), sgData),
+      instData.empty() ? DenseI32ArrayAttr()
+                       : DenseI32ArrayAttr::get(getContext(), instData),
+      getLaneLayout(),
+      laneData.empty() ? DenseI32ArrayAttr()
+                       : DenseI32ArrayAttr::get(getContext(), laneData),
+      getOrder());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+  auto sgLayoutOpt = getSgLayout();
+  auto laneLayoutOpt = getLaneLayout();
+
+  SmallVector<int32_t> sgLayout;
+  SmallVector<int32_t> laneLayout;
+
+  if (sgLayoutOpt) {
+    sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
+  }
+  if (laneLayoutOpt) {
+    laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
+  }
+
+  for (auto dim : unitDims) {
+    if (dim < static_cast<int64_t>(sgLayout.size()))
+      sgLayout[dim] = 1;
+    if (dim < static_cast<int64_t>(laneLayout.size()))
+      laneLayout[dim] = 1;
+  }
+
+  return LayoutAttr::get(
+      getContext(),
+      sgLayout.empty() ? DenseI32ArrayAttr()
+                       : DenseI32ArrayAttr::get(getContext(), sgLayout),
+      getSgData(), getInstData(),
+      laneLayout.empty() ? DenseI32ArrayAttr()
+                         : DenseI32ArrayAttr::get(getContext(), laneLayout),
+      getLaneData(), getOrder());
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_SliceAttr
 //===----------------------------------------------------------------------===//
@@ -510,6 +590,69 @@ bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
                       [&](int64_t dim) { return thisDims.contains(dim); });
 }
 
+bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
+  if (dyn_cast<xegpu::LayoutAttr>(other))
+    return false;
+
+  auto flattenedThis = flatten();
+  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
+
+  return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
+          (flattenedThis.getDims() == flattenedOther.getDims()));
+}
+
+// Helper function to adjust unit dimensions from sliced space to parent space
+static SetVector<int64_t>
+adjustUnitDimsWithSliceDims(const SetVector<int64_t> &unitDims,
+                            ArrayRef<int64_t> sliceDims) {
+  // Reconstruct parent's non-sliced dimensions
+
+  int64_t parentRank = sliceDims.size() + unitDims.size();
+  llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
+                                             sliceDims.end());
+  SmallVector<int64_t> nonSlicedDims;
+  for (int64_t i = 0; i < parentRank; ++i) {
+    if (!slicedDimsSet.contains(i))
+      nonSlicedDims.push_back(i);
+  }
+
+  // Map unit dims from sliced space to parent space
+  SetVector<int64_t> adjustUnitDims;
+  for (auto dim : unitDims) {
+    if (dim < static_cast<int64_t>(nonSlicedDims.size())) {
+      adjustUnitDims.insert(nonSlicedDims[dim]);
+    }
+  }
+
+  return adjustUnitDims;
+}
+
+// set the layout for unit dims: sg_data, inst_data and lane_data to 1
+DistributeLayoutAttr SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) {
+  SliceAttr attr = flatten();
+  ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+  SetVector<int64_t> adjustUnitDims =
+      adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+  return SliceAttr::get(getContext(), parent.setUnitDimData(adjustUnitDims),
+                        attr.getDims());
+}
+
+// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
+DistributeLayoutAttr SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) {
+  SliceAttr attr = flatten();
+  ArrayRef<int64_t> sliceDims = attr.getDims().asArrayRef();
+  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+
+  SetVector<int64_t> adjustUnitDims =
+      adjustUnitDimsWithSliceDims(unitDims, sliceDims);
+
+  return SliceAttr::get(getContext(), parent.setUnitDimLayout(adjustUnitDims),
+                        attr.getDims());
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_RangeAttr
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
index 59a1ad9dbe189..dc9eb96c169b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp
@@ -580,23 +580,39 @@ void LayoutInfoPropagation::visitVectorBroadCastOp(
   // Only consider vector to vector broadcasts for now.
   VectorType resultTy = broadcast.getResultVectorType();
   VectorType sourceTy = dyn_cast<VectorType>(broadcast.getSourceType());
-  if (!sourceTy) {
-    broadcast.emitWarning("Expecting source type to be a vector type.");
+  // skip layout propagation for non-vector source operand.
+  if (!sourceTy)
     return;
-  }
 
-  // Only consider nD -> nD broadcast.
+  // Hanlding broadcast from low-rank to high-rank (e.g., 1D to 2D) case.
   if (sourceTy.getRank() != resultTy.getRank()) {
-    broadcast.emitWarning("Expecting source and result to have same rank.");
+    auto sourceDims = sourceTy.getShape();
+    auto resultDims = resultTy.getShape();
+    SmallVector<int64_t> bcastDims;
+    auto dimDiff = resultTy.getRank() - sourceTy.getRank();
+    // adding the missing leading dims
+    for (int i = 0; i < dimDiff; i++)
+      bcastDims.push_back(i);
+
+    // for the rest dims in the resultTy, if sourceTy dim is 1, then it's
+    // broadcasted dim
+    for (size_t i = 0; i < sourceDims.size(); i++)
+      if ((sourceDims[i] == 1) && (resultDims[i + dimDiff] != 1))
+        bcastDims.push_back(i + dimDiff);
+
+    // create a slice layout for the source
+    xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
+        broadcast->getContext(),
+        cast<xegpu::DistributeLayoutAttr>(resultLayout.get()),
+        DenseI64ArrayAttr::get(broadcast->getContext(), bcastDims));
+
+    propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
     return;
   }
+
   SetVector<int64_t> broadcastUnitDims = broadcast.computeBroadcastedUnitDims();
-  if (broadcastUnitDims.size() != 1) {
-    broadcast.emitWarning("Expecting source type to be nD vector only with "
-                          "one broadcasted dimension.");
-    return;
-  }
-  // Propagate the result layout to the source operand.
+  resultLayout = cast<xegpu::DistributeLayoutAttr>(resultLayout.get())
+                     .setUnitDimData(broadcastUnitDims);
   propagateIfChanged(operands[0], operands[0]->meet(resultLayout));
 }
 
@@ -917,7 +933,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
   } else {
 
     // The layout is strictly determined by the payload type.
-    auto payloadTy = dyn_cast<VectorType>(load.getValueType());
+    VectorType payloadTy = load.getValueType();
     if (!payloadTy) {
       load.emitWarning("Not propagating, non-vector payload supplied.");
       return;
@@ -987,7 +1003,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
     // Currently, for 2D StoreScatterOp we expect that the height dimension of
     // the tensor descriptor is equal to the subgroup size. This is ensured by
     // the op verifier.
-    auto payloadTy = dyn_cast<VectorType>(storeScatter.getValueType());
+    VectorType payloadTy = storeScatter.getValueType();
     if (!payloadTy) {
       storeScatter.emitWarning("Not propagating, non-vector payload supplied.");
       return;

diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 0d1c5eeeff711..ca81c3cd7be42 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -99,7 +99,6 @@ getDistVecTypeBasedOnLaneLayout(xegpu::DistributeLayoutAttr layout,
   for (auto [i, dim] : llvm::enumerate(originalType.getShape())) {
     if (i < distributionStart)
       continue;
-
     // Check if the dimension can be distributed evenly.
     if (dim % effectiveLaneLayout[i - distributionStart] != 0)
       return failure();
@@ -1424,6 +1423,166 @@ struct VectorMultiReductionDistribution : public gpu::WarpDistributionPattern {
   }
 };
 
+/// This pattern distributes the `vector.broadcast` operation across lanes in a
+/// warp. The pattern supports three use cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector
+///    must have a slice layout of the result. If the distributed source and
+///    target vector types are identical, this lowers to a no-op; otherwise, it
+///    remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target:
+///    The source vector must have unit dimensions, and lane_data must be unit
+///    size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast from
+///    scalar to distributed result type.
+///
+/// Example 1 (lowering to a broadcast with distributed types):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [0]> } : () -> (vector<32xf32>)
+///   %2 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>}
+///     : vector<32xf32> to vector<8x32xf32>
+///     gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<1xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [0]> } : () -> (vector<32xf32>)
+///   gpu.yield %0 : vector<32xf32>
+/// }
+/// %2 = vector.broadcast %r#0 : vector<1xf32> to vector<8x1xf32>
+///
+/// Example 2 (no-op):
+/// ```
+/// %r = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x32xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [1]> } : () -> (vector<8xf32>)
+///   %1 = vector.shape_cast %0
+///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///      1]>}: vector<8xf32> to vector<8x1xf32>
+///   %2 = vector.broadcast %1
+///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///     1]>}: vector<8x1xf32> to vector<8x32xf32>
+///   gpu.yield %1 : vector<8x32xf32>
+/// }
+/// ```
+/// is lowered to:
+/// ```
+/// %r:1 = gpu.warp_execute_on_lane_0(%laneid)[32] -> (vector<8x1xf32>) {
+///   %0 = "some_def"() {layout_result_0 =
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 32], lane_data = [1, 1]>,
+///     dims = [1]> } : () -> (vector<8xf32>)
+///   %1 = vector.shape_cast %0
+///     {layout_result_0 = #xegpu.layout<lane_layout = [1, 32], lane_data = [1,
+///     1]>}: vector<8xf32> to vector<8x1xf32>
+///   gpu.yield %1 : vector<8x1xf32>
+/// }
+/// // The broadcast is implicit through layout transformation (no-op)
+///  "some_use"(%r#0)
+/// ```
+struct VectorBroadcastDistribution : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    OpOperand *yieldOperand =
+        getWarpResult(warpOp, llvm::IsaPred<vector::BroadcastOp>);
+    if (!yieldOperand)
+      return failure();
+    auto broadcastOp =
+        cast<vector::BroadcastOp>(yieldOperand->get().getDefiningOp());
+    unsigned operandIdx = yieldOperand->getOperandNumber();
+
+    VectorType sourceType = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    VectorType destType =
+        dyn_cast<VectorType>(broadcastOp.getResult().getType());
+
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getDistributeLayoutAttr(broadcastOp->getOpOperand(0));
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getDistributeLayoutAttr(broadcastOp.getResult());
+
+    FailureOr<VectorType> sourceDistType;
+    Type sourceElemOrDistType;
+    if (sourceType) {
+
+      // Case 1 and 2: source is a vector type.
+      int64_t rankDiff = destType.getRank() - sourceType.getRank();
+      if (rankDiff > 0) {
+        // Case 1: source is lower-rank than result.
+        bool isSliceOf = sourceLayout.isSliceOf(resultLayout);
+        if (!isSliceOf)
+          return rewriter.notifyMatchFailure(
+              warpOp,
+              "Broadcast input layout must be a slice of result layout.");
+      }
+      // case 2: source and result have same rank
+      if (rankDiff == 0) {
+        SetVector<int64_t> broadcastUnitDims =
+            broadcastOp.computeBroadcastedUnitDims();
+        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+        bool isEqualTo = sourceLayout.isEqualTo(resultLayout);
+        if (!isEqualTo)
+          return rewriter.notifyMatchFailure(
+              warpOp, "For same-rank broadcast, source must be identical to "
+                      "adjusted result layouts with unit dims.");
+        sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+      }
+
+      sourceDistType =
+          getDistVecTypeBasedOnLaneLayout(sourceLayout, sourceType);
+      if (failed(sourceDistType)) {
+        return rewriter.notifyMatchFailure(
+            warpOp, "Failed to distribute the source vector type.");
+      }
+      sourceElemOrDistType = sourceDistType.value();
+
+    } else {
+      // Case 3: source is a scalar type.
+      if (sourceLayout) {
+        return rewriter.notifyMatchFailure(
+            warpOp, "Broadcast from scalar must not have a layout attribute.");
+      }
+      sourceElemOrDistType = broadcastOp.getSourceType();
+    }
+    FailureOr<VectorType> destDistType =
+        getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+    if (failed(destDistType)) {
+      return rewriter.notifyMatchFailure(
+          warpOp, "Failed to distribute the dest vector type.");
+    }
+
+    SmallVector<size_t> newRetIndices;
+    auto newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, {broadcastOp.getSource()}, sourceElemOrDistType,
+        newRetIndices);
+
+    Value distributedSource = newWarpOp.getResult(newRetIndices[0]);
+
+    Value newBroadcast = distributedSource;
+
+    if (sourceElemOrDistType != destDistType.value()) {
+      rewriter.setInsertionPointAfter(newWarpOp);
+      newBroadcast =
+          vector::BroadcastOp::create(rewriter, newWarpOp.getLoc(),
+                                      destDistType.value(), distributedSource);
+    }
+
+    rewriter.replaceAllUsesWith(newWarpOp.getResult(operandIdx), newBroadcast);
+    return success();
+  }
+};
+
 /// Distribute a `vector.shape_cast` op feeding into yield op of an enclosing
 /// `gpu.warp_execute_on_lane_0` region.
 struct VectorShapeCastDistribution : public gpu::WarpDistributionPattern {
@@ -1865,7 +2024,7 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
   // patterns. Therefore, assign higher benefit.
   patterns
       .add<VectorShapeCastDistribution, VectorExtractStridedSliceDistribution,
-           VectorInsertStridedSliceDistribution>(
+           VectorInsertStridedSliceDistribution, VectorBroadcastDistribution>(
           patterns.getContext(),
           /*pattern benefit=*/highPatternBenefit);
 }

diff  --git a/mlir/test/Dialect/XeGPU/propagate-layout.mlir b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
index f8b59b87a122b..48e77d867508b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout.mlir
@@ -640,3 +640,61 @@ func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc
   return
 }
 }
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_1d_to_2d_broadcast_along_row(
+// CHECK-SAME:    %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
+// CHECK-SAME:    %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
+// CHECK:         %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}>  {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+// CHECK-SAME:      !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+// CHECK-NEXT:    %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
+// CHECK-SAME:       {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT:    %[[BROADCAST:.*]] = vector.broadcast %[[REDUCE]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+func.func @vector_broadcast_1d_to_2d_broadcast_along_row(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0000> : vector<16xf16>
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
+  %5 = vector.broadcast %4 : vector<16xf16> to vector<16x16xf16>
+  xegpu.store_nd %5, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_2d_to_2d_along_column(
+// CHECK:            %[[REDUCE:.*]] = vector.multi_reduction <add>
+// CHECK-SAME:       {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1] : vector<16x16xf16> to vector<16xf16>
+// CHECK-NEXT:    %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+// CHECK-NEXT:    vector.broadcast %[[SHAPECAST]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+
+func.func @vector_broadcast_2d_to_2d_along_column(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant dense<0.0000> : vector<16xf16>
+  %3 = xegpu.load_nd %arg0  : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
+  %4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
+  %5 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
+  %6 = vector.broadcast %5 : vector<16x1xf16> to vector<16x16xf16>
+  xegpu.store_nd %6, %arg1  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+}
+
+// -----
+gpu.module @test {
+// CHECK-LABEL: func.func @vector_broadcast_scalar_to_vector(
+// CHECK:         %[[CST:.*]] = arith.constant 0.{{.*}} : f16
+// CHECK-NEXT:    %[[BROADCAST:.*]] = vector.broadcast %[[CST]]
+// CHECK-SAME:       {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+
+func.func @vector_broadcast_scalar_to_vector(%arg0: !xegpu.tensor_desc<16x16xf16>) {
+  %cst = arith.constant 0.0000 : f16
+  %6 = vector.broadcast %cst : f16 to vector<16x16xf16>
+  xegpu.store_nd %6, %arg0  : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
+  return
+}
+}
\ No newline at end of file

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
index 44ec21359593f..216f3d19cff94 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir
@@ -920,4 +920,69 @@ gpu.func @vector_insert_strided_slice_unsupported_offset(%laneid: index) {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane
+// CHECK-SAME: (%[[ARG0:.*]]: index) {
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<1xf16>)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST_INNER:.*]] = vector.broadcast %[[DEF]]
+// CHECK: gpu.yield %[[BCAST_INNER]], %[[DEF]]
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[R]]#1 : vector<1xf16> to vector<16x1xf16>
+// CHECK: "some_use"(%[[BCAST]])
+gpu.func  @vector_broadcast_1d_to_2d_broadcast_within_lane(%laneid: index) {
+
+  %r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<16x1xf16>) {
+
+    %1 = "some_def"() : () -> vector<16xf16>
+
+    %2 = vector.broadcast %1 {
+      layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    } : vector<16xf16> to vector<16x16xf16>
+
+    gpu.yield %2 : vector<16x16xf16>
+  }
+  "some_use"(%r) : (vector<16x1xf16>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, vector<16x1xf16>)
+// CHECK:   %[[DEF:.*]] = "some_def"() : () -> vector<16x1xf16>
+// CHECK:   %[[BCAST:.*]] = vector.broadcast %[[DEF]]
+// CHECK-SAME: : vector<16x1xf16> to vector<16x16xf16>
+// CHECK:   gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, vector<16x1xf16>
+// CHECK: "some_use"(%[[R]]#1) : (vector<16x1xf16>) -> ()
+gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: index) {
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+    %1 = "some_def"() : () -> vector<16x1xf16>
+    %2 = vector.broadcast %1 {
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    } : vector<16x1xf16> to vector<16x16xf16>
+    gpu.yield %2: vector<16x16xf16>
+  }
+  "some_use"(%0) : (vector<16x1xf16>) -> ()
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK: %[[R:.*]]:2 = gpu.warp_execute_on_lane_0(%[[ARG0]])[16] -> (vector<16x1xf16>, f16)
+// CHECK: %[[DEF:.*]] = "some_def"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[DEF]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+// CHECK: gpu.yield %[[BCAST]], %[[DEF]] : vector<16x16xf16>, f16
+// CHECK: %[[RESULT:.*]] = vector.broadcast %[[R]]#1 : f16 to vector<16x1xf16>
+// CHECK: "some_use"(%[[RESULT]])
+gpu.func
+ at vector_shape_cast_scalar_to_vector(%arg0: index) {
+  %0 = gpu.warp_execute_on_lane_0(%arg0)[16] -> (vector<16x1xf16>) {
+    %1 = "some_def"() : () -> f16
+    %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+    gpu.yield %2 : vector<16x16xf16>
+  }
+  "some_use"(%0) : (vector<16x1xf16>) -> ()
+  gpu.return
+}
+
 }

diff  --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 22177f8f6a15f..e5e3d2a1c1ad5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -330,3 +330,64 @@ gpu.module @xevm_module{
     gpu.return
   }
 }
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d_broadcast_within_lane({{.*}}) {
+gpu.module @xevm_module{
+   gpu.func  @vector_broadcast_1d_to_2d_broadcast_within_lane(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} dense<0.000000e+00> : vector<16xf16>
+    %tdesc0 = xegpu.create_nd_tdesc %arg0 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    %0 = xegpu.load_nd %tdesc0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
+    %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} [0] : vector<16x16xf16> to vector<16xf16>
+    // CHECK: %[[BCAST:.*]] = vector.broadcast %{{.*}} : f16 to vector<16xf16>
+    %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+    xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case({{.*}}) {
+gpu.module @xevm_module{
+   gpu.func  @vector_broadcast_2d_to_2d_across_lane_lower_to_noop_case(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %mask = vector.constant_mask [16] {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: vector<16xi1>
+    %1 = xegpu.load %arg0[%c0], %mask {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>}: memref<16xf16>, index, vector<16xi1> -> vector<16xf16>
+    
+    %11 = vector.shape_cast %1 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x1xf16>
+    %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+    // CHECK-NOT: vector.broadcast
+    // CHECK-NOT: vector.shape_cast
+ 
+    %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    // CHECK: xegpu.store_nd {{.*}}, {{.*}}[{{.*}}, {{.*}}]
+    // CHECK-SAME: : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
+
+    xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @vector_shape_cast_scalar_to_vector({{.*}}) {
+gpu.module @xevm_module{
+   gpu.func  @vector_shape_cast_scalar_to_vector(%arg0: memref<16xf16>, %arg1: memref<16x16xf16>) {
+    %c0 = arith.constant 0 : index
+    %9 = gpu.block_id  x
+    %10 = arith.index_cast %9 : index to i16
+    %11 = arith.bitcast %10 : i16 to f16
+    // CHECK: vector.broadcast {{.*}} : f16 to vector<16xf16>
+    %2 = vector.broadcast %11 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+    %tdesc1 = xegpu.create_nd_tdesc %arg1 : memref<16x16xf16>
+      -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    xegpu.store_nd %2, %tdesc1[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}> : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+    gpu.return
+  }
+}
+
+


        


More information about the Mlir-commits mailing list