[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution patterns for vector step, shape_cast & broadcast from sg-to-wi (PR #185960)

Nishant Patel llvmlistbot at llvm.org
Fri Mar 27 09:54:40 PDT 2026


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/185960

>From 00908d454a43281a5bad94c82ce324983455499f Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 9 Mar 2026 18:57:23 +0000
Subject: [PATCH 1/5] Add pattern for vector.step

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 74 ++++++++++++++++++-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 43 +++++++++++
 2 files changed, 115 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 5cd766ed2813e..2ecf6898a2ad5 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -828,6 +828,72 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   }
 };
 
+/// Distribute a vector::StepOp with a sliced result layout.
+/// The sliced layout must have exactly 1 effective lane dimension.
+/// We completely resolve the vector::StepOp by computing the lane_data-sized
+/// subranges.
+struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
+  using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getResult(0));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "the result vector of the step op lacks subgroup layout");
+    auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
+    if (!sliceLayout)
+      return rewriter.notifyMatchFailure(
+          op, "the result layout must be a slice layout");
+    if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
+      return rewriter.notifyMatchFailure(
+          op, "expecting 1 dim in the effective result layout");
+
+    auto loc = op.getLoc();
+    auto stepResultVecTy = op.getResult().getType();
+    auto wiShapeOrFailure =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
+    if (failed(wiShapeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "unable to compute workitem vector type from the layout");
+    VectorType newVecTy = wiShapeOrFailure.value();
+
+    Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+                                         /*upperBound=*/mlir::IntegerAttr());
+    auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
+        rewriter, loc, laneId, stepResultVecTy.getShape());
+    if (failed(laneDataBlockCoords))
+      return rewriter.notifyMatchFailure(
+          op, "failed to compute lane data block coordinates");
+
+    auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
+    auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
+    assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
+           newVecTy.getNumElements() / laneDataBlockLength);
+    SmallVector<Value> stepVals;
+    // For each lane_data block, reconstruct its sub-range
+    // from the range of SG-level vector.step.
+    for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
+      auto laneDataBlockStartCoord = laneDataBlockCoords[0];
+      stepVals.push_back(laneDataBlockStartCoord);
+      for (int i = 1; i < laneDataBlockLength; ++i) {
+        auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
+        stepVals.push_back(arith::AddIOp::create(
+            rewriter, loc, laneDataBlockStartCoord, offset));
+      }
+    }
+    assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
+           "Expecting the number of step values to match the number of "
+           "elements in the vector");
+    auto stepOpVal =
+        vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
+    rewriter.replaceOp(op, stepOpVal);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -1029,10 +1095,14 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
       [=](vector::MultiDimReductionOp op) -> bool {
         return !isValidSubgroupMultiReductionOp(op);
       });
+  // vector::StepOp is legal only if its result has no temporary layout.
+  target.addDynamicallyLegalOp<vector::StepOp>([=](vector::StepOp op) -> bool {
+    return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+  });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
                SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
-               SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
-      typeConverter, patterns.getContext());
+               SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix,
+               SgToWiVectorStep>(typeConverter, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 13ca46c3dbb50..fdafb5984b349 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -525,3 +525,46 @@ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layou
   gpu.return
 }
 }
+
+// -----
+// vector.step with a simple slice layout (16 elements, 1 per lane)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_step_slice
+// CHECK:         %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-DAG:     %[[C16:.*]] = arith.constant 16 : index
+// CHECK:         %[[REM:.*]] = arith.remui %[[LANE_ID]], %[[C16]] : index
+// CHECK:         %[[REM2:.*]] = arith.remui %[[REM]], %[[C16]]{{.*}} : index
+// CHECK:         %[[VEC:.*]] = vector.from_elements %[[REM2]] : vector<1xindex>
+gpu.func @vector_step_slice() {
+  %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<16xindex>
+  gpu.return
+}
+}
+
+// -----
+// vector.step with unit-size vector (1 element, always 0)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_step_slice_unit
+// CHECK:         %[[VEC:.*]] = vector.from_elements %{{.*}} : vector<1xindex>
+gpu.func @vector_step_slice_unit() {
+  %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>} : vector<1xindex>
+  gpu.return
+}
+}
+
+// -----
+// vector.step with multi-distribution (lane_layout=[2,4,2], lane_data=[1,2,1])
+// Each lane holds 4 elements as 2 blocks of 2 elements (lane_data=2 in dim 1).
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_step_slice_multi_dist
+// CHECK:         %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
+// CHECK:         %[[VEC:.*]] = vector.from_elements %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xindex>
+gpu.func @vector_step_slice_multi_dist() {
+  %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>} : vector<16xindex>
+  gpu.return
+}
+}

