[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)
Artem Kroviakov
llvmlistbot at llvm.org
Sat Aug 23 02:34:47 PDT 2025
================
@@ -811,6 +811,136 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ auto yield = cast<gpu::YieldOp>(
+ warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
+ Operation *lastNode = yield->getPrevNode();
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+ if (!storeScatterOp)
+ return failure();
+ else if (!storeScatterOp.getOffsets())
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Store op must have offsets argument");
+ else if (cast<VectorType>(storeScatterOp.getOffsets().getType())
+ .getRank() != 1)
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Expected 1D offsets vector");
+
+ VectorType storeVecTy =
+ cast<VectorType>(storeScatterOp.getValue().getType());
+ assert(storeVecTy.getRank() <= 2 &&
+ "Expected at most 2D result at SG level");
+ VectorType distStoreVecTy;
+ if (storeVecTy.getRank() == 2)
+ distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+ else // rank 1
+ distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands =
+ llvm::to_vector_of<Value>(storeScatterOp->getOperands());
+ SmallVector<Type> operandTypes =
+ llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes());
+ operandTypes[0] = distStoreVecTy;
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ Value offsetsVec = newStoreScatterOpOperands[2];
+ Value maskVec = newStoreScatterOpOperands[3];
+
+ auto loc = newWarpOp.getLoc();
+ Value laneId = warpOp.getLaneid();
----------------
akroviakov wrote:
> this will only give the laneId that executed the code inside warpOp (lane 0 in most examples).
warpOp's argument is not zero-predicated. It is the lowering that is supposed to zero-predicate the lane and use it in some `if`-like branching, I do not query the argument of that `if`. Also, the user has no control over warpOp's argument, because it is completely internal to the pass. The conversion tests show a simple `gpu.lane_id`.
https://github.com/llvm/llvm-project/pull/154949
More information about the Mlir-commits
mailing list