[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)

Charitha Saumya llvmlistbot at llvm.org
Fri Aug 22 11:31:52 PDT 2025


================
@@ -811,6 +811,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+    if (!storeScatterOp)
+      return failure();
+    else if (!storeScatterOp.getOffsets())
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Store op must have offsets argument");
+    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+                 .getRank() != 1)
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Expected 1D offsets vector");
+
+    VectorType storeVecTy =
+        cast<VectorType>(storeScatterOp.getValue().getType());
+    assert(storeVecTy.getRank() <= 2 &&
+           "Expected at most 2D result at SG level");
+    VectorType distStoreVecTy;
----------------
charithaintc wrote:

strongly suggest using `getDistVecTypeBasedOnLaneLayout` to get distributed type. This will also check if the SG vector type is consistent in terms of the lane layout (i.e. it is distributable to lanes), which is not done here. 

https://github.com/llvm/llvm-project/pull/154949


More information about the Mlir-commits mailing list