[Mlir-commits] [mlir] [mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (PR #154403)
Chao Chen
llvmlistbot at llvm.org
Wed Aug 20 08:00:28 PDT 2025
================
@@ -137,71 +202,35 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
- auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
- if (!layout)
- return failure();
- Type elemTy = tdescTy.getElementType();
ArrayRef<int64_t> wgShape = tdescTy.getShape();
- // sgLayout must be present for workgroup-level distribution.
- SmallVector<int64_t> sgLayout;
- if (auto sgLayoutAttr = layout.getSgLayout())
- sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
- else
- return rewriter.notifyMatchFailure(
- op, "sgLayout attribute is required in layout");
-
- // Get the subgroup ID
- Value linearSgId =
- gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- int64_t startOfRange = -1, endOfRange = -1;
- bool sgIdRangeSpecified =
- isSgIdRangeSpecified(op, startOfRange, endOfRange);
-
- if (sgIdRangeSpecified) {
- int64_t sgCount = endOfRange - startOfRange;
- if (computeProduct(sgLayout) != sgCount)
- return rewriter.notifyMatchFailure(
- op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get
- // the adjusted sg id
- Value startOfRangeVal =
- arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- linearSgId =
- rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
- }
-
- auto maybeTdescOffsets =
- layout.getOffsets(rewriter, loc, linearSgId, wgShape);
- if (failed(maybeTdescOffsets))
- return failure();
-
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
- xegpu::TensorDescType newTdescTy =
- xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
- layout.dropSgLayoutAndData());
+ Type elemTy = tdescTy.getElementType();
- SmallVector<Value> newCreateNdOps;
- SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
-
- for (auto tdescOffsets : *maybeTdescOffsets) {
- SmallVector<OpFoldResult> sgOffsets;
- size_t rank = tdescOffsets.size();
- for (size_t i = 0; i < rank; i++) {
- size_t idx = origOffsets.size() - rank + i;
- Value add = rewriter.createOrFold<index::AddOp>(
- loc, tdescOffsets[i],
- getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
- sgOffsets.push_back(add);
+ // the call back function for creating new CreateNdOps,
+ // the baseOffsets is the origial offsets of the op, and
+ // descOffsets is the relative offsets to the mem_desc accessed
+ // by each subgroup op.
+ auto callback = [&](ArrayRef<OpFoldResult> baseOffsets,
----------------
chencha3 wrote:
where is scatter ops patterns?
https://github.com/llvm/llvm-project/pull/154403
More information about the Mlir-commits
mailing list