>From 48016c8fe2e63c4930360c1e96276c2ac5a2f9ff Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 9 Mar 2026 20:33:06 +0000
Subject: [PATCH 2/5] Add pattern for vector.shape_cast

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 44 ++++++++++++--
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 58 +++++++++++++++++--
 2 files changed, 93 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 2ecf6898a2ad5..fececc954f0c8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -894,6 +894,39 @@ struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
   }
 };
 
+/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
+struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
+  using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getTemporaryLayout(op->getOpOperand(0));
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(op->getOpResult(0));
+    if (!sourceLayout || !resultLayout)
+      return rewriter.notifyMatchFailure(
+          op,
+          "the source or result of shape_cast op lacks distribution layout");
+    if (!sourceLayout.isForSubgroup() || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "the source or result layout is not for subgroup");
+
+    auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
+        resultLayout, op.getResultVectorType());
+    if (failed(resultDistTypeOrFailure))
+      return rewriter.notifyMatchFailure(
+          op, "failed to get distributed vector type for result");
+
+    Value source = adaptor.getSource();
+    auto newShapeCast = vector::ShapeCastOp::create(
+        rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
+    rewriter.replaceOp(op, newShapeCast);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -1095,14 +1128,15 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
       [=](vector::MultiDimReductionOp op) -> bool {
         return !isValidSubgroupMultiReductionOp(op);
       });
-  // vector::StepOp is legal only if its result has no temporary layout.
-  target.addDynamicallyLegalOp<vector::StepOp>([=](vector::StepOp op) -> bool {
-    return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
-  });
+  target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+      [=](Operation *op) -> bool {
+        return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
+      });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
                SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
                SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix,
-               SgToWiVectorStep>(typeConverter, patterns.getContext());
+               SgToWiVectorStep, SgToWiVectorShapeCast>(typeConverter,
+                                                        patterns.getContext());
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index fdafb5984b349..0717e9b96d079 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -527,7 +527,6 @@ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layou
 }
 
 // -----
