[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)

Chao Chen llvmlistbot at llvm.org
Wed Aug 20 08:00:28 PDT 2025


================
@@ -137,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
     Location loc = op.getLoc();
     MLIRContext *ctx = op.getContext();
     xegpu::TensorDescType tdescTy = op.getType();
-    auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
-    if (!layout)
-      return failure();
-    Type elemTy = tdescTy.getElementType();
     ArrayRef<int64_t> wgShape = tdescTy.getShape();
-    // sgLayout must be present for workgroup-level distribution.
-    SmallVector<int64_t> sgLayout;
-    if (auto sgLayoutAttr = layout.getSgLayout())
-      sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
-    else
-      return rewriter.notifyMatchFailure(
-          op, "sgLayout attribute is required in layout");
-
-    // Get the subgroup ID
-    Value linearSgId =
-        gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
-    int64_t startOfRange = -1, endOfRange = -1;
-    bool sgIdRangeSpecified =
-        isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
-    if (sgIdRangeSpecified) {
-      int64_t sgCount = endOfRange - startOfRange;
-      if (computeProduct(sgLayout) != sgCount)
-        return rewriter.notifyMatchFailure(
-            op, "sg_layout size must match the sg_id_range");
-      // Subtract startOfRange from the original subgroup id to get
-      // the adjusted sg id
-      Value startOfRangeVal =
-          arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
-      linearSgId =
-          rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
-    }
-
-    auto maybeTdescOffsets =
-        layout.getOffsets(rewriter, loc, linearSgId, wgShape);
-    if (failed(maybeTdescOffsets))
-      return failure();
-
-    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-    xegpu::TensorDescType newTdescTy =
-        xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
-                                   layout.dropSgLayoutAndData());
+    Type elemTy = tdescTy.getElementType();
 
-    SmallVector<Value> newCreateNdOps;
-    SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
-    for (auto tdescOffsets : *maybeTdescOffsets) {
-      SmallVector<OpFoldResult> sgOffsets;
-      size_t rank = tdescOffsets.size();
-      for (size_t i = 0; i < rank; i++) {
-        size_t idx = origOffsets.size() - rank + i;
-        Value add = rewriter.createOrFold<index::AddOp>(
-            loc, tdescOffsets[i],
-            getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
-        sgOffsets.push_back(add);
+    // the call back function for creating new CreateNdOps,
+    // the baseOffsets is the origial offsets of the op, and
+    // descOffsets is the relative offsets to the mem_desc accessed
+    // by each subgroup op.
+    auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
----------------
chencha3 wrote:

where is scatter ops patterns? 

https://github.com/llvm/llvm-project/pull/154403


More information about the Mlir-commits mailing list