[Mlir-commits] [mlir] [MLIR][XeGPU] Scattered ops sg-to-wi distribution (PR #154949)
Artem Kroviakov
llvmlistbot at llvm.org
Fri Aug 29 02:54:18 PDT 2025
================
@@ -807,6 +807,210 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern {
}
};
+/// Distribute a scattered store op. The offsets argument is required.
+/// Both offset and mask vectors must be 1D and have #subgroup_size elements.
+/// The layouts are fixed and implicit: one offset/mask per lane.
+/// The pass changes the offset/mask vector shapes to a
+/// single-element vector, **it is assumed that their producer will also be
+/// distributed**. The payload vector also has a fixed distribution:
+/// no chunk size -> vector of one element.
+/// chunk size -> vector of the innermost dimension of the SG-payload.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<16xf16>,
+/// memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// xegpu.store %payload, %src[%offset], %mask : vector<1xf16>,
+/// memref<256xf16>, vector<1xindex>, vector<1xi1>
+/// Example 2 (chunk size, same mask and offsets):
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
+/// To
+/// xegpu.store %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// vector<8xf16>, memref<256xf16>, vector<1xindex>, vector<1xi1>
+struct StoreDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ Operation *lastNode = warpOp.getTerminator()->getPrevNode();
+ auto storeScatterOp = dyn_cast_or_null<xegpu::StoreScatterOp>(lastNode);
+ if (!storeScatterOp)
+ return failure();
+ auto offsets = storeScatterOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()))
+ return rewriter.notifyMatchFailure(
+ storeScatterOp, "Store op must have a vector of offsets argument");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType maskTy = cast<VectorType>(storeScatterOp.getMask().getType());
+ if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
+ return rewriter.notifyMatchFailure(storeScatterOp,
+ "Expected 1D offsets and mask vector");
+ VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType());
+ assert(storeVecTy.getRank() <= 2 &&
+ "Expected at most 2D result at SG level");
+ VectorType distStoreVecTy;
+ if (storeVecTy.getRank() == 2)
+ distStoreVecTy = VectorType::Builder(storeVecTy).dropDim(0);
+ else // rank 1
+ distStoreVecTy = VectorType::Builder(storeVecTy).setDim(0, 1);
+ // Assume offset and mask producers will be distributed as well.
+ VectorType distOffsetsTy =
+ VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+ VectorType distMaskTy = VectorType::get(
+ {1}, getElementTypeOrSelf(storeScatterOp.getMask().getType()));
+ std::string layoutPayloadName =
+ xegpu::getLayoutName(storeScatterOp->getOpOperand(0));
+ std::string layoutOffsetsName =
+ xegpu::getLayoutName(storeScatterOp->getOpOperand(2));
+ std::string layoutMaskName =
+ xegpu::getLayoutName(storeScatterOp->getOpOperand(3));
+
+ xegpu::LayoutAttr layoutPayload =
+ storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutPayloadName);
+ xegpu::LayoutAttr layoutOffsets =
+ storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+ xegpu::LayoutAttr layoutMask =
+ storeScatterOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+ FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutPayload, storeVecTy);
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+ if (failed(distStoreVecByWarpOpOrFailure) ||
+ failed(distOffsetsByWarpOpOrFailure) ||
+ failed(distMaskByWarpOpOrFailure)) {
+ storeScatterOp.emitWarning(
+ "Some vector operands have no layouts, using defaults instead.");
+ }
+ distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or(distStoreVecTy);
+ distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or(distOffsetsTy);
+ distMaskTy = distMaskByWarpOpOrFailure.value_or(distMaskTy);
+
+ SmallVector<size_t> newRetIndices;
+ SmallVector<Value> operands = storeScatterOp->getOperands();
+ SmallVector<Type> operandTypesToYield = {
+ distStoreVecTy, operands[1].getType(), distOffsetsTy, distMaskTy};
+
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
+ SmallVector<Value> newStoreScatterOpOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ xegpu::StoreScatterOp newOp = xegpu::StoreScatterOp::create(
+ rewriter, newWarpOp.getLoc(), TypeRange{}, newStoreScatterOpOperands,
+ storeScatterOp->getAttrs());
+ xegpu::removeLayoutAttrs(newOp);
+ rewriter.eraseOp(storeScatterOp);
+ return success();
+ }
+};
+
+/// Distribute a scattered load op. The logic and requirements are the same as
+/// for the scattered store distribution. The warpOp's payload vector is
+/// expected to be distributed by the load's result consumer.
+/// Example 1 (no chunk size):
+/// %mask = producer_op : vector<16xi1>
+/// %offset = producer_op : vector<16xindex>
+/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+/// vector<16xindex>, vector<16xi1> -> vector<16xf16>
+/// To
+/// %mask = producer_op : vector<1xi1>
+/// %offset = producer_op : vector<1xindex>
+/// %0 = xegpu.load %payload, %src[%offset], %mask : memref<256xf16>,
+/// vector<1xindex>, vector<1xi1> -> vector<1xf16>
+/// Example 2 (chunk size, same mask and offsets):
+/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
+/// To
+/// %0 = xegpu.load %payload, %src[%offset], %mask <{chunk_size=8}> :
+/// memref<256xf16>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+struct LoadDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ // Check if the yield operand that was produced by the *last* scattered
+ // load op to avoid sinking it before barriers (maintain memory order).
+ return isa<xegpu::LoadGatherOp>(op) &&
+ warpOp.getTerminator()->getPrevNode() == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadGatherOp");
+
+ auto loadGatherOp =
+ producedByLastLoad->get().getDefiningOp<xegpu::LoadGatherOp>();
+ auto offsets = loadGatherOp.getOffsets();
+ if (!offsets || !isa<VectorType>(offsets.getType()) ||
+ !isa<VectorType>(loadGatherOp.getMask().getType()))
+ return rewriter.notifyMatchFailure(
+ loadGatherOp,
+ "Load op must have a vector arguments for offsets and mask");
+ VectorType offsetsTy = cast<VectorType>(offsets.getType());
+ VectorType maskTy = cast<VectorType>(loadGatherOp.getMask().getType());
+ if (offsetsTy.getRank() != 1 || maskTy.getRank() != 1)
+ return rewriter.notifyMatchFailure(loadGatherOp,
+ "Expected 1D offsets and mask vector");
+ // Assume offset and mask producers will be distributed as well.
+ VectorType distOffsetsTy =
+ VectorType::get({1}, getElementTypeOrSelf(offsetsTy));
+ VectorType distMaskTy = VectorType::get({1}, getElementTypeOrSelf(maskTy));
+
+ std::string layoutOffsetsName =
+ xegpu::getLayoutName(loadGatherOp->getOpOperand(1));
+ std::string layoutMaskName =
+ xegpu::getLayoutName(loadGatherOp->getOpOperand(2));
+
+ xegpu::LayoutAttr layoutOffsets =
+ loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutOffsetsName);
+ xegpu::LayoutAttr layoutMask =
+ loadGatherOp->getAttrOfType<xegpu::LayoutAttr>(layoutMaskName);
+
+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutOffsets, offsetsTy);
+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layoutMask, maskTy);
+ if (failed(distOffsetsByWarpOpOrFailure) ||
+ failed(distMaskByWarpOpOrFailure)) {
+ loadGatherOp.emitWarning(
+ "Some vector operands have no layouts, using defaults instead.");
+ }
----------------
akroviakov wrote:
Now it is a matching failure
https://github.com/llvm/llvm-project/pull/154949
More information about the Mlir-commits
mailing list