[Mlir-commits] [mlir] [MLIR][XeGPU] Distribute load_nd/store_nd/prefetch_nd with offsets from Wg to Sg (PR #153432)

Adam Siemieniuk llvmlistbot at llvm.org
Mon Aug 18 09:16:48 PDT 2025


================
@@ -296,6 +296,208 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
   }
 };
 
+// Utility function to compute global offsets for subgroup operations.
+// Returns a vector of new offsets for each subgroup, given the original op's
+// offsets and subgroup relative offsets.
+static SmallVector<SmallVector<OpFoldResult>>
+computeGlobalOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
+                     ArrayRef<OpFoldResult> wgOffsets,
+                     ConversionPatternRewriter &rewriter) {
+  SmallVector<SmallVector<OpFoldResult>> globalOffsets;
+  Location loc = op->getLoc();
+  for (const auto &sgOffsets : sgOffsetsList) {
+    SmallVector<OpFoldResult> newOffsets;
+    size_t rank = sgOffsets.size();
+    for (size_t i = 0; i < rank; i++) {
+      size_t idx = wgOffsets.size() - rank + i;
+      Value add = rewriter.createOrFold<index::AddOp>(
+          loc, sgOffsets[i],
+          getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+      newOffsets.push_back(add);
+    }
+    globalOffsets.push_back(std::move(newOffsets));
+  }
+  return globalOffsets;
+}
+
+// Utility function to get sgShape, sgOffsetList for a given
+// op.
+template <typename OpTy, typename AdaptorTy>
+LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
+                           ConversionPatternRewriter &rewriter,
+                           SmallVector<int64_t> &sgShape,
+                           SmallVector<SmallVector<Value>> &sgOffsetList) {
+  int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+  if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+    return failure();
+
+  Location loc = op.getLoc();
+  Value tdesc = op.getTensorDesc();
+  auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+  if (!tdescTy)
+    return failure();
+  auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+  if (!layout)
+    return failure();
+
+  SmallVector<int64_t> sgLayout;
+  if (auto sgLayoutAttr = layout.getSgLayout())
+    sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+  else
+    return rewriter.notifyMatchFailure(
----------------
adam-smnk wrote:

nit: I'd convert it into guard case `if (!attr) return;`

https://github.com/llvm/llvm-project/pull/153432


More information about the Mlir-commits mailing list