[Mlir-commits] [mlir] [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (PR #144417)
Nishant Patel
llvmlistbot at llvm.org
Thu Jul 17 11:55:16 PDT 2025
https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/144417
>From f1509d2ebd1de2665a23b08f9f7d57742a9cf137 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:07 +0000
Subject: [PATCH 1/9] Add pattern for broadcast
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 61 ++++++++++++++++++-
.../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 14 +++++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 16 ++++-
3 files changed, 89 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..6cbe0d5c0f350 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
@@ -314,13 +315,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
+struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+
+ // Only handle broadcasts to vectors with XeGPU layout attribute
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return rewriter.notifyMatchFailure(
+ op, "Result does not have a valid layout attribute for subgroup distribution");
+
+ // Extract sgShape from layout
+ SmallVector<int64_t> sgShape;
+ if (auto sgDataAttr = layout.getSgData()) {
+ sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+ } else {
+ auto sgLayoutArr = layout.getSgLayout();
+ ArrayRef<int64_t> shape = resultType.getShape();
+ sgShape.reserve(shape.size());
+ for (size_t i = 0; i < shape.size(); ++i) {
+ assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+ sgShape.push_back(shape[i] / sgLayoutArr[i]);
+ }
+ }
+
+ VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+ SmallVector<Value> newBroadcasts;
+
+ // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
+ for (Value unused : adaptor.getOperands().front()) {
+ // All subgroups get the same broadcasted value
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+ newBroadcasts.push_back(newBroadcast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+ WgToSgVectorBroadcastOp>(
patterns.getContext());
}
} // namespace xegpu
@@ -369,6 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return true;
+ auto layout = xegpu::getLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..759de3a219bea 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -102,4 +102,18 @@ gpu.module @test_round_robin_assignment {
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
gpu.return
}
+
+ // CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7e89ada934071..94130236e3714 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -169,4 +169,18 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
: vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
gpu.return
}
-}
+
+// CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+}
\ No newline at end of file
>From c5cd2743719786db4a42afbf37ec2da45faef3ed Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:25 +0000
Subject: [PATCH 2/9] Add pattern for broadcast
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 41 ++++++++++---------
1 file changed, 21 insertions(+), 20 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6cbe0d5c0f350..cc43433b524af 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -316,8 +316,8 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
};
/// This pattern transforms vector.broadcast ops to work at subgroup level.
-/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
-struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+struct WgToSgVectorBroadcastOp
+ : public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
LogicalResult
@@ -325,13 +325,12 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
ConversionPatternRewriter &rewriter) const override {
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
if (!resultType)
- return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+ return failure();
// Only handle broadcasts to vectors with XeGPU layout attribute
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
- return rewriter.notifyMatchFailure(
- op, "Result does not have a valid layout attribute for subgroup distribution");
+ return failure();
// Extract sgShape from layout
SmallVector<int64_t> sgShape;
@@ -347,15 +346,17 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
}
}
- VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newBroadcasts;
- // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
- for (Value unused : adaptor.getOperands().front()) {
- // All subgroups get the same broadcasted value
+ // The operand is always a scalar or lower-rank vector, so just broadcast
+ // for each subgroup
+ for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
- op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
- xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+ op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
newBroadcasts.push_back(newBroadcast.getResult());
}
@@ -371,8 +372,7 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- WgToSgVectorBroadcastOp>(
- patterns.getContext());
+ WgToSgVectorBroadcastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -420,13 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
- target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
- auto resultType = dyn_cast<VectorType>(op.getResult().getType());
- if (!resultType)
- return true;
- auto layout = xegpu::getLayoutAttr(op.getResult());
- return isLegal(layout);
- });
+ target.addDynamicallyLegalOp<vector::BroadcastOp>(
+ [=](vector::BroadcastOp op) -> bool {
+ auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+ if (!resultType)
+ return true;
+ xegpu::LayoutAttr = xegpu::getLayoutAttr(op.getResult());
+ return isLegal(layout);
+ });
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
>From 803a565d17c2c13880d0316c39017dacd9cfb5da Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 15:47:42 +0000
Subject: [PATCH 3/9] Clean up
---
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ----
1 file changed, 4 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c3e851cc9840c..0065ee6cc424c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -341,19 +341,15 @@ struct WgToSgVectorBroadcastOp
if (!resultType)
return failure();
- // Only handle broadcasts to vectors with XeGPU layout attribute
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();
- // Extract sgShape from layout
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newBroadcasts;
- // The operand is always a scalar or lower-rank vector, so just broadcast
- // for each subgroup
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
>From 2c97ee7a0e7b5f01608b5fb4d4e3d5bd20cff355 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 19:20:15 +0000
Subject: [PATCH 4/9] Add CHECKS
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 14 ++++----------
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 4 ++++
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 5 +++--
3 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0065ee6cc424c..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -338,8 +338,6 @@ struct WgToSgVectorBroadcastOp
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();
- if (!resultType)
- return failure();
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
@@ -348,17 +346,17 @@ struct WgToSgVectorBroadcastOp
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
- SmallVector<Value> newBroadcasts;
+ SmallVector<Value> newBroadcastOps;
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
- newBroadcasts.push_back(newBroadcast.getResult());
+ newBroadcastOps.push_back(newBroadcast.getResult());
}
- rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+ rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
return success();
}
};
@@ -556,11 +554,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
- auto resultType = dyn_cast<VectorType>(op.getResult().getType());
- if (!resultType)
- return true;
- xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
- return isLegal(layout);
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 2ff12c5968b8c..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -111,6 +111,10 @@ gpu.module @test_round_robin_assignment {
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
+ // CHECK-COUNT-3: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ // CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7925095ab90f5..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,8 +170,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
-
-// CHECK-LABEL: test_broadcast
+ // CHECK-LABEL: test_broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @test_broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
@@ -179,6 +178,8 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
>From 692ae9eef25cfc6c42337c8eb0aa3f814bbe409e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 18 Jun 2025 00:02:02 +0000
Subject: [PATCH 5/9] add check
---
.../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 11 +++++++++--
1 file changed, 9 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 96c7032d6b812..8d63fcdf302c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -343,14 +343,21 @@ struct WgToSgVectorBroadcastOp
if (!layout || !layout.getSgLayout())
return failure();
+ // TODO: Currently only supports cases where the source and result ranks
+ // are the same.
+ auto srcType =
+ dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
+ if (!srcType || srcType.getRank() != resultType.getRank())
+ return failure();
+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
SmallVector<Value> newBroadcastOps;
- for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
+ for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
- op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+ op.getLoc(), newResultType, operand);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
>From 1d17537d9573de2c84b4377362b54a6373acdab1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 7 Jul 2025 22:12:19 +0000
Subject: [PATCH 6/9] Add test case for dim0
---
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 20 ++++++++++++++++++--
1 file changed, 18 insertions(+), 2 deletions(-)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index f60358f188e72..8a81a286da23a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,9 +170,9 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
- // CHECK-LABEL: broadcast
+ // CHECK-LABEL: broadcast_dim1
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
- gpu.func @broadcast(%src: memref<24x1xf32>) {
+ gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
@@ -186,6 +186,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
+ // CHECK-LABEL: broadcast_dim0
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
+ gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
+ -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+ -> vector<1x32xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
+ : vector<1x32xf32> to vector<12x32xf32>
+ gpu.return
+ }
+
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
>From 425d677422058763479809b4cf499b8f3322c92c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 8 Jul 2025 19:16:44 +0000
Subject: [PATCH 7/9] add check
---
.../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index dbb97f230c873..730449ba929b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -357,6 +357,15 @@ struct WgToSgVectorBroadcastOp
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
+ // Check if the srcShape has unit dim in dimensions being broadcasted,
+ // and the other dimensions are the same as the destination type
+ // TODO: Generalize it
+ auto srcShape = srcType.getShape();
+ for (size_t i = 0; i < srcShape.size(); ++i) {
+ if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
+ return failure();
+ }
+
SmallVector<Value> newBroadcastOps;
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
>From 00ffa572e870a3120051a100e2d6a9ce77b75f74 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 15 Jul 2025 22:12:28 +0000
Subject: [PATCH 8/9] Temp commit to check isDiscardable
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 39 +++++++++++++++++++
1 file changed, 39 insertions(+)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 730449ba929b4..60041225005f7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,6 +34,35 @@ using namespace mlir;
namespace {
+bool isDistributable(ArrayRef<int64_t> sgLayout, ArrayRef<int64_t> sgData,
+ ArrayRef<int64_t> wgShape) {
+ // Check rank consistency
+ if (sgLayout.size() != sgData.size() || sgLayout.size() != wgShape.size())
+ return false;
+
+ for (size_t i = 0; i < sgLayout.size(); ++i) {
+ int64_t subgroupCount = sgLayout[i];
+ int64_t subgroupData = sgData[i];
+ int64_t shape = wgShape[i];
+
+ // Each subgroup must have positive data size
+ if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0)
+ return false;
+
+ // Total data assigned to all subgroups in this dimension
+ int64_t totalSubgroupData = subgroupCount * subgroupData;
+
+ // Subgroups must not collectively exceed the shape
+ if (totalSubgroupData > shape)
+ return false;
+
+ // Each subgroup's data must evenly divide the shape
+ if (shape % subgroupData != 0)
+ return false;
+ }
+ return true;
+}
+
static std::pair<SmallVector<int64_t>, int>
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
int count = 1;
@@ -357,6 +386,16 @@ struct WgToSgVectorBroadcastOp
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
+ // Check if the output layout is distributable
+ SmallVector<int64_t> sgLayout;
+ if (auto sgLayoutAttr = layout.getSgLayout())
+ sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+ else
+ return failure();
+
+ if (!isDistributable(sgLayout, sgShape, wgShape))
+ return failure();
+
// Check if the srcShape has unit dim in dimensions being broadcasted,
// and the other dimensions are the same as the destination type
// TODO: Generalize it
>From 8467c29eb78412297a1dfd00bdbd7c8ea128a911 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 17 Jul 2025 18:54:14 +0000
Subject: [PATCH 9/9] Add check for output layout
---
.../Transforms/XeGPUWgToSgDistribute.cpp | 31 +------------------
1 file changed, 1 insertion(+), 30 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 60041225005f7..05a517e511ade 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,35 +34,6 @@ using namespace mlir;
namespace {
-bool isDistributable(ArrayRef<int64_t> sgLayout, ArrayRef<int64_t> sgData,
- ArrayRef<int64_t> wgShape) {
- // Check rank consistency
- if (sgLayout.size() != sgData.size() || sgLayout.size() != wgShape.size())
- return false;
-
- for (size_t i = 0; i < sgLayout.size(); ++i) {
- int64_t subgroupCount = sgLayout[i];
- int64_t subgroupData = sgData[i];
- int64_t shape = wgShape[i];
-
- // Each subgroup must have positive data size
- if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0)
- return false;
-
- // Total data assigned to all subgroups in this dimension
- int64_t totalSubgroupData = subgroupCount * subgroupData;
-
- // Subgroups must not collectively exceed the shape
- if (totalSubgroupData > shape)
- return false;
-
- // Each subgroup's data must evenly divide the shape
- if (shape % subgroupData != 0)
- return false;
- }
- return true;
-}
-
static std::pair<SmallVector<int64_t>, int>
getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
int count = 1;
@@ -393,7 +364,7 @@ struct WgToSgVectorBroadcastOp
else
return failure();
- if (!isDistributable(sgLayout, sgShape, wgShape))
+ if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
return failure();
// Check if the srcShape has unit dim in dimensions being broadcasted,
More information about the Mlir-commits
mailing list