[Mlir-commits] [mlir] 56b263b - [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (#144417)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 23 08:41:57 PDT 2025
Author: Nishant Patel
Date: 2025-07-23T08:41:53-07:00
New Revision: 56b263b1bdd1713dd4062bfd3b3a7fce4aad4b2c
URL: https://github.com/llvm/llvm-project/commit/56b263b1bdd1713dd4062bfd3b3a7fce4aad4b2c
DIFF: https://github.com/llvm/llvm-project/commit/56b263b1bdd1713dd4062bfd3b3a7fce4aad4b2c.diff
LOG: [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (#144417)
This PR adds transformation pattern for vector.broadcast op in
xegpu-wg-to-sg-distribute pass
Added:
Modified:
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 80bb5e888bdc7..e1a3d21b7f609 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ // TODO: Currently only supports cases where the source and result ranks
+ // are the same.
+ auto srcType =
+ dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
+ if (!srcType || srcType.getRank() != resultType.getRank())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ // Check if the output layout is distributable
+ SmallVector<int64_t> sgLayout;
+ if (auto sgLayoutAttr = layout.getSgLayout())
+ sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+ else
+ return failure();
+
+ if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+ return failure();
+
+ // Check if the srcShape has unit dim in dimensions being broadcasted,
+ // and the other dimensions are the same as the destination type
+ // TODO: Generalize it
+ auto srcShape = srcType.getShape();
+ for (size_t i = 0; i < srcShape.size(); ++i) {
+ if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
+ return failure();
+ }
+
+ SmallVector<Value> newBroadcastOps;
+ for (auto operand : adaptor.getOperands().front()) {
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ op.getLoc(), newResultType, operand);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
+ newBroadcastOps.push_back(newBroadcast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
+ return success();
+ }
+};
+
// This pattern transforms elementwise ops to work at subgroup level.
struct WgToSgElementwiseOp : public ConversionPattern {
WgToSgElementwiseOp(MLIRContext *ctx)
@@ -475,8 +534,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
- patterns.getContext());
+ UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
+ WgToSgVectorBroadcastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -583,6 +642,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<vector::BroadcastOp>(
+ [=](vector::BroadcastOp op) -> bool {
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index c6124f90e0f48..8a880068aab33 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
gpu.return
}
+ // CHECK-LABEL: broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ // CHECK-COUNT-3: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ // CHECK-NOT: vector.broadcast
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+
gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
gpu.return
}
-
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..8a81a286da23a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,6 +170,38 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
+ // CHECK-LABEL: broadcast_dim1
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+
+ // CHECK-LABEL: broadcast_dim0
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
+ -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ -> vector<1x32xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
+ : vector<1x32xf32> to vector<12x32xf32>
+ gpu.return
+ }
+
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +327,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
gpu.return
}
-
-
}
+
More information about the Mlir-commits
mailing list