-// vector.step with a simple slice layout (16 elements, 1 per lane)
 gpu.module @xevm_module {
 // CHECK-LABEL: gpu.func @vector_step_slice
 // CHECK:         %[[LANE_ID:.*]] = gpu.lane_id
@@ -542,7 +541,6 @@ gpu.func @vector_step_slice() {
 }
 
 // -----
-// vector.step with unit-size vector (1 element, always 0)
 gpu.module @xevm_module {
 // CHECK-LABEL: gpu.func @vector_step_slice_unit
 // CHECK:         %[[VEC:.*]] = vector.from_elements %{{.*}} : vector<1xindex>
@@ -553,8 +551,6 @@ gpu.func @vector_step_slice_unit() {
 }
 
 // -----
-// vector.step with multi-distribution (lane_layout=[2,4,2], lane_data=[1,2,1])
-// Each lane holds 4 elements as 2 blocks of 2 elements (lane_data=2 in dim 1).
 gpu.module @xevm_module {
 // CHECK-LABEL: gpu.func @vector_step_slice_multi_dist
 // CHECK:         %[[LANE_ID:.*]] = gpu.lane_id
@@ -568,3 +564,57 @@ gpu.func @vector_step_slice_multi_dist() {
   gpu.return
 }
 }
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
+// CHECK:         %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<1x1xf32>
+gpu.func @vector_shapecast_rank_increasing() {
+  %cst = "some_op"()
+    {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+    : () -> (vector<16xf32>)
+  %cast = vector.shape_cast %cst
+    {
+      layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<16xf32> to vector<1x16xf32>
+  gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing
+// CHECK:         %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1x1xf32> to vector<1xf32>
+gpu.func @vector_shapecast_rank_reducing() {
+  %cst = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+    : () -> (vector<1x16xf32>)
+  %cast = vector.shape_cast %cst
+    {
+      layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+      layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+    }
+    : vector<1x16xf32> to vector<16xf32>
+  gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing_without_slicing_layout
+// CHECK:         %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<1x1xf32>
+gpu.func @vector_shapecast_rank_increasing_without_slicing_layout() {
+  %cst = "some_op"()
+    {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+    : () -> (vector<16xf32>)
+  %cast = vector.shape_cast %cst
+    {
+      layout_operand_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>,
+      layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+    }
+    : vector<16xf32> to vector<1x16xf32>
+  gpu.return
+}
+}

>From 71ef2f807c86ae9f4a1cc3eb23e217487184c8c1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 12 Mar 2026 22:03:21 +0000
Subject: [PATCH 3/5] Add pattern for vector.broadcast

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 129 +++++++++++++++++-
 .../XeGPU/sg-to-wi-experimental-unit.mlir     |  75 ++++++++++
 2 files changed, 202 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 5cd766ed2813e..a51a8dbb21cab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -828,6 +828,127 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   }
 };
 
+/// This pattern distributes a subgroup-level `vector.broadcast` op to
+/// workitem-level. The pattern supports three cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+///    vector must have a slice layout of the result. If the distributed source
+///    and target vector types are identical, this lowers to a no-op; otherwise,
+///    it remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+///    target: The source vector must have unit dimensions, and lane_data must
+///    be unit size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
+///    from scalar to distributed result type.
+///
+/// Example 1 (low-rank to high-rank broadcast):
+/// ```
+///   %0 = "some_op"() {layout_result_0 =
+///     #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+///     dims = [0]>} : () -> vector<16xf16>
+///   %1 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : vector<16xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+///   %0 = ... : vector<1xf16>
+///   %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
+/// ```
+///
+/// Example 2 (same-rank broadcast, no-op):
+/// ```
+///   %0 = "some_op"() {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : () -> vector<16x1xf16>
+///   %1 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : vector<16x1xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to (no-op, source already matches distributed result type):
+/// ```
+///   %0 = ... : vector<16x1xf16>
+///   // broadcast is eliminated, %0 is used directly
+/// ```
+///
+/// Example 3 (scalar to vector broadcast):
+/// ```
+///   %0 = "some_op"() : () -> f16
+///   %1 = vector.broadcast %0 {layout_result_0 =
+///     #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+///     : f16 to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+///   %0 = ... : f16
+///   %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
+/// ```
+struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr resultLayout =
+        xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
+    if (!resultLayout || !resultLayout.isForSubgroup())
+      return rewriter.notifyMatchFailure(
+          op, "result does not have subgroup distribute layout");
+
+    VectorType destType = op.getResultVectorType();
+    VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
+
+    xegpu::DistributeLayoutAttr sourceLayout =
+        xegpu::getTemporaryLayout(op->getOpOperand(0));
+
+    if (sourceType) {
+      int64_t rankDiff = destType.getRank() - sourceType.getRank();
+      if (rankDiff > 0) {
+        // Case 1: Low-rank to high-rank broadcast.
+        if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
+          op.emitWarning(
+              "broadcast source layout must be a slice of result layout");
+      } else if (rankDiff == 0) {
+        // Case 2: Same-rank broadcast.
+        if (!sourceLayout || !sourceLayout.isEqualTo(resultLayout))
+          return rewriter.notifyMatchFailure(
+              op, "for same-rank broadcast, source layout must be equal to "
+                  "result layout");
+        auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
+        SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
+                                               broadcastUnitDimsSet.end());
+        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+        sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+      }
+    } else {
+      // Case 3: Scalar to vector broadcast.
+      if (sourceLayout)
+        return rewriter.notifyMatchFailure(
+            op, "broadcast from scalar must not have a layout attribute");
+    }
+
+    auto destDistType =
+        xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+    if (failed(destDistType))
+      return rewriter.notifyMatchFailure(
+          op, "failed to distribute the result vector type");
+
+    Value source = adaptor.getSource();
+    // If the adapted source already matches the dest dist type, it's a no-op.
+    if (source.getType() == destDistType.value()) {
+      rewriter.replaceOp(op, source);
+      return success();
+    }
+
+    auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
+                                             destDistType.value(), source);
+    rewriter.replaceOp(op, newOp);
+    return success();
+  }
+};
+
 struct XeGPUSgToWiDistributeExperimentalPass
     : public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
           XeGPUSgToWiDistributeExperimentalPass> {
@@ -1029,10 +1150,14 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
       [=](vector::MultiDimReductionOp op) -> bool {
         return !isValidSubgroupMultiReductionOp(op);
       });
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        return !xegpu::getTemporaryLayout(op->getResult(0));
+      });
   target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
   patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
                SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
                SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
