[Mlir-commits] [mlir] [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (PR #144417)

Nishant Patel llvmlistbot at llvm.org
Mon Jun 16 12:34:23 PDT 2025


https://github.com/nbpatel created https://github.com/llvm/llvm-project/pull/144417

None

>From f1509d2ebd1de2665a23b08f9f7d57742a9cf137 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:07 +0000
Subject: [PATCH 1/4] Add pattern for broadcast

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 61 ++++++++++++++++++-
 .../test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir | 14 +++++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir   | 16 ++++-
 3 files changed, 89 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 3bf76af674ba0..6cbe0d5c0f350 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -314,13 +315,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
+struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+
+    // Only handle broadcasts to vectors with XeGPU layout attribute
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return rewriter.notifyMatchFailure(
+          op, "Result does not have a valid layout attribute for subgroup distribution");
+
+    // Extract sgShape from layout
+    SmallVector<int64_t> sgShape;
+    if (auto sgDataAttr = layout.getSgData()) {
+      sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
+    } else {
+      auto sgLayoutArr = layout.getSgLayout();
+      ArrayRef<int64_t> shape = resultType.getShape();
+      sgShape.reserve(shape.size());
+      for (size_t i = 0; i < shape.size(); ++i) {
+        assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
+        sgShape.push_back(shape[i] / sgLayoutArr[i]);
+      }
+    }
+
+    VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+    SmallVector<Value> newBroadcasts;
+
+    // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
+    for (Value unused : adaptor.getOperands().front()) {
+      // All subgroups get the same broadcasted value
+      auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+          op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+      newBroadcasts.push_back(newBroadcast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+    return success();
+  }
+};
+
 } // namespace
 
 namespace mlir {
 namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
-               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
+               WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+               WgToSgVectorBroadcastOp>(
       patterns.getContext());
 }
 } // namespace xegpu
@@ -369,6 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
+    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+    if (!resultType)
+      return true;
+    auto layout = xegpu::getLayoutAttr(op.getResult());
+    return isLegal(layout);
+  });
+
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 
   xegpu::populateXeGPUWgToSgDistributePatterns(patterns);
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index bee026eb2084d..759de3a219bea 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -102,4 +102,18 @@ gpu.module @test_round_robin_assignment {
       : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
     gpu.return
   }
+
+  // CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7e89ada934071..94130236e3714 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -169,4 +169,18 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
-}
+
+// CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+}
\ No newline at end of file

>From c5cd2743719786db4a42afbf37ec2da45faef3ed Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Wed, 11 Jun 2025 21:15:25 +0000
Subject: [PATCH 2/4] Add pattern for broadcast

---
 .../Transforms/XeGPUWgToSgDistribute.cpp      | 41 ++++++++++---------
 1 file changed, 21 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 6cbe0d5c0f350..cc43433b524af 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -316,8 +316,8 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
 };
 
 /// This pattern transforms vector.broadcast ops to work at subgroup level.
-/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
-struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
+struct WgToSgVectorBroadcastOp
+    : public OpConversionPattern<vector::BroadcastOp> {
   using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
 
   LogicalResult
@@ -325,13 +325,12 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
                   ConversionPatternRewriter &rewriter) const override {
     auto resultType = dyn_cast<VectorType>(op.getResult().getType());
     if (!resultType)
-      return rewriter.notifyMatchFailure(op, "Result is not a vector type");
+      return failure();
 
     // Only handle broadcasts to vectors with XeGPU layout attribute
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
     if (!layout || !layout.getSgLayout())
-      return rewriter.notifyMatchFailure(
-          op, "Result does not have a valid layout attribute for subgroup distribution");
+      return failure();
 
     // Extract sgShape from layout
     SmallVector<int64_t> sgShape;
@@ -347,15 +346,17 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
       }
     }
 
-    VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
     SmallVector<Value> newBroadcasts;
 
-    // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
-    for (Value unused : adaptor.getOperands().front()) {
-      // All subgroups get the same broadcasted value
+    // The operand is always a scalar or lower-rank vector, so just broadcast
+    // for each subgroup
+    for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
-          op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
-      xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
+        op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0),
+                 layout.dropSgLayoutAndData());
       newBroadcasts.push_back(newBroadcast.getResult());
     }
 
