[Mlir-commits] [mlir] [MLIR][XeGPU] Add wg-to-sg distirbution for dpasmx, bitcast, interleave, and deinterleave (PR #194985)
Jianhui Li
llvmlistbot at llvm.org
Wed May 6 10:20:40 PDT 2026
================
@@ -1403,19 +1462,116 @@ struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
+
+// This pattern transforms vector.bitcast ops to work at subgroup level.
+struct WgToSgVectorBitCastOp : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newBitCastOps;
+ for (auto src : adaptor.getSource()) {
+ auto newBitCast =
+ vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
+ newBitCastOps.push_back(newBitCast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newBitCastOps});
+ return success();
+ }
+};
+
+// This pattern transforms vector.interleave ops to work at subgroup level.
+struct WgToSgVectorInterleaveOp
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+ if (!layout || !layout.isForWorkgroup())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newInterleaveOps;
+ // Interleave operates pairwise: each lhs value is interleaved with
+ // corresponding rhs value
+ for (auto [lhs, rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
+ auto newInterleave = vector::InterleaveOp::create(
+ rewriter, op.getLoc(), newResultType, lhs, rhs);
+ newInterleaveOps.push_back(newInterleave.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
+ return success();
+ }
+};
+
+// This pattern transforms vector.deinterleave ops to work at subgroup level.
+struct WgToSgVectorDeinterleaveOp
+ : public OpConversionPattern<vector::DeinterleaveOp> {
+ using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout =
+ xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
+ if (!layout || !layout.isForWorkgroup())
----------------
Jianhui-Li wrote:
removed
https://github.com/llvm/llvm-project/pull/194985
More information about the Mlir-commits
mailing list