[Mlir-commits] [mlir] 56b263b - [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (#144417)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 23 08:41:57 PDT 2025


Author: Nishant Patel
Date: 2025-07-23T08:41:53-07:00
New Revision: 56b263b1bdd1713dd4062bfd3b3a7fce4aad4b2c

URL: https://github.com/llvm/llvm-project/commit/56b263b1bdd1713dd4062bfd3b3a7fce4aad4b2c
DIFF: https://github.com/llvm/llvm-project/commit/56b263b1bdd1713dd4062bfd3b3a7fce4aad4b2c.diff

LOG: [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (#144417)

This PR adds transformation pattern for vector.broadcast op in
xegpu-wg-to-sg-distribute pass

Added: 
    

Modified: 
    mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
    mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 80bb5e888bdc7..e1a3d21b7f609 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+    : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    // TODO: Currently only supports cases where the source and result ranks
+    // are the same.
+    auto srcType =
+        dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
+    if (!srcType || srcType.getRank() != resultType.getRank())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    // Check if the output layout is distributable
+    SmallVector<int64_t> sgLayout;
+    if (auto sgLayoutAttr = layout.getSgLayout())
+      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    else
+      return failure();
+
+    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+      return failure();
+
+    // Check if the srcShape has unit dim in dimensions being broadcasted,
+    // and the other dimensions are the same as the destination type
+    // TODO: Generalize it
+    auto srcShape = srcType.getShape();
+    for (size_t i = 0; i < srcShape.size(); ++i) {
+      if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
+        return failure();
+    }
+
+    SmallVector<Value> newBroadcastOps;
+    for (auto operand : adaptor.getOperands().front()) {
+      auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+          op.getLoc(), newResultType, operand);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newBroadcastOps.push_back(newBroadcast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
+    return success();
+  }
+};
+
 // This pattern transforms elementwise ops to work at subgroup level.
 struct WgToSgElementwiseOp : public ConversionPattern {
   WgToSgElementwiseOp(MLIRContext *ctx)
@@ -475,8 +534,8 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
-      patterns.getContext());
+               UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
+               WgToSgVectorBroadcastOp>(patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -583,6 +642,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
       [=](Operation *op) -> std::optional<bool> {
         // Only handle elementwise mappable ops

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index c6124f90e0f48..8a880068aab33 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
     gpu.return
   }
 
+  // CHECK-LABEL: broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    // CHECK-COUNT-3: vector.broadcast {{.*}}
+    // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+    // CHECK-NOT: vector.broadcast
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+
   gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c1 = arith.constant 1 : index
     %c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     gpu.return
   }
-
 }

diff  --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..8a81a286da23a 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,6 +170,38 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+  // CHECK-LABEL: broadcast_dim1
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+
+  // CHECK-LABEL: broadcast_dim0
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
+  gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
+      -> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
+      -> vector<1x32xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
+    %broadcast = vector.broadcast %load
+      {layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
+      : vector<1x32xf32> to vector<12x32xf32>
+    gpu.return
+  }
+
   gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
     //CHECK: [[c0:%.+]] = arith.constant 0 : index
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +327,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     gpu.return
   }
-
-
 }
+


        


More information about the Mlir-commits mailing list