@@ -371,8 +372,7 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               WgToSgVectorBroadcastOp>(
-      patterns.getContext());
+               WgToSgVectorBroadcastOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -420,13 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
-  target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
-    auto resultType = dyn_cast<VectorType>(op.getResult().getType());
-    if (!resultType)
-      return true;
-    auto layout = xegpu::getLayoutAttr(op.getResult());
-    return isLegal(layout);
-  });
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        auto resultType = dyn_cast<VectorType>(op.getResult().getType());
+        if (!resultType)
+          return true;
+        xegpu::LayoutAttr = xegpu::getLayoutAttr(op.getResult());
+        return isLegal(layout);
+      });
 
   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
 

>From 803a565d17c2c13880d0316c39017dacd9cfb5da Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 15:47:42 +0000
Subject: [PATCH 3/4] Clean up

---
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index c3e851cc9840c..0065ee6cc424c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -341,19 +341,15 @@ struct WgToSgVectorBroadcastOp
     if (!resultType)
       return failure();
 
-    // Only handle broadcasts to vectors with XeGPU layout attribute
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
     if (!layout || !layout.getSgLayout())
       return failure();
 
-    // Extract sgShape from layout
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
     SmallVector<Value> newBroadcasts;
 
-    // The operand is always a scalar or lower-rank vector, so just broadcast
-    // for each subgroup
     for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
           op.getLoc(), newResultType, adaptor.getOperands().front()[i]);

>From 2c97ee7a0e7b5f01608b5fb4d4e3d5bd20cff355 Mon Sep 17 00:00:00 2001
From: nbpatel <nishant.b.patel at intel.com>
Date: Mon, 16 Jun 2025 19:20:15 +0000
Subject: [PATCH 4/4] Add CHECKS

---
 .../XeGPU/Transforms/XeGPUWgToSgDistribute.cpp     | 14 ++++----------
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir     |  4 ++++
 mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir        |  5 +++--
 3 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 0065ee6cc424c..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -338,8 +338,6 @@ struct WgToSgVectorBroadcastOp
                   ConversionPatternRewriter &rewriter) const override {
     VectorType resultType = op.getResult().getType();
     ArrayRef<int64_t> wgShape = resultType.getShape();
-    if (!resultType)
-      return failure();
 
     xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
     if (!layout || !layout.getSgLayout())
@@ -348,17 +346,17 @@ struct WgToSgVectorBroadcastOp
     SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
     VectorType newResultType =
         VectorType::get(sgShape, resultType.getElementType());
-    SmallVector<Value> newBroadcasts;
 
+    SmallVector<Value> newBroadcastOps;
     for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
       auto newBroadcast = rewriter.create<vector::BroadcastOp>(
           op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
       xegpu::setLayoutAttr(newBroadcast->getResult(0),
                            layout.dropSgLayoutAndData());
-      newBroadcasts.push_back(newBroadcast.getResult());
+      newBroadcastOps.push_back(newBroadcast.getResult());
     }
 
-    rewriter.replaceOpWithMultiple(op, {newBroadcasts});
+    rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
     return success();
   }
 };
@@ -556,11 +554,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
 
   target.addDynamicallyLegalOp<vector::BroadcastOp>(
       [=](vector::BroadcastOp op) -> bool {
-        auto resultType = dyn_cast<VectorType>(op.getResult().getType());
-        if (!resultType)
-          return true;
-        xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
-        return isLegal(layout);
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
       });
 
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 2ff12c5968b8c..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -111,6 +111,10 @@ gpu.module @test_round_robin_assignment {
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
       -> vector<24x1xf32>
+    // CHECK-COUNT-3: vector.broadcast {{.*}}
+    // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+    // CHECK-NOT: vector.broadcast
     %broadcast = vector.broadcast %load 
       {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
       : vector<24x1xf32> to vector<24x8xf32>
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 7925095ab90f5..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,8 +170,7 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
-
-// CHECK-LABEL: test_broadcast
+  // CHECK-LABEL: test_broadcast
   // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
   gpu.func @test_broadcast(%src: memref<24x1xf32>) {
     %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
@@ -179,6 +178,8 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     %load =  xegpu.load_nd %tdesc
       : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
       -> vector<24x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
     %broadcast = vector.broadcast %load 
       {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
       : vector<24x1xf32> to vector<24x8xf32>



More information about the Mlir-commits mailing list