[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)

Charitha Saumya llvmlistbot at llvm.org
Fri Aug 22 11:31:52 PDT 2025


================
@@ -811,6 +811,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    auto yield = cast<gpu::YieldOp>(
+        warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+    Operation *lastNode = yield->getPrevNode();
+    auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+    if (!storeScatterOp)
+      return failure();
+    else if (!storeScatterOp.getOffsets())
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Store op must have offsets argument");
+    else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+                 .getRank() != 1)
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Expected 1D offsets vector");
+
+    VectorType storeVecTy =
+        cast<VectorType>(storeScatterOp.getValue().getType());
+    assert(storeVecTy.getRank() <= 2 &&
+           "Expected at most 2D result at SG level");
+    VectorType distStoreVecTy;
+    if (storeVecTy.getRank() == 2)
+      distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+    else // rank 1
+      distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+    SmallVector<size_t> newRetIndices;
+    SmallVector<Value> operands =
+        llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+    SmallVector<Type> operandTypes =
+        llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+    operandTypes[0] = distStoreVecTy;
+
+    gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+        rewriter, warpOp, operands, operandTypes, newRetIndices);
+    SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+        newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+    Value offsetsVec = newStoreScatterOpOperands[2];
+    Value maskVec = newStoreScatterOpOperands[3];
+
+    auto loc = newWarpOp.getLoc();
+    Value laneId = warpOp.getLaneid();
+    rewriter.setInsertionPointAfter(newWarpOp);
+    Value laneOffset =
+        vector::ExtractOp::create(rewriter, loc, offsetsVec, laneId);
+    laneOffset = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneOffset.getType()), laneOffset);
+    Value laneMask = vector::ExtractOp::create(rewriter, loc, maskVec, laneId);
+    laneMask = vector::BroadcastOp::create(
+        rewriter, loc, VectorType::get({1}, laneMask.getType()), laneMask);
----------------
charithaintc wrote:

I am not sure about this code sequence.

In my understanding, offstes and masks are also distributed.
```
vector<16xindex> -> vector<1xindex>
vector<16xi1> -> vector<1xi1>
```
Then each lane is extracting the 0th elemnt from them. No need to broadcast or use lane id. At SIMT level each thread need scalar offset and masks (or vector<1x*> if you think that is better.)

https://github.com/llvm/llvm-project/pull/154949


More information about the Mlir-commits mailing list