[Mlir-commits] [mlir] [MLIR][XeGPU] Add distribution pattern for xegpu.load & store for sg to wi pass (PR #181917)
Charitha Saumya
llvmlistbot at llvm.org
Fri Feb 20 13:25:10 PST 2026
================
@@ -395,6 +395,78 @@ struct SgToWiPrefetchNd : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// Distributes a subgroup-level LoadGather (xegpu.load) op to workitem-level.
+struct SgToWiLoadGather : public OpConversionPattern<xegpu::LoadGatherOp> {
+ using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::LoadGatherOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ xegpu::DistributeLayoutAttr layout = op.getAnchorLayout();
+ if (!layout)
+ return failure();
+
+ VectorType resultTy = op.getValueType();
+ if (!resultTy)
+ return failure();
+
+ // Check that leading dimensions are unit.
+ int chunkSize = op.getChunkSize().value_or(1);
+ int effectiveVecRank = (chunkSize == 1) ? 1 : 2;
+ for (int i = 0; i < resultTy.getRank() - effectiveVecRank; i++) {
+ if (resultTy.getShape()[i] != 1)
+ return rewriter.notifyMatchFailure(
+ op, "Only unit dimensions allowed for the leading "
+ "dimensions of the load vector!");
+ }
+
+ auto expectedWiResultTyOrFailure =
+ xegpu::getDistVecTypeBasedOnLaneLayout(layout, resultTy);
+ if (failed(expectedWiResultTyOrFailure))
+ return rewriter.notifyMatchFailure(
+ op,
+ "unable to compute expected workitem vector type from lane layout");
+
+ VectorType expectedWiResultTy = expectedWiResultTyOrFailure.value();
+ VectorType supportedWiResultTy =
+ VectorType::get({expectedWiResultTy.getNumElements()},
+ expectedWiResultTy.getElementType());
+
+ // Flatten offsets and mask to 1D to match the 1D result type.
+ Value offsets = adaptor.getOffsets();
+ if (auto offsetsTy = dyn_cast<VectorType>(offsets.getType())) {
+ VectorType offsetsTy1D = VectorType::get({offsetsTy.getNumElements()},
+ offsetsTy.getElementType());
+ if (offsetsTy != offsetsTy1D)
+ offsets = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+ offsetsTy1D, offsets)
+ .getResult();
+ }
+ Value mask = adaptor.getMask();
+ if (auto maskTy = dyn_cast<VectorType>(mask.getType())) {
+ VectorType maskTy1D =
+ VectorType::get({maskTy.getNumElements()}, maskTy.getElementType());
+ if (maskTy != maskTy1D)
+ mask =
+ vector::ShapeCastOp::create(rewriter, op.getLoc(), maskTy1D, mask)
+ .getResult();
+ }
----------------
charithaintc wrote:
consider using a lambda.
https://github.com/llvm/llvm-project/pull/181917
More information about the Mlir-commits
mailing list