[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (PR #153432)

Chao Chen llvmlistbot at llvm.org
Thu Aug 14 15:04:03 PDT 2025


================
@@ -296,6 +296,192 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
+template <typename OpTy, typename AdaptorTy, typename CreateFn>
+LogicalResult distributeNdOpWithOffset(OpTy op, AdaptorTy adaptor,
+                                       ConversionPatternRewriter &rewriter,
+                                       CreateFn &&createOp) {
+  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+    return failure();
+
+  Location loc = op.getLoc();
+  Value tdesc = op.getTensorDesc();
+  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+  if (!tdescTy)
+    return failure();
+  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+  if (!layout)
+    return failure();
+
+  SmallVector<int64_t> sgLayout;
+  if (auto sgLayoutAttr = layout.getSgLayout())
+    sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+  else
+    return rewriter.notifyMatchFailure(
+        op, "sgLayout attribute is required in layout");
+
+  ArrayRef<int64_t> wgShape = tdescTy.getShape();
+  SmallVector<int64_t> sgShape;
+  int count;
+  std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+  // Get the subgroup ID
+  Value linearSgId =
+      gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+  int64_t startOfRange = -1, endOfRange = -1;
+  bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+
+  if (sgIdRangeSpecified) {
+    int64_t sgCount = endOfRange - startOfRange;
+    if (computeProduct(sgLayout) != sgCount)
+      return rewriter.notifyMatchFailure(
+          op, "sg_layout size must match the sg_id_range");
+    Value startOfRangeVal =
+        rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+    linearSgId =
+        rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
+  }
+
+  auto maybeTdescOffsets =
+      layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+  if (failed(maybeTdescOffsets))
+    return failure();
+
+  SmallVector<OpFoldResult> oldOffsets;
+  if (auto constOffsets = op.getConstOffsetsAttr()) {
+    for (auto attr : constOffsets.asArrayRef())
+      oldOffsets.push_back(rewriter.getIndexAttr(attr));
+  }
+  for (auto v : op.getOffsets())
+    oldOffsets.push_back(v);
+
+  return createOp(loc, sgShape, *maybeTdescOffsets, oldOffsets, adaptor,
+                  rewriter, op);
+}
+
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+  using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      xegpu::LoadNdOp op,
+      typename OpConversionPattern<xegpu::LoadNdOp>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    return distributeNdOpWithOffset(
+        op, adaptor, rewriter,
+        [](Location loc, SmallVector<int64_t> &sgShape,
+           ArrayRef<SmallVector<Value>> tdescOffsetsList,
+           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+           ConversionPatternRewriter &rewriter,
+           xegpu::LoadNdOp &op) -> LogicalResult {
+          SmallVector<Value> newLoadOps;
+          for (auto [tdescOffsets, tdesc] :
+               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc())) {
+            SmallVector<OpFoldResult> newOffsets;
+            size_t rank = tdescOffsets.size();
+            for (size_t i = 0; i < rank; i++) {
+              size_t idx = oldOffsets.size() - rank + i;
+              Value add = rewriter.createOrFold<index::AddOp>(
+                  loc, tdescOffsets[i],
+                  getValueOrCreateConstantIndexOp(rewriter, loc,
+                                                  oldOffsets[idx]));
+              newOffsets.push_back(add);
+            }
+            VectorType newResTy = VectorType::get(
+                sgShape, dyn_cast<xegpu::TensorDescType>(tdesc.getType())
+                             .getElementType());
+            auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+                loc, newResTy, tdesc, newOffsets,
+                /*packed=*/nullptr,
+                /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+                op.getL3HintAttr());
+            newLoadOps.push_back(newLoadOp);
+          }
+          rewriter.replaceOpWithMultiple(op, {newLoadOps});
+          return success();
+        });
+  }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+    : public OpConversionPattern<xegpu::StoreNdOp> {
+  using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+  LogicalResult matchAndRewrite(
+      xegpu::StoreNdOp op,
+      typename OpConversionPattern<xegpu::StoreNdOp>::OneToNOpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    return distributeNdOpWithOffset(
+        op, adaptor, rewriter,
+        [](Location loc, SmallVector<int64_t> &sgShape,
+           ArrayRef<SmallVector<Value>> tdescOffsetsList,
+           SmallVector<OpFoldResult> &oldOffsets, OneToNOpAdaptor &adaptor,
+           ConversionPatternRewriter &rewriter,
+           xegpu::StoreNdOp &op) -> LogicalResult {
+          for (auto [tdescOffsets, tdesc, value] :
+               llvm::zip(tdescOffsetsList, adaptor.getTensorDesc(),
+                         adaptor.getValue())) {
+            SmallVector<OpFoldResult> newOffsets;
+            size_t rank = tdescOffsets.size();
+            for (size_t i = 0; i < rank; i++) {
----------------
chencha3 wrote:

Seems this loop can be refactored out as an util

https://github.com/llvm/llvm-project/pull/153432


More information about the Mlir-commits mailing list