[Mlir-commits] [mlir] [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (PR #144417)
Nishant Patel
llvmlistbot at llvm.org
Mon Jun 16 12:34:23 PDT 2025
https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/144417
None
>From f1509d2ebd1de2665a23b08f9f7d57742a9cf137 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:07 +0000
Subject: [PATCH 1/4] Add pattern for broadcast
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 61 ++++++++++++++++++-
.../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 14 +++++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 16 ++++-
3 files changed, 89 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..6cbe0d5c0f350 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -314,13 +315,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
+struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+
+ // Only handle broadcasts to vectors with XeGPU layout attribute
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return rewriter.notifyMatchFailure(
+ op, "Result does not have a valid layout attribute for subgroup distribution");
+
+ // Extract sgShape from layout
+ SmallVector<int64_t> sgShape;
+ if (auto sgDataAttr = layout.getSgData()) {
+ sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+ } else {
+ auto sgLayoutArr = layout.getSgLayout();
+ ArrayRef<int64_t> shape = resultType.getShape();
+ sgShape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i) {
+ assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+ sgShape.push_back(shape[i] / sgLayoutArr[i]);
+ }
+ }
+
+ VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+ SmallVector<Value> newBroadcasts;
+
+ // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
+ for (Value unused : adaptor.getOperands().front()) {
+ // All subgroups get the same broadcasted value
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+ newBroadcasts.push_back(newBroadcast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+ WgToSgVectorBroadcastOp>(
patterns.getContext());
}
} // namespace xegpu
@@ -369,6 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return true;
+ auto layout = xegpu::getLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..759de3a219bea 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -102,4 +102,18 @@ gpu.module @test_round_robin_assignment {
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
gpu.return
}
+
+ // CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7e89ada934071..94130236e3714 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -169,4 +169,18 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
: vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
gpu.return
}
-}
+
+// CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+}
\ No newline at end of file
>From c5cd2743719786db4a42afbf37ec2da45faef3ed Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:25 +0000
Subject: [PATCH 2/4] Add pattern for broadcast
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 41 ++++++++++---------
1 file changed, 21 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6cbe0d5c0f350..cc43433b524af 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -316,8 +316,8 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
};
/// This pattern transforms vector.broadcast ops to work at subgroup level.
-/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
-struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+struct WgToSgVectorBroadcastOp
+ : public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
@@ -325,13 +325,12 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
ConversionPatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
if (!resultType)
- return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+ return failure();
// Only handle broadcasts to vectors with XeGPU layout attribute
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
- return rewriter.notifyMatchFailure(
- op, "Result does not have a valid layout attribute for subgroup distribution");
+ return failure();
// Extract sgShape from layout
SmallVector<int64_t> sgShape;
@@ -347,15 +346,17 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
}
}
- VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newBroadcasts;
- // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
- for (Value unused : adaptor.getOperands().front()) {
- // All subgroups get the same broadcasted value
+ // The operand is always a scalar or lower-rank vector, so just broadcast
+ // for each subgroup
+ for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
- op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
- xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+ op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcasts.push_back(newBroadcast.getResult());
}
@@ -371,8 +372,7 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- WgToSgVectorBroadcastOp>(
- patterns.getContext());
+ WgToSgVectorBroadcastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -420,13 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
- auto resultType = dyn_cast<VectorType>(op.getResult().getType());
- if (!resultType)
- return true;
- auto layout = xegpu::getLayoutAttr(op.getResult());
- return isLegal(layout);
- });
+ target.addDynamicallyLegalOp<vector::BroadcastOp>(
+ [=](vector::BroadcastOp op) -> bool {
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return true;
+ xegpu::LayoutAttr = xegpu::getLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
>From 803a565d17c2c13880d0316c39017dacd9cfb5da Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 15:47:42 +0000
Subject: [PATCH 3/4] Clean up
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ----
1 file changed, 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c3e851cc9840c..0065ee6cc424c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -341,19 +341,15 @@ struct WgToSgVectorBroadcastOp
if (!resultType)
return failure();
- // Only handle broadcasts to vectors with XeGPU layout attribute
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();
- // Extract sgShape from layout
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newBroadcasts;
- // The operand is always a scalar or lower-rank vector, so just broadcast
- // for each subgroup
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
>From 2c97ee7a0e7b5f01608b5fb4d4e3d5bd20cff355 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 19:20:15 +0000
Subject: [PATCH 4/4] Add CHECKS
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 14 ++++----------
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 4 ++++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 5 +++--
3 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0065ee6cc424c..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -338,8 +338,6 @@ struct WgToSgVectorBroadcastOp
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- if (!resultType)
- return failure();
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
@@ -348,17 +346,17 @@ struct WgToSgVectorBroadcastOp
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
- SmallVector<Value> newBroadcasts;
+ SmallVector<Value> newBroadcastOps;
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
- newBroadcasts.push_back(newBroadcast.getResult());
+ newBroadcastOps.push_back(newBroadcast.getResult());
}
- rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+ rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
return success();
}
};
@@ -556,11 +554,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
- auto resultType = dyn_cast<VectorType>(op.getResult().getType());
- if (!resultType)
- return true;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- return isLegal(layout);
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 2ff12c5968b8c..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -111,6 +111,10 @@ gpu.module @test_round_robin_assignment {
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
+ // CHECK-COUNT-3: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ // CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7925095ab90f5..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,8 +170,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
-
-// CHECK-LABEL: test_broadcast
+ // CHECK-LABEL: test_broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @test_broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
@@ -179,6 +178,8 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
More information about the Mlir-commits
mailing list