[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution patterns for vector step, shape_cast & broadcast from sg-to-wi (PR #185960)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 17 15:06:01 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
This PR adds distribution patterns for vector.step, vector.shape_cast & vector.broadcast in the new sg-to-wi pass
---
Full diff: https://github.com/llvm/llvm-project/pull/185960.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp (+224-1)
- (modified) mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir (+160)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
index 5cd766ed2813e..15e93717f9846 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSgToWiDistributeExperimental.cpp
@@ -828,6 +828,220 @@ struct SgToWiStoreScatter : public OpConversionPattern<xegpu::StoreScatterOp> {
}
};
+/// Distribute a vector::StepOp with a sliced result layout.
+/// The sliced layout must have exactly 1 effective lane dimension.
+/// We completely resolve the vector::StepOp by computing the lane_data-sized
+/// subranges.
+struct SgToWiVectorStep : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "the result vector of the step op lacks subgroup layout");
+ auto sliceLayout = dyn_cast<xegpu::SliceAttr>(resultLayout);
+ if (!sliceLayout)
+ return rewriter.notifyMatchFailure(
+ op, "the result layout must be a slice layout");
+ if (sliceLayout.getEffectiveLaneLayoutAsInt().size() != 1)
+ return rewriter.notifyMatchFailure(
+ op, "expecting 1 dim in the effective result layout");
+
+ auto loc = op.getLoc();
+ auto stepResultVecTy = op.getResult().getType();
+ auto wiShapeOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, stepResultVecTy);
+ if (failed(wiShapeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "unable to compute workitem vector type from the layout");
+ VectorType newVecTy = wiShapeOrFailure.value();
+
+ Value laneId = gpu::LaneIdOp::create(rewriter, loc, rewriter.getIndexType(),
+ /*upperBound=*/mlir::IntegerAttr());
+ auto laneDataBlockCoords = resultLayout.computeDistributedCoords(
+ rewriter, loc, laneId, stepResultVecTy.getShape());
+ if (failed(laneDataBlockCoords))
+ return rewriter.notifyMatchFailure(
+ op, "failed to compute lane data block coordinates");
+
+ auto laneDataBlockCoordsVec = laneDataBlockCoords.value();
+ auto laneDataBlockLength = resultLayout.getEffectiveLaneDataAsInt()[0];
+ assert(static_cast<int64_t>(laneDataBlockCoordsVec.size()) ==
+ newVecTy.getNumElements() / laneDataBlockLength);
+ SmallVector<Value> stepVals;
+ // For each lane_data block, reconstruct its sub-range
+ // from the range of SG-level vector.step.
+ for (auto &laneDataBlockCoords : laneDataBlockCoordsVec) {
+ auto laneDataBlockStartCoord = laneDataBlockCoords[0];
+ stepVals.push_back(laneDataBlockStartCoord);
+ for (int i = 1; i < laneDataBlockLength; ++i) {
+ auto offset = arith::ConstantIndexOp::create(rewriter, loc, i);
+ stepVals.push_back(arith::AddIOp::create(
+ rewriter, loc, laneDataBlockStartCoord, offset));
+ }
+ }
+ assert(static_cast<int64_t>(stepVals.size()) == newVecTy.getNumElements() &&
+ "Expecting the number of step values to match the number of "
+ "elements in the vector");
+ auto stepOpVal =
+ vector::FromElementsOp::create(rewriter, loc, newVecTy, stepVals);
+ rewriter.replaceOp(op, stepOpVal);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level ShapeCast op to workitem-level.
+struct SgToWiVectorShapeCast : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(op->getOpResult(0));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "the result vector of the shape_cast op lacks subgroup layout");
+
+ auto resultDistTypeOrFailure = xegpu::getDistVecTypeBasedOnLaneLayout(
+ resultLayout, op.getResultVectorType());
+ if (failed(resultDistTypeOrFailure))
+ return rewriter.notifyMatchFailure(
+ op, "failed to get distributed vector type for result");
+
+ Value source = adaptor.getSource();
+ auto newShapeCast = vector::ShapeCastOp::create(
+ rewriter, op.getLoc(), resultDistTypeOrFailure.value(), source);
+ rewriter.replaceOp(op, newShapeCast);
+ return success();
+ }
+};
+
+/// This pattern distributes a subgroup-level `vector.broadcast` op to
+/// workitem-level. The pattern supports three cases:
+///
+/// 1) Broadcast a low-rank vector to high-rank vector: The low-rank input
+/// vector must have a slice layout of the result. If the distributed source
+/// and target vector types are identical, this lowers to a no-op; otherwise,
+/// it remains a broadcast but operates on distributed vectors.
+///
+/// 2) Broadcast a same-rank vector with identical layouts for source and
+/// target: The source vector must have unit dimensions, and lane_data must
+/// be unit size for those unit dims. This always lowers to a no-op.
+///
+/// 3) Broadcast a scalar with no layout: This always lowers to a broadcast
+/// from scalar to distributed result type.
+///
+/// Example 1 (low-rank to high-rank broadcast):
+/// ```
+/// %0 = "some_op"() {layout_result_0 =
+/// #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
+/// dims = [0]>} : () -> vector<16xf16>
+/// %1 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : vector<16xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+/// %0 = "some_op"() : () -> vector<1xf16>
+/// %1 = vector.broadcast %0 : vector<1xf16> to vector<16x1xf16>
+/// ```
+///
+/// Example 2 (same-rank broadcast, no-op):
+/// ```
+/// %0 = "some_op"() {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : () -> vector<16x1xf16>
+/// %1 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : vector<16x1xf16> to vector<16x16xf16>
+/// ```
+/// is distributed to (no-op, source already matches distributed result type):
+/// ```
+/// %0 = "some_op"() : () -> vector<16x1xf16>
+/// // broadcast is eliminated, %0 is used directly
+/// ```
+///
+/// Example 3 (scalar to vector broadcast):
+/// ```
+/// %0 = "some_op"() : () -> f16
+/// %1 = vector.broadcast %0 {layout_result_0 =
+/// #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+/// : f16 to vector<16x16xf16>
+/// ```
+/// is distributed to:
+/// ```
+/// %0 = "some_op"() : f16
+/// %1 = vector.broadcast %0 : f16 to vector<16x1xf16>
+/// ```
+struct SgToWiBroadcast : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr resultLayout =
+ xegpu::getTemporaryLayout(cast<OpResult>(op.getResult()));
+ if (!resultLayout || !resultLayout.isForSubgroup())
+ return rewriter.notifyMatchFailure(
+ op, "result does not have subgroup distribute layout");
+
+ VectorType destType = op.getResultVectorType();
+ VectorType sourceType = dyn_cast<VectorType>(op.getSourceType());
+
+ xegpu::DistributeLayoutAttr sourceLayout =
+ xegpu::getTemporaryLayout(op->getOpOperand(0));
+
+ if (sourceType) {
+ int64_t rankDiff = destType.getRank() - sourceType.getRank();
+ if (rankDiff > 0) {
+ // Case 1: Low-rank to high-rank broadcast.
+ if (!sourceLayout || !sourceLayout.isSliceOf(resultLayout))
+ op.emitWarning(
+ "broadcast source layout must be a slice of result layout");
+ } else if (rankDiff == 0) {
+ // Case 2: Same-rank broadcast.
+ if (!sourceLayout || !sourceLayout.isEqualTo(resultLayout))
+ return rewriter.notifyMatchFailure(
+ op, "for same-rank broadcast, source layout must be equal to "
+ "result layout");
+ auto broadcastUnitDimsSet = op.computeBroadcastedUnitDims();
+ SmallVector<int64_t> broadcastUnitDims(broadcastUnitDimsSet.begin(),
+ broadcastUnitDimsSet.end());
+ resultLayout = resultLayout.setUnitDimData(broadcastUnitDims);
+ sourceLayout = sourceLayout.setUnitDimLayout(broadcastUnitDims);
+ }
+ } else {
+ // Case 3: Scalar to vector broadcast.
+ if (sourceLayout)
+ return rewriter.notifyMatchFailure(
+ op, "broadcast from scalar must not have a layout attribute");
+ }
+
+ auto destDistType =
+ xegpu::getDistVecTypeBasedOnLaneLayout(resultLayout, destType);
+ if (failed(destDistType))
+ return rewriter.notifyMatchFailure(
+ op, "failed to distribute the result vector type");
+
+ Value source = adaptor.getSource();
+ // If the adapted source already matches the dest dist type, it's a no-op.
+ if (source.getType() == destDistType.value()) {
+ rewriter.replaceOp(op, source);
+ return success();
+ }
+
+ auto newOp = vector::BroadcastOp::create(rewriter, op.getLoc(),
+ destDistType.value(), source);
+ rewriter.replaceOp(op, newOp);
+ return success();
+ }
+};
+
struct XeGPUSgToWiDistributeExperimentalPass
: public xegpu::impl::XeGPUSgToWiDistributeExperimentalBase<
XeGPUSgToWiDistributeExperimentalPass> {
@@ -1029,10 +1243,19 @@ void xegpu::populateXeGPUSgToWiDistributeTypeConversionAndLegality(
[=](vector::MultiDimReductionOp op) -> bool {
return !isValidSubgroupMultiReductionOp(op);
});
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ [=](Operation *op) -> bool {
+ return !xegpu::getTemporaryLayout(dyn_cast<OpResult>(op->getResult(0)));
+ });
+ target.addDynamicallyLegalOp<vector::BroadcastOp>(
+ [=](vector::BroadcastOp op) -> bool {
+ return !xegpu::getTemporaryLayout(op->getResult(0));
+ });
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
patterns.add<SgToWiCreateNdDesc, SgToWiLoadNd, SgToWiStoreNd, SgToWiDpas,
SgToWiElementWise, SgToWiArithConstant, SgToWiPrefetchNd,
SgToWiLoadGather, SgToWiStoreScatter, SgToWiVectorReduction,
- SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix>(
+ SgToWiMultiDimReduction, SgToWiLoadMatrix, SgToWiStoreMatrix,
+ SgToWiVectorStep, SgToWiVectorShapeCast, SgToWiBroadcast>(
typeConverter, patterns.getContext());
}
diff --git a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
index 13ca46c3dbb50..2cfefcd94a32e 100644
--- a/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
+++ b/mlir/test/Dialect/XeGPU/sg-to-wi-experimental-unit.mlir
@@ -525,3 +525,163 @@ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layou
gpu.return
}
}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_step_slice
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK: %[[REM:.*]] = arith.remui %[[LANE_ID]], %[[C16]] : index
+// CHECK: %[[REM2:.*]] = arith.remui %[[REM]], %[[C16]]{{.*}} : index
+// CHECK: %[[VEC:.*]] = vector.from_elements %[[REM2]] : vector<1xindex>
+gpu.func @vector_step_slice() {
+ %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 2]>} : vector<16xindex>
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_step_slice_unit
+// CHECK: %[[VEC:.*]] = vector.from_elements %{{.*}} : vector<1xindex>
+gpu.func @vector_step_slice_unit() {
+ %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 1, 1, 16], lane_data = [1, 1, 1, 1]>, dims = [0, 1, 3]>} : vector<1xindex>
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_step_slice_multi_dist
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[VEC:.*]] = vector.from_elements %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : vector<4xindex>
+gpu.func @vector_step_slice_multi_dist() {
+ %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [2, 4, 2], lane_data = [1, 2, 1]>, dims = [0, 2]>} : vector<16xindex>
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing
+// CHECK: %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<1x1xf32>
+gpu.func @vector_shapecast_rank_increasing() {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>}
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_reducing
+// CHECK: %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1x1xf32> to vector<1xf32>
+gpu.func @vector_shapecast_rank_reducing() {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
+ : () -> (vector<1x16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>
+ }
+ : vector<1x16xf32> to vector<16xf32>
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_shapecast_rank_increasing_without_slicing_layout
+// CHECK: %[[SC:.*]] = vector.shape_cast %{{.*}} : vector<1xf32> to vector<1x1xf32>
+gpu.func @vector_shapecast_rank_increasing_without_slicing_layout() {
+ %cst = "some_op"()
+ {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
+ : () -> (vector<16xf32>)
+ %cast = vector.shape_cast %cst
+ {
+ layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+ }
+ : vector<16xf32> to vector<1x16xf32>
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_1d_to_2d
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16xf16> to vector<1xf16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<1xf16> to vector<16x1xf16>
+gpu.func @vector_broadcast_1d_to_2d(%laneid: index) {
+ %0 = "some_op"() {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>} : () -> vector<16xf16>
+ %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [0]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16xf16> to vector<16x16xf16>
+ "some_use"(%1) : (vector<16x16xf16>) -> ()
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_3d
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC]] : vector<16x16xf16> to vector<16x1xf16>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[CAST]] : vector<16x1xf16> to vector<1x16x1xf16>
+gpu.func @vector_broadcast_2d_to_3d(%laneid: index) {
+ %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x16xf16>
+ %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 1, 16], lane_data = [1, 1, 1]>} : vector<16x16xf16> to vector<1x16x16xf16>
+ "some_use"(%1) : (vector<1x16x16xf16>) -> ()
+ gpu.return
+}
+}
+
+// -----
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_2d_to_2d_noop
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK-NOT: vector.broadcast
+gpu.func @vector_broadcast_2d_to_2d_noop(%laneid: index) {
+ %0 = "some_op"() {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : () -> vector<16x1xf16>
+ %1 = vector.broadcast %0 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<16x1xf16> to vector<16x16xf16>
+ "some_use"(%1) : (vector<16x16xf16>) -> ()
+ gpu.return
+}
+}
+
+// -----
+// Scalar to vector broadcast (with layout)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_scalar_to_vector
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<16x1xf16>
+gpu.func @vector_broadcast_scalar_to_vector(%laneid: index) {
+ %0 = "some_op"() : () -> f16
+ %1 = vector.broadcast %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : f16 to vector<16x16xf16>
+ "some_use"(%1) : (vector<16x16xf16>) -> ()
+ gpu.return
+}
+}
+
+// -----
+// Scalar to vector broadcast (no layout - uniform, should remain unchanged)
+gpu.module @xevm_module {
+// CHECK-LABEL: gpu.func @vector_broadcast_scalar_to_vector_uniform
+// CHECK: %[[SRC:.*]] = "some_op"()
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<16x16xf16>
+// CHECK: "some_use"(%[[BCAST]])
+gpu.func @vector_broadcast_scalar_to_vector_uniform(%laneid: index) {
+ %0 = "some_op"() : () -> f16
+ %1 = vector.broadcast %0 : f16 to vector<16x16xf16>
+ "some_use"(%1) : (vector<16x16xf16>) -> ()
+ gpu.return
+}
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/185960
More information about the Mlir-commits
mailing list