[Mlir-commits] [mlir] [mlir][xegpu] Add definition of SliceAttr (PR #150146)

Chao Chen llvmlistbot at llvm.org
Fri Aug 8 09:05:21 PDT 2025


================
@@ -211,6 +264,146 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
   return success();
 }
 
+FailureOr<SmallVector<Value>>
+LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+                                  Value linearId) {
+  // delinearizeSubgroupId is only available for
+  // workgroup-level layout attribute
+  if (!isWgLayout())
+    return failure();
+
+  // TODO: handle order attribute
+  auto hasDefaultOrder = [&]() {
+    DenseI32ArrayAttr order = getOrder();
+    return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
+                         llvm::reverse(order.asArrayRef())));
+  };
+  if (!hasDefaultOrder())
+    return mlir::emitError(loc, "order attribute is currently not supported.");
+
+  auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value {
+    return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+  });
+
+  return affine::delinearizeIndex(builder, loc, linearId, dims);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by LayoutAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+                       ArrayRef<int64_t> shape) {
+  if (!isWgLayout())
+    return failure();
+
+  SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+  SmallVector<int64_t> sgShape;
+  if (auto maybeSgShape = getSgDataAsInt())
+    sgShape = maybeSgShape.value();
+  else if (auto ratio = computeShapeRatio(shape, sgLayout))
+    sgShape = ratio.value();
+  else
+    return failure();
+
+  // delinearize Ids
+  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+  if (failed(maybeIds))
+    return failure();
+  SmallVector<Value> sgIds = *maybeIds;
+
+  return genOffsetsComputations(builder, loc, sgIds, sgLayout, sgShape, shape);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+                  xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+  if (!parent || !dims)
+    return emitError() << "expected parent layout and dims attribute";
+
+  int64_t rank = parent.getRank();
+
+  // check every element in dims is unique and smaller than rank
+  llvm::SmallDenseSet<int64_t> seen;
+  for (int64_t dim : dims.asArrayRef()) {
+    if (dim < 0 || dim >= rank)
+      return emitError() << "invalid dim (" << dim << ") in slice attribute.";
+    if (!seen.insert(dim).second)
+      return emitError() << "repeated dim (" << dim << ") in slice attribute.";
+  }
+  return success();
+}
+
+SliceAttr SliceAttr::flatten() const {
+  xegpu::LayoutTrait parent = getParent();
+  SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
+
+  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
+    parent = sliceAttr.getParent();
+    slicedDims.push_back(sliceAttr.getDims());
+  }
+
+  auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
+  SmallVector<int64_t> indices =
+      llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
+
+  // get remaining dims (flattend) by applying slice ops with all slicedDims
+  SmallVector<int64_t> remainingDims(indices);
+  for (auto dim : llvm::reverse(slicedDims))
+    remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
+                                        dim.asArrayRef());
+
+  // get flattend sliced dims by applying slice ops with the remaining dims
+  SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
+      llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
+
+  return xegpu::SliceAttr::get(
+      getContext(), layoutAttr,
+      DenseI64ArrayAttr::get(getContext(), flattendDims));
+}
+
+FailureOr<SmallVector<Value>>
+SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+                                 Value linearId) {
+  SliceAttr attr = flatten();
+  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+  return parent.delinearizeSubgroupId(builder, loc, linearId);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by SliceAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+                      ArrayRef<int64_t> shape) {
+  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+  if (!isWgLayout())
+    return failure();
+
+  SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+  SmallVector<int64_t> sgShape;
+  if (auto maybeSgShape = getSgDataAsInt())
+    sgShape = maybeSgShape.value();
+  else if (auto ratio = computeShapeRatio(shape, sgLayout))
+    sgShape = ratio.value();
----------------
chencha3 wrote:

fixed. 

https://github.com/llvm/llvm-project/pull/150146


More information about the Mlir-commits mailing list