[Mlir-commits] [mlir] [MLIR][XeGPU] Add wg-to-sg distirbution for dpasmx, bitcast, interleave, and deinterleave (PR #194985)

Jianhui Li llvmlistbot at llvm.org
Wed May 6 10:20:40 PDT 2026


================
@@ -1403,19 +1462,116 @@ struct WgToSgVectorMaskOp : public OpConversionPattern<MaskOpType> {
 
 using WgToSgVectorConstantMaskOp = WgToSgVectorMaskOp<vector::ConstantMaskOp>;
 using WgToSgVectorCreateMaskOp = WgToSgVectorMaskOp<vector::CreateMaskOp>;
+
+// This pattern transforms vector.bitcast ops to work at subgroup level.
+struct WgToSgVectorBitCastOp : public OpConversionPattern<vector::BitCastOp> {
+  using OpConversionPattern<vector::BitCastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BitCastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newBitCastOps;
+    for (auto src : adaptor.getSource()) {
+      auto newBitCast =
+          vector::BitCastOp::create(rewriter, op.getLoc(), newResultType, src);
+      newBitCastOps.push_back(newBitCast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBitCastOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.interleave ops to work at subgroup level.
+struct WgToSgVectorInterleaveOp
+    : public OpConversionPattern<vector::InterleaveOp> {
+  using OpConversionPattern<vector::InterleaveOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InterleaveOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getResult()));
+    if (!layout || !layout.isForWorkgroup())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newInterleaveOps;
+    // Interleave operates pairwise: each lhs value is interleaved with
+    // corresponding rhs value
+    for (auto [lhs, rhs] : llvm::zip(adaptor.getLhs(), adaptor.getRhs())) {
+      auto newInterleave = vector::InterleaveOp::create(
+          rewriter, op.getLoc(), newResultType, lhs, rhs);
+      newInterleaveOps.push_back(newInterleave.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newInterleaveOps});
+    return success();
+  }
+};
+
+// This pattern transforms vector.deinterleave ops to work at subgroup level.
+struct WgToSgVectorDeinterleaveOp
+    : public OpConversionPattern<vector::DeinterleaveOp> {
+  using OpConversionPattern<vector::DeinterleaveOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::DeinterleaveOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    xegpu::DistributeLayoutAttr layout =
+        xegpu::getTemporaryLayout(dyn_cast<OpResult>(op.getRes1()));
+    if (!layout || !layout.isForWorkgroup())
----------------
Jianhui-Li wrote:

removed

https://github.com/llvm/llvm-project/pull/194985


More information about the Mlir-commits mailing list