[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_gather/store_scatter op from Wg To Sg (PR #154420)

Chao Chen llvmlistbot at llvm.org
Fri Aug 22 10:08:17 PDT 2025


================
@@ -763,6 +763,88 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
   }
 };
 
+// This pattern transforms the LoadGatherOp with explicit offsets to load
+// subgroup data, similar to WgToSgLoadNdOpWithOffset.
+struct WgToSgLoadGatherOpWithOffset
+    : public OpConversionPattern<xegpu::LoadGatherOp> {
+  using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+
+    SmallVector<Value> newLoadOps;
+    auto chunkSizeAttr =
+        rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
+    VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
+    for (auto [offsets, mask] :
+         llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
+      auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>(
+          loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
+          op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
+      xegpu::setLayoutAttr(newLoadOp->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newLoadOps.push_back(newLoadOp);
+    }
+    rewriter.replaceOpWithMultiple(op, {newLoadOps});
+    return success();
+  }
+};
+
+// This pattern transforms the StoreScatterOp with explicit offsets to store
+// subgroup data, similar to WgToSgStoreNdOpWithOffset.
+struct WgToSgStoreScatterOpWithOffset
+    : public OpConversionPattern<xegpu::StoreScatterOp> {
+  using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    if (!op.getOffsets())
+      return failure();
+
+    Location loc = op.getLoc();
+    VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
+    if (!valueType)
+      return failure();
+
+    ArrayRef<int64_t> wgShape = valueType.getShape();
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getValue());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    auto chunkSizeOpt = op.getChunkSize();
+    int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1;
+    auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
+    for (auto [val, offs, mask] : llvm::zip(
+             adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
+      rewriter.create<xegpu::StoreScatterOp>(
+          loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(),
+          op.getL2HintAttr(), op.getL3HintAttr());
+      // Update the layout_result_0 attribute to drop sg_layout and sg_data.
+      if (auto layoutAttr =
----------------
chencha3 wrote:

better to use the getLayoutAttr and setLayoutAttr interface here. 

https://github.com/llvm/llvm-project/pull/154420


More information about the Mlir-commits mailing list