[Mlir-commits] [mlir] [MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (PR #144417)
Chao Chen
llvmlistbot at llvm.org
Mon Jul 21 10:54:03 PDT 2025
================
@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ // TODO: Currently only supports cases where the source and result ranks
+ // are the same.
+ auto srcType =
+ dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
+ if (!srcType || srcType.getRank() != resultType.getRank())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ // Check if the output layout is distributable
+ SmallVector<int64_t> sgLayout;
+ if (auto sgLayoutAttr = layout.getSgLayout())
+ sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+ else
+ return failure();
+
+ if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+ return failure();
+
+ // Check if the srcShape has unit dim in dimensions being broadcasted,
+ // and the other dimensions are the same as the destination type
+ // TODO: Generalize it
+ auto srcShape = srcType.getShape();
+ for (size_t i = 0; i < srcShape.size(); ++i) {
----------------
chencha3 wrote:
It seems this check duplicates the check in broadcast verifier, unless there are cases where the source vector, e.g., vector<32x1x1xf32> can be distributed to a vector, e.g., <8x2x1>.
https://github.com/llvm/llvm-project/pull/144417
More information about the Mlir-commits
mailing list