[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (PR #153432)
Chao Chen
llvmlistbot at llvm.org
Thu Aug 14 15:04:03 PDT 2025
================
@@ -296,6 +296,192 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};
+template <typename OpTy, typename AdaptorTy, typename CreateFn>
+LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
+ ConversionPatternRewriter &rewriter,
+ CreateFn &&createOp) {
+ int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+ if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+ return failure();
+
+ Location loc = op.getLoc();
+ Value tdesc = op.getTensorDesc();
+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+ if (!tdescTy)
+ return failure();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout)
+ return failure();
+
+ SmallVector<int64_t> sgLayout;
+ if (auto sgLayoutAttr = layout.getSgLayout())
+ sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+ else
+ return rewriter.notifyMatchFailure(
+ op, "sgLayout attribute is required in layout");
+
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Get the subgroup ID
+ Value linearSgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+ int64_t startOfRange = -1, endOfRange = -1;
+ bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+
+ if (sgIdRangeSpecified) {
+ int64_t sgCount = endOfRange - startOfRange;
+ if (computeProduct(sgLayout) != sgCount)
+ return rewriter.notifyMatchFailure(
+ op, "sg_layout size must match the sg_id_range");
+ Value startOfRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ linearSgId =
+ rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
+ }
+
+ auto maybeTdescOffsets =
+ layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+ if (failed(maybeTdescOffsets))
+ return failure();
+
+ SmallVector<OpFoldResult> oldOffsets;
+ if (auto constOffsets = op.getConstOffsetsAttr()) {
+ for (auto attr : constOffsets.asArrayRef())
+ oldOffsets.push_back(rewriter.getIndexAttr(attr));
+ }
+ for (auto v : op.getOffsets())
+ oldOffsets.push_back(v);
+
+ return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
+ rewriter, op);
+}
+
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ xegpu::LoadNdOp op,
+ typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return distributeNdOpWithOffset(
+ op, adaptor, rewriter,
+ [](Location loc, SmallVector<int64_t> &sgShape,
+ ArrayRef<SmallVector<Value>> tdescOffsetsList,
+ SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+ ConversionPatternRewriter &rewriter,
+ xegpu::LoadNdOp &op) -> LogicalResult {
+ SmallVector<Value> newLoadOps;
+ for (auto [tdescOffsets, tdesc] :
+ llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
+ SmallVector<OpFoldResult> newOffsets;
+ size_t rank = tdescOffsets.size();
+ for (size_t i = 0; i < rank; i++) {
+ size_t idx = oldOffsets.size() - rank + i;
+ Value add = rewriter.createOrFold<index::AddOp>(
+ loc, tdescOffsets[i],
+ getValueOrCreateConstantIndexOp(rewriter, loc,
+ oldOffsets[idx]));
+ newOffsets.push_back(add);
+ }
+ VectorType newResTy = VectorType::get(
+ sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
+ .getElementType());
+ auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+ loc, newResTy, tdesc, newOffsets,
+ /*packed=*/nullptr,
+ /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ });
+ }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+ : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+ LogicalResult matchAndRewrite(
+ xegpu::StoreNdOp op,
+ typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ return distributeNdOpWithOffset(
+ op, adaptor, rewriter,
+ [](Location loc, SmallVector<int64_t> &sgShape,
+ ArrayRef<SmallVector<Value>> tdescOffsetsList,
+ SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+ ConversionPatternRewriter &rewriter,
+ xegpu::StoreNdOp &op) -> LogicalResult {
+ for (auto [tdescOffsets, tdesc, value] :
+ llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
+ adaptor.getValue())) {
+ SmallVector<OpFoldResult> newOffsets;
+ size_t rank = tdescOffsets.size();
+ for (size_t i = 0; i < rank; i++) {
----------------
chencha3 wrote:
Seems this loop can be refactored out as an util
https://github.com/llvm/llvm-project/pull/153432
More information about the Mlir-commits
mailing list