[Mlir-commits] [mlir] [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (PR #144417)

Nishant Patel llvmlistbot at llvm.org
Thu Jul 17 11:55:16 PDT 2025


https://github.com/nbpatel updated https://github.com/llvm/llvm-project/pull/144417

>From f1509d2ebd1de2665a23b08f9f7d57742a9cf137 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:07 +0000
Subject: [PATCH 1/9] Add pattern for broadcast

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 61 ++++++++++++++++++-
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 14 +++++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 16 ++++-
 3 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..6cbe0d5c0f350 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -314,13 +315,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
+struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+
+    // Only handle broadcasts to vectors with XeGPU layout attribute
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return rewriter.notifyMatchFailure(
+          op, "Result does not have a valid layout attribute for subgroup distribution");
+
+    // Extract sgShape from layout
+    SmallVector<int64_t> sgShape;
+    if (auto sgDataAttr = layout.getSgData()) {
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    } else {
+      auto sgLayoutArr = layout.getSgLayout();
+      ArrayRef<int64_t> shape = resultType.getShape();
+      sgShape.reserve(shape.size());
+      for (size_t i = 0; i < shape.size(); ++i) {
+        assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+        sgShape.push_back(shape[i] / sgLayoutArr[i]);
+      }
+    }
+
+    VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+    SmallVector<Value> newBroadcasts;
+
+    // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
+    for (Value unused : adaptor.getOperands().front()) {
+      // All subgroups get the same broadcasted value
+      auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+          op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+      newBroadcasts.push_back(newBroadcast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+               WgToSgVectorBroadcastOp>(
       patterns.getContext());
 }
 } // namespace xegpu
@@ -369,6 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
+    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return true;
+    auto layout = xegpu::getLayoutAttr(op.getResult());
+    return isLegal(layout);
+  });
+
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
   xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..759de3a219bea 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -102,4 +102,18 @@ gpu.module @test_round_robin_assignment {
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
+
+  // CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7e89ada934071..94130236e3714 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -169,4 +169,18 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
-}
+
+// CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+}
\ No newline at end of file

>From c5cd2743719786db4a42afbf37ec2da45faef3ed Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:25 +0000
Subject: [PATCH 2/9] Add pattern for broadcast

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 41 ++++++++++---------
 1 file changed, 21 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6cbe0d5c0f350..cc43433b524af 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -316,8 +316,8 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
 };
 
 /// This pattern transforms vector.broadcast ops to work at subgroup level.
-/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
-struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+struct WgToSgVectorBroadcastOp
+    : public OpConversionPattern<vector::BroadcastOp> {
   using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
 
   LogicalResult
@@ -325,13 +325,12 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
                   ConversionPatternRewriter &rewriter) const override {
     auto resultType = dyn_cast<VectorType>(op.getResult().getType());
     if (!resultType)
-      return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+      return failure();
 
     // Only handle broadcasts to vectors with XeGPU layout attribute
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
     if (!layout || !layout.getSgLayout())
-      return rewriter.notifyMatchFailure(
-          op, "Result does not have a valid layout attribute for subgroup distribution");
+      return failure();
 
     // Extract sgShape from layout
     SmallVector<int64_t> sgShape;
@@ -347,15 +346,17 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
       }
     }
 
-    VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
     SmallVector<Value> newBroadcasts;
 
-    // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
-    for (Value unused : adaptor.getOperands().front()) {
-      // All subgroups get the same broadcasted value
+    // The operand is always a scalar or lower-rank vector, so just broadcast
+    // for each subgroup
+    for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
-          op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
-      xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+        op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0),
+                 layout.dropSgLayoutAndData());
       newBroadcasts.push_back(newBroadcast.getResult());
     }
 
@@ -371,8 +372,7 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               WgToSgVectorBroadcastOp>(
-      patterns.getContext());
+               WgToSgVectorBroadcastOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -420,13 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
-    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
-    if (!resultType)
-      return true;
-    auto layout = xegpu::getLayoutAttr(op.getResult());
-    return isLegal(layout);
-  });
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+        if (!resultType)
+          return true;
+        xegpu::LayoutAttr = xegpu::getLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
 
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 

>From 803a565d17c2c13880d0316c39017dacd9cfb5da Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 15:47:42 +0000
Subject: [PATCH 3/9] Clean up

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c3e851cc9840c..0065ee6cc424c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -341,19 +341,15 @@ struct WgToSgVectorBroadcastOp
     if (!resultType)
       return failure();
 
-    // Only handle broadcasts to vectors with XeGPU layout attribute
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
     if (!layout || !layout.getSgLayout())
       return failure();
 
