[Mlir-commits] [mlir] [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (PR #144417)

Chao Chen llvmlistbot at llvm.org
Mon Jul 21 10:54:03 PDT 2025


================
@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+    : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    // TODO: Currently only supports cases where the source and result ranks
+    // are the same.
+    auto srcType =
+        dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
+    if (!srcType || srcType.getRank() != resultType.getRank())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    // Check if the output layout is distributable
+    SmallVector<int64_t> sgLayout;
+    if (auto sgLayoutAttr = layout.getSgLayout())
+      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+    else
+      return failure();
+
+    if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+      return failure();
+
+    // Check if the srcShape has unit dim in dimensions being broadcasted,
+    // and the other dimensions are the same as the destination type
+    // TODO: Generalize it
+    auto srcShape = srcType.getShape();
+    for (size_t i = 0; i < srcShape.size(); ++i) {
----------------
chencha3 wrote:

It seems this check duplicates the check in broadcast verifier, unless there are cases where the source vector, e.g., vector<32x1x1xf32> can be distributed to a vector, e.g., <8x2x1>. 

https://github.com/llvm/llvm-project/pull/144417


More information about the Mlir-commits mailing list