-               SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
-      typeConverter, patterns.getContext());
+               SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix,
+               SgToWiBroadcast>(typeConverter, patterns.getContext());
 }
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 13ca46c3dbb50..01b902723caf8 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -525,3 +525,78 @@ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layou
   gpu.return
 }
 }
+
+// -----
+// 1D to 2D broadcast (low-rank to high-rank, within-lane)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16xf16> to vector<1xf16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<1xf16> to vector<16x1xf16>
+gpu.func @vector_broadcast_1d_to_2d(%laneid: index) {
+  %0 = "some_op"() {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : () -> vector<16xf16>
+  %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+  "some_use"(%1) : (vector<16x16xf16>) -> ()
+  gpu.return
+}
+}
+
+// -----
+// 2D to 3D broadcast (low-rank to high-rank)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_3d
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x16xf16> to vector<16x1xf16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<16x1xf16> to vector<1x16x1xf16>
+gpu.func @vector_broadcast_2d_to_3d(%laneid: index) {
+  %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x16xf16>
+  %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>} : vector<16x16xf16> to vector<1x16x16xf16>
+  "some_use"(%1) : (vector<1x16x16xf16>) -> ()
+  gpu.return
+}
+}
+
+// -----
+// 2D to 2D same-rank broadcast (noop case, broadcast across lane dim)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_noop
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK-NOT: vector.broadcast
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x1xf16> to vector<16x16xf16>
+// CHECK: "some_use"(%[[CAST]])
+gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
+  %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
+  %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+  "some_use"(%1) : (vector<16x16xf16>) -> ()
+  gpu.return
+}
+}
+
+// -----
+// Scalar to vector broadcast (with layout)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_scalar_to_vector
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<16x1xf16>
+gpu.func @vector_broadcast_scalar_to_vector(%laneid: index) {
+  %0 = "some_op"() : () -> f16
+  %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+  "some_use"(%1) : (vector<16x16xf16>) -> ()
+  gpu.return
+}
+}
+
+// -----
+// Scalar to vector broadcast (no layout - uniform, should remain unchanged)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_scalar_to_vector_uniform
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<16x16xf16>
+// CHECK: "some_use"(%[[BCAST]])
+gpu.func @vector_broadcast_scalar_to_vector_uniform(%laneid: index) {
+  %0 = "some_op"() : () -> f16
+  %1 = vector.broadcast %0 : f16 to vector<16x16xf16>
+  "some_use"(%1) : (vector<16x16xf16>) -> ()
+  gpu.return
+}
+}

>From 3305aaeaadf54ea6e3c7346f0acc623ce8526f45 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 23 Mar 2026 22:30:06 +0000
Subject: [PATCH 4/5] remove operand layouts

---
 mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index ce0009b95a515..3a24affe6ab3c 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -836,7 +836,7 @@ gpu.module @xevm_module {
 // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<1xf16> to vector<16x1xf16>
 gpu.func @vector_broadcast_1d_to_2d(%laneid: index) {
   %0 = "some_op"() {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : () -> vector<16xf16>
-  %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+  %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
   "some_use"(%1) : (vector<16x16xf16>) -> ()
   gpu.return
 }
@@ -850,7 +850,7 @@ gpu.module @xevm_module {
 // CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<16x1xf16> to vector<1x16x1xf16>
 gpu.func @vector_broadcast_2d_to_3d(%laneid: index) {
   %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x16xf16>
-  %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>} : vector<16x16xf16> to vector<1x16x16xf16>
+  %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>} : vector<16x16xf16> to vector<1x16x16xf16>
   "some_use"(%1) : (vector<1x16x16xf16>) -> ()
   gpu.return
 }
@@ -863,7 +863,7 @@ gpu.module @xevm_module {
 // CHECK-NOT: vector.broadcast
 gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
   %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
-  %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+  %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
   "some_use"(%1) : (vector<16x16xf16>) -> ()
   gpu.return
 }

>From 46f101e1ee9aa449d73f42d3d3226c0400a1affd Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 25 Mar 2026 00:22:22 +0000
Subject: [PATCH 5/5] Address Feedback

---
 .../XeGPUSgToWiDistributeExperimental.cpp     | 26 ++++++++-----------
 .../XeGPU/sg-to-wi-experimental-unit.mlir     | 12 +++++----
 2 files changed, 18 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index fb6b7e73c705c..8c60ced4ed38e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -842,8 +842,8 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
   }
 };
 