-    // Extract sgShape from layout
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
     SmallVector<Value> newBroadcasts;
 
-    // The operand is always a scalar or lower-rank vector, so just broadcast
-    // for each subgroup
     for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
           op.getLoc(), newResultType, adaptor.getOperands().front()[i]);

>From 2c97ee7a0e7b5f01608b5fb4d4e3d5bd20cff355 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 19:20:15 +0000
Subject: [PATCH 4/9] Add CHECKS

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp     | 14 ++++----------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir     |  4 ++++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir        |  5 +++--
 3 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0065ee6cc424c..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -338,8 +338,6 @@ struct WgToSgVectorBroadcastOp
                   ConversionPatternRewriter &rewriter) const override {
     VectorType resultType = op.getResult().getType();
     ArrayRef<int64_t> wgShape = resultType.getShape();
-    if (!resultType)
-      return failure();
 
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
     if (!layout || !layout.getSgLayout())
@@ -348,17 +346,17 @@ struct WgToSgVectorBroadcastOp
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
-    SmallVector<Value> newBroadcasts;
 
+    SmallVector<Value> newBroadcastOps;
     for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
           op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
       xegpu::setLayoutAttr(newBroadcast->getResult(0),
                            layout.dropSgLayoutAndData());
-      newBroadcasts.push_back(newBroadcast.getResult());
+      newBroadcastOps.push_back(newBroadcast.getResult());
     }
 
-    rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+    rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
     return success();
   }
 };
