[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)

Artem Kroviakov llvmlistbot at llvm.org
Fri Aug 29 02:54:13 PDT 2025


================
@@ -807,6 +807,210 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
   }
 };
 
+/// Distribute a scattered store op. The offsets argument is required.
+/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
+/// The layouts are fixed and implicit: one offset/mask per lane.
+/// The pass changes the offset/mask vector shapes to a
+/// single-element vector, **it is assumed that their producer will also be
+/// distributed**. The payload vector also has a fixed distribution:
+///   no chunk size -> vector of one element.
+///   chunk size    -> vector of the innermost dimension of the SG-payload.
+/// Example 1 (no chunk size):
+///    %mask = producer_op : vector<16xi1>
+///    %offset = producer_op : vector<16xindex>
+///    xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+///     memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+///    %mask = producer_op : vector<1xi1>
+///    %offset = producer_op : vector<1xindex>
+///    xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+///     memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// Example 2 (chunk size, same mask and offsets):
+///    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+///    xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+///     vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+  using gpu::WarpDistributionPattern::WarpDistributionPattern;
+  LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+                                PatternRewriter &rewriter) const override {
+    Operation *lastNode = warpOp.getTerminator()->getPrevNode();
+    auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+    if (!storeScatterOp)
+      return failure();
+    auto offsets = storeScatterOp.getOffsets();
+    if (!offsets || !isa<VectorType>(offsets.getType()))
+      return rewriter.notifyMatchFailure(
+          storeScatterOp, "Store op must have a vector of offsets argument");
+    VectorType offsetsTy = cast<VectorType>(offsets.getType());
+    VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
+    if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
+      return rewriter.notifyMatchFailure(storeScatterOp,
+                                         "Expected 1D offsets and mask vector");
+    VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
+    assert(storeVecTy.getRank() <= 2 &&
+           "Expected at most 2D result at SG level");
+    VectorType distStoreVecTy;
+    if (storeVecTy.getRank() == 2)
+      distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+    else // rank 1
+      distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+    // Assume offset and mask producers will be distributed as well.
+    VectorType distOffsetsTy =
+        VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+    VectorType distMaskTy = VectorType::get(
+        {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
+    std::string layoutPayloadName =
+        xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
+    std::string layoutOffsetsName =
+        xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
+    std::string layoutMaskName =
+        xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
+
+    xegpu::LayoutAttr layoutPayload =
+        storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
+    xegpu::LayoutAttr layoutOffsets =
+        storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+    xegpu::LayoutAttr layoutMask =
+        storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+    FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
+    FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+    FailureOr<VectorType> distMaskByWarpOpOrFailure =
+        getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+    if (failed(distStoreVecByWarpOpOrFailure) ||
+        failed(distOffsetsByWarpOpOrFailure) ||
+        failed(distMaskByWarpOpOrFailure)) {
+      storeScatterOp.emitWarning(
+          "Some vector operands have no layouts, using defaults instead.");
+    }
+    distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
----------------
akroviakov wrote:

Removed

https://github.com/llvm/llvm-project/pull/154949


More information about the Mlir-commits mailing list