[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute vector.step & vector.shape_cast op from wg to sg (PR #155443)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Aug 28 08:15:09 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nishant Patel (nbpatel)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/155443.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+118-7)
- (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir (+53)
``````````diff
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0b7fe81facfce..059641af2219a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -507,8 +507,15 @@ struct WgToSgVectorBroadcastOp
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
newResultType, operand);
- xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
- layout.dropSgLayoutAndData());
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), newLayout);
+ }
newBroadcastOps.push_back(newBroadcast.getResult());
}
@@ -566,6 +573,10 @@ struct WgToSgElementwiseOp : public ConversionPattern {
if (auto layout = dyn_cast<xegpu::LayoutAttr>(attr.getValue())) {
if (auto newLayout = layout.dropSgLayoutAndData())
state.addAttribute(attr.getName(), newLayout);
+ } else if (auto sliceAttr =
+ dyn_cast<xegpu::SliceAttr>(attr.getValue())) {
+ if (sliceAttr.isForSubgroup())
+ state.addAttribute(attr.getName(), sliceAttr.dropSgLayoutAndData());
} else {
state.addAttribute(attr.getName(), attr.getValue());
}
@@ -756,8 +767,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
- if (auto newLayout = layout.dropSgLayoutAndData())
- xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
+ }
SmallVector<Value> newConsts(count, cstOp);
rewriter.replaceOpWithMultiple(op, {newConsts});
@@ -765,6 +783,90 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
}
};
+// This pattern distributes the vector.step ops to work at subgroup level
+struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+ std::optional<SmallVector<int64_t>> sgShape =
+ getSgShapeAndCount(wgShape, layout).first;
+ if (!sgShape)
+ return failure();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto maybeOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeOffsets))
+ return failure();
+
+ VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+ Value base = vector::StepOp::create(rewriter, loc, newTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : *maybeOffsets) {
+ Value bcast =
+ vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+ Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+ newOps.push_back(add);
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+};
+
+// This pattern transforms vector.shape_cast ops to work at subgroup level.
+struct WgToSgVectorShapeCastOp
+ : public OpConversionPattern<vector::ShapeCastOp> {
+ using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return failure();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getDistributeLayoutAttr(op.getResult());
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newShapeCastOps;
+ for (auto src : adaptor.getSource()) {
+ auto newShapeCast =
+ rewriter.create<vector::ShapeCastOp>(op.getLoc(), newResultType, src);
+ if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
+ if (sliceAttr.isForSubgroup())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
+ sliceAttr.dropSgLayoutAndData());
+ } else if (auto layoutAttr =
+ dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
+ if (auto newLayout = layoutAttr.dropSgLayoutAndData())
+ xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0), newLayout);
+ }
+ newShapeCastOps.push_back(newShapeCast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
+ return success();
+ }
+};
+
struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
LogicalResult
@@ -826,8 +928,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
- WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp>(
- patterns.getContext());
+ WgToSgArithConstantOp, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp,
+ WgToSgVectorStepOp, WgToSgVectorShapeCastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -949,7 +1051,16 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecType)
return true;
- return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
+
+ auto layout = xegpu::getDistributeLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
+ target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp>(
+ [=](Operation *op) -> bool {
+ // Check for either a SliceAttr or LayoutAttr on the result.
+ auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0));
+ return isLegal(layout);
});
target.addDynamicallyLegalOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 32157a7911f62..7601274ba4969 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -2,6 +2,7 @@
//CHECK: #map = affine_map<()[s0] -> (s0 floordiv 4)>
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
+//CHECK: #map2 = affine_map<()[s0] -> (s0 floordiv 8)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -321,4 +322,56 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}
+
+ // CHECK-LABEL: vector_step_op
+ gpu.func @vector_step_op_slice_attr() {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map2()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+ gpu.return
+ }
+
+ gpu.func @vector_step_op_layout_attr() {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[c16:%.+]] = arith.constant 16 : index
+ //CHECK: [[c8:%.+]] = arith.constant 8 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[sgId]], [[c8]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<8xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<8xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<8xindex>
+ %step = vector.step {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [8]>}: vector<128xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: constant_with_slice_attr
+ gpu.func @constant_with_slice_attr() {
+ //CHECK: [[cst:%.+]] = arith.constant dense<10> : vector<1xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 1]>, dims = [1, 2, 3]>} dense<10> : vector<4xindex>
+ gpu.return
+ }
+
+ // CHECK-LABEL: vector_shape_cast
+ gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc[0, 0]
+ : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ -> vector<256x128xf32>
+ //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<8x4x8x4xf32>
+ %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [2, 4, 2, 2], sg_data = [8, 4, 8, 4]>} : vector<256x128xf32> to vector<16x16x16x8xf32>
+ gpu.return
+ }
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/155443
More information about the Mlir-commits
mailing list