@@ -556,11 +554,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
       [=](vector::BroadcastOp op) -> bool {
-        auto resultType = dyn_cast<VectorType>(op.getResult().getType());
-        if (!resultType)
-          return true;
-        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-        return isLegal(layout);
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 2ff12c5968b8c..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -111,6 +111,10 @@ gpu.module @test_round_robin_assignment {
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
       -> vector<24x1xf32>
+    // CHECK-COUNT-3: vector.broadcast {{.*}}
+    // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+    // CHECK-NOT: vector.broadcast
     %broadcast = vector.broadcast %load 
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
       : vector<24x1xf32> to vector<24x8xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7925095ab90f5..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,8 +170,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
-
-// CHECK-LABEL: test_broadcast
+  // CHECK-LABEL: test_broadcast
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
   gpu.func @test_broadcast(%src: memref<24x1xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
@@ -179,6 +178,8 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
       -> vector<24x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
     %broadcast = vector.broadcast %load 
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
       : vector<24x1xf32> to vector<24x8xf32>

>From 692ae9eef25cfc6c42337c8eb0aa3f814bbe409e Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 18 Jun 2025 00:02:02 +0000
Subject: [PATCH 5/9] add check

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp        | 11 +++++++++--
 1 file changed, 9 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 96c7032d6b812..8d63fcdf302c2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -343,14 +343,21 @@ struct WgToSgVectorBroadcastOp
     if (!layout || !layout.getSgLayout())
       return failure();
 
+    // TODO: Currently only supports cases where the source and result ranks
+    // are the same.
+    auto srcType =
+        dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
+    if (!srcType || srcType.getRank() != resultType.getRank())
+      return failure();
+
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
     SmallVector<Value> newBroadcastOps;
-    for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
+    for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
-          op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+          op.getLoc(), newResultType, operand);
       xegpu::setLayoutAttr(newBroadcast->getResult(0),
                            layout.dropSgLayoutAndData());
       newBroadcastOps.push_back(newBroadcast.getResult());

>From 1d17537d9573de2c84b4377362b54a6373acdab1 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 7 Jul 2025 22:12:19 +0000
Subject: [PATCH 6/9] Add test case for dim0

---
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir | 20 ++++++++++++++++++--
 1 file changed, 18 insertions(+), 2 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index f60358f188e72..8a81a286da23a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,9 +170,9 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
-  // CHECK-LABEL: broadcast
+  // CHECK-LABEL: broadcast_dim1
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
-  gpu.func @broadcast(%src: memref<24x1xf32>) {
+  gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
       -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
@@ -186,6 +186,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+  // CHECK-LABEL: broadcast_dim0
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
+  gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
+      -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+      -> vector<1x32xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+    %broadcast = vector.broadcast %load
+      {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
+      : vector<1x32xf32> to vector<12x32xf32>
+    gpu.return
+  }
+
   gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
     //CHECK: [[c0:%.+]] = arith.constant 0 : index
     //CHECK: [[c128:%.+]] = arith.constant 128 : index

>From 425d677422058763479809b4cf499b8f3322c92c Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 8 Jul 2025 19:16:44 +0000
Subject: [PATCH 7/9] add check

---
 .../Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp   | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index dbb97f230c873..730449ba929b4 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -357,6 +357,15 @@ struct WgToSgVectorBroadcastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
+    // Check if the srcShape has unit dim in dimensions being broadcasted,
+    // and the other dimensions are the same as the destination type
+    // TODO: Generalize it
+    auto srcShape = srcType.getShape();
+    for (size_t i = 0; i < srcShape.size(); ++i) {
+      if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
+        return failure();
+    }
+
     SmallVector<Value> newBroadcastOps;
     for (auto operand : adaptor.getOperands().front()) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(

>From 00ffa572e870a3120051a100e2d6a9ce77b75f74 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Tue, 15 Jul 2025 22:12:28 +0000
Subject: [PATCH 8/9] Temp commit to check isDiscardable

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 39 +++++++++++++++++++
 1 file changed, 39 insertions(+)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 730449ba929b4..60041225005f7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,6 +34,35 @@ using namespace mlir;
 
 namespace {
 
+bool isDistributable(ArrayRef<int64_t> sgLayout, ArrayRef<int64_t> sgData,
+                     ArrayRef<int64_t> wgShape) {
+  // Check rank consistency
+  if (sgLayout.size() != sgData.size() || sgLayout.size() != wgShape.size())
+    return false;
+
+  for (size_t i = 0; i < sgLayout.size(); ++i) {
+    int64_t subgroupCount = sgLayout[i];
+    int64_t subgroupData = sgData[i];
+    int64_t shape = wgShape[i];
+
+    // Each subgroup must have positive data size
+    if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0)
+      return false;
+
+    // Total data assigned to all subgroups in this dimension
+    int64_t totalSubgroupData = subgroupCount * subgroupData;
+
+    // Subgroups must not collectively exceed the shape
+    if (totalSubgroupData > shape)
+      return false;
+
+    // Each subgroup's data must evenly divide the shape
+    if (shape % subgroupData != 0)
+      return false;
+  }
+  return true;
+}
+
 static std::pair<SmallVector<int64_t>, int>
 getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   int count = 1;
@@ -357,6 +386,16 @@ struct WgToSgVectorBroadcastOp
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
 
+    // Check if the output layout is distributable
+    SmallVector<int64_t> sgLayout;
+    if (auto sgLayoutAttr = layout.getSgLayout())
+      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    else
+      return failure();
+
+    if (!isDistributable(sgLayout, sgShape, wgShape))
+      return failure();
+
     // Check if the srcShape has unit dim in dimensions being broadcasted,
     // and the other dimensions are the same as the destination type
     // TODO: Generalize it

>From 8467c29eb78412297a1dfd00bdbd7c8ea128a911 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Thu, 17 Jul 2025 18:54:14 +0000
Subject: [PATCH 9/9] Add check for output layout

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 31 +------------------
 1 file changed, 1 insertion(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 60041225005f7..05a517e511ade 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -34,35 +34,6 @@ using namespace mlir;
 
 namespace {
 
-bool isDistributable(ArrayRef<int64_t> sgLayout, ArrayRef<int64_t> sgData,
-                     ArrayRef<int64_t> wgShape) {
-  // Check rank consistency
-  if (sgLayout.size() != sgData.size() || sgLayout.size() != wgShape.size())
-    return false;
-
-  for (size_t i = 0; i < sgLayout.size(); ++i) {
-    int64_t subgroupCount = sgLayout[i];
-    int64_t subgroupData = sgData[i];
-    int64_t shape = wgShape[i];
-
-    // Each subgroup must have positive data size
-    if (subgroupData <= 0 || subgroupCount <= 0 || shape <= 0)
-      return false;
-
-    // Total data assigned to all subgroups in this dimension
-    int64_t totalSubgroupData = subgroupCount * subgroupData;
-
-    // Subgroups must not collectively exceed the shape
-    if (totalSubgroupData > shape)
-      return false;
-
-    // Each subgroup's data must evenly divide the shape
-    if (shape % subgroupData != 0)
-      return false;
-  }
-  return true;
-}
-
 static std::pair<SmallVector<int64_t>, int>
 getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
   int count = 1;
@@ -393,7 +364,7 @@ struct WgToSgVectorBroadcastOp
     else
       return failure();
 
-    if (!isDistributable(sgLayout, sgShape, wgShape))
+    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
       return failure();
 
     // Check if the srcShape has unit dim in dimensions being broadcasted,



More information about the Mlir-commits mailing list