[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)
Artem Kroviakov
llvmlistbot at llvm.org
Thu Aug 28 07:47:31 PDT 2025
================
@@ -807,6 +807,156 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Distribute a scattered store op. The offsets argument is required.
+/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
+/// The layouts are fixed and implicit: one offset/mask per lane.
+/// The pass changes the offset/mask vector shapes to a
+/// single-element vector, **it is assumed that their producer will also be
+/// distributed**. The payload vector also has a fixed distribution:
+/// no chunk size -> vector of one element.
+/// chunk size -> vector of the innermost dimension of the SG-payload.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+/// memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// Example 2 (chunk size, same mask and offsets):
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ Operation *lastNode = warpOp.getTerminator()->getPrevNode();
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+ if (!storeScatterOp)
+ return failure();
+ auto offsets = storeScatterOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()))
+ return rewriter.notifyMatchFailure(
+ storeScatterOp, "Store op must have a vector of offsets argument");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
+ if (offsetsTy.getRank() != 1)
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Expected 1D offsets vector");
+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
+ assert(storeVecTy.getRank() <= 2 &&
+ "Expected at most 2D result at SG level");
+ VectorType distStoreVecTy;
+ if (storeVecTy.getRank() == 2)
----------------
akroviakov wrote:
It is `[16,1]`. I have updated the propagation logic to account for offsets. I have also added a layout-based distribution with a fallback to the existing solution.
> layout assigned to offsets (byt propagation logic)
The propagation does its job at pushing the default mask/offset layout to producers, not preserving it for later passes. To use the layout at distribution, we have to manually specify it per operand anyway. So we get the overhead both in the pass and for the user.
> I understand that the layout is not useful here. But it is better to keep this logic in a single place.
It looks like we try to solve a trivial problem by applying a generic solution (instead of a correspondingly trivial one) with its dependencies. A user has to manually specify properly named attributes that are always the same, and now it is also the compiler's job to make sure the user specified the default and only possible layout?
I thought layouts exist so that a user can communicate some detail to the compiler, not the other way around.
In any case, I see the merit in asking for a uniform distribution logic, and it is implemented now. However, the user experience also needs to improve.
https://github.com/llvm/llvm-project/pull/154949
More information about the Mlir-commits
mailing list