[Mlir-commits] [mlir] [mlir][xegpu] Add definition of SliceAttr (PR #150146)
Nishant Patel
llvmlistbot at llvm.org
Wed Aug 6 10:56:12 PDT 2025
================
@@ -177,74 +144,56 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-
- // TODO : Handle order attribute
// Get the subgroup ID
- auto linearSgId =
+ Value linearSgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- // Create constants for layout dimensions
- SmallVector<Value> sgLayoutDim(sgLayout.size());
- SmallVector<Value> sgDataDim(sgShape.size());
-
- for (size_t i = 0; i < sgLayout.size(); i++) {
- sgLayoutDim[i] =
- arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
- sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
- }
-
int64_t startOfRange = -1, endOfRange = -1;
bool sgIdRangeSpecified =
isSgIdRangeSpecified(op, startOfRange, endOfRange);
- Value adjustedSgId = linearSgId;
if (sgIdRangeSpecified) {
int64_t sgCount = endOfRange - startOfRange;
if (computeProduct(sgLayout) != sgCount)
return rewriter.notifyMatchFailure(
op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get the adjusted
- // sg id
+ // Subtract startOfRange from the original subgroup id to get
+ // the adjusted sg id
Value startOfRangeVal =
- arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- adjustedSgId =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ linearSgId =
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
}
- auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
- if (failed(deLinearizeSgId))
+ auto maybeTdescOffsets =
+ layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+ if (failed(maybeTdescOffsets))
return failure();
- SmallVector<Value> sgIds = *deLinearizeSgId;
-
- // Calculate distribution unit shape and local offsets for subgroup
- SmallVector<int64_t> distUnitShape(sgLayout.size());
- SmallVector<Value> localOffset(sgLayout.size());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
- localOffset[i] =
- rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
- }
-
- SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
xegpu::TensorDescType newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
layout.dropSgLayoutAndData());
+
SmallVector<Value> newCreateNdOps;
- for (SmallVector<int64_t> distUnitBaseAddr :
- StaticTileOffsetRange(wgShape, distUnitShape)) {
- SmallVector<OpFoldResult> globalOffsets =
- calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
- distUnitBaseAddr, distUnitShape);
-
- auto newCreateNdOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
+ SmallVector<OpFoldResult> wgTileOffsets = op.getMixedOffsets();
----------------
nbpatel wrote:
nit: maybe better to not use tile and just use wgOffsets and sgOffsets as var names
https://github.com/llvm/llvm-project/pull/150146
More information about the Mlir-commits
mailing list