-/// Distribute a vector::StepOp with a sliced result layout.
-/// The sliced layout must have exactly 1 effective lane dimension.
+/// Distribute a vector::StepOp to workitem-level.
+/// The layout must have exactly 1 effective lane dimension.
 /// We completely resolve the vector::StepOp by computing the lane_data-sized
 /// subranges.
 struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
@@ -857,13 +857,6 @@ struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
     if (!resultLayout || !resultLayout.isForSubgroup())
       return rewriter.notifyMatchFailure(
           op, "the result vector of the step op lacks subgroup layout");
-    auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
-    if (!sliceLayout)
-      return rewriter.notifyMatchFailure(
-          op, "the result layout must be a slice layout");
-    if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
-      return rewriter.notifyMatchFailure(
-          op, "expecting 1 dim in the effective result layout");
 
     auto loc = op.getLoc();
     auto stepResultVecTy = op.getResult().getType();
@@ -888,7 +881,12 @@ struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
            newVecTy.getNumElements() / laneDataBlockLength);
     SmallVector<Value> stepVals;
     // For each lane_data block, reconstruct its sub-range
-    // from the range of SG-level vector.step.
+    // from the range of SG-level vector.step.Example: vector.step
+    // {slice<layout<lane_layout=[2,4,2], lane_data=[1,2,1]>, dims=[0,2]>} :
+    // vector<16xindex>
+    // Each logical lane holds 4 elements as 2 blocks of 2 elements each.
+    // The blocks are round-robin distributed, so logical lane id 0
+    // holds values [0,1, 8,9].
     for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
       auto laneDataBlockStartCoord = laneDataBlockCoords[0];
       stepVals.push_back(laneDataBlockStartCoord);
@@ -1145,14 +1143,12 @@ struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
               "broadcast source layout must be a slice of result layout");
       } else if (rankDiff == 0) {
         // Case 2: Same-rank broadcast.
-        if (!sourceLayout || !sourceLayout.isEqualTo(resultLayout))
-          return rewriter.notifyMatchFailure(
-              op, "for same-rank broadcast, source layout must be equal to "
-                  "result layout");
         auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
         SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
                                                broadcastUnitDimsSet.end());
-        resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+        assert(sourceLayout.isEqualTo(
+                   sourceLayout.setUnitDimData(broadcastUnitDims)) &&
+               "The sg_data for unit dimensions should be set as 1");
         sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
       }
     } else {
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 3a24affe6ab3c..77795f6b9696e 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -766,11 +766,13 @@ gpu.func @vector_step_slice_unit() {
 gpu.module @xevm_module {
 // CHECK-LABEL: gpu.func @vector_step_slice_multi_dist
 // CHECK:         %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:     %[[C8:.*]] = arith.constant 8 : index
-// CHECK-DAG:     %[[C16:.*]] = arith.constant 16 : index
-// CHECK-DAG:     %[[C2:.*]] = arith.constant 2 : index
-// CHECK:         %[[VEC:.*]] = vector.from_elements %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xindex>
+// CHECK:         %[[MULI:.*]] = arith.muli %{{.*}}, %{{.*}} : index
+// CHECK:         %[[V0:.*]] = arith.remui %[[MULI]], %{{.*}} : index
+// CHECK:         %[[SUM1:.*]] = arith.addi %[[MULI]], %{{.*}} : index
+// CHECK:         %[[V2:.*]] = arith.remui %[[SUM1]], %{{.*}} : index
+// CHECK:         %[[V1:.*]] = arith.addi %[[V0]], %{{.*}} : index
+// CHECK:         %[[V3:.*]] = arith.addi %[[V2]], %{{.*}} : index
+// CHECK:         %[[VEC:.*]] = vector.from_elements %[[V0]], %[[V1]], %[[V2]], %[[V3]] : vector<4xindex>
 gpu.func @vector_step_slice_multi_dist() {
   %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>} : vector<16xindex>
   gpu.return



More information about the Mlir-commits mailing list