[Mlir-commits] [mlir] c962234 - [mlir][xegpu] Add definition of SliceAttr (#150146)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 8 09:27:20 PDT 2025
Author: Chao Chen
Date: 2025-08-08T11:27:17-05:00
New Revision: c96223434c64d32c3f397d20a8ed1d9749aae441
URL: https://github.com/llvm/llvm-project/commit/c96223434c64d32c3f397d20a8ed1d9749aae441
DIFF: https://github.com/llvm/llvm-project/commit/c96223434c64d32c3f397d20a8ed1d9749aae441.diff
LOG: [mlir][xegpu] Add definition of SliceAttr (#150146)
---------
Co-authored-by: Charitha Saumya <136391709+charithaintc at users.noreply.github.com>
Added:
mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
mlir/test/Dialect/XeGPU/invalid.mlir
mlir/test/Dialect/XeGPU/layout.mlir
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
index 3f8cac4dc07c3..728f1aa859061 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
@@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
+
+set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
+mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
+mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
+add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
+add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 8e2784f40ad39..3592da4c46364 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
@@ -22,17 +23,19 @@
namespace mlir {
namespace xegpu {
class TensorDescType;
+class LayoutAttr;
+class SliceAttr;
} // namespace xegpu
} // namespace mlir
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
+#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
-
-#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 64eb21cbc3c4c..1f420c13ebae0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -175,7 +175,38 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}
-def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
+def LayoutTrait: AttrInterface<"LayoutTrait"> {
+ let cppNamespace = "::mlir::xegpu";
+ let description = [{
+ Common trait for all XeGPU layouts.
+ }];
+
+ let methods = [
+ InterfaceMethod<"Get the rank of attribute",
+ "int64_t",
+ "getRank">,
+ InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
+ "std::optional<SmallVector<int64_t>>",
+ "getSgLayoutAsInt">,
+ InterfaceMethod<"Get the SgData field of the attribute as integer array",
+ "std::optional<SmallVector<int64_t>>",
+ "getSgDataAsInt">,
+ InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
+ indices based on the effective subgroup layout.}],
+ "FailureOr<SmallVector<Value>>",
+ "delinearizeSubgroupId",
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
+ InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
+ assigned to a subgroup identified by linearId. The shape parameter
+ represents the workgroup-level problem size. Each subgroup may access
+ multiple blocks according to round-robin distribution rules.}],
+ "FailureOr<SmallVector<SmallVector<Value>>>",
+ "getOffsets",
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
+ ];
+}
+
+def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
@@ -330,12 +361,143 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}
+
+ std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
+ if (DenseI32ArrayAttr layout = getSgLayout())
+ return llvm::to_vector_of<int64_t>(layout.asArrayRef());
+ return std::nullopt;
+ }
+
+ std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
+ if (DenseI32ArrayAttr data = getSgData())
+ return llvm::to_vector_of<int64_t>(data.asArrayRef());
+ return std::nullopt;
+ }
+
+ /// Delinearizes a linear subgroup ID into its multidimensional indices
+ /// based on the effective subgroup layout.
+ FailureOr<SmallVector<Value>>
+ delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+
+ /// Generates instructions to compute multidimensional offsets for blocks
+ /// assigned to a subgroup identified by linearId. The shape parameter
+ /// represents the workgroup-level problem size. Each subgroup may access
+ /// multiple blocks according to round-robin distribution rules.
+ FailureOr<SmallVector<SmallVector<Value>>>
+ getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+
}];
let assemblyFormat = "`<` struct(params) `>`";
let genVerifyDecl = 1;
}
+
+def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
+ let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
+
+ let description = [{
+ Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
+ However, whereas LayoutAttr requires the data to have the same rank as the attribute,
+ SliceAttr permits the data to have a lower rank. In this case, compute units in the
+ specified dimensions (given by `$dims`) share the data, provided that the remaining
+ ranks match the data rank. SliceAttr is commonly used by operations such as
+ vector.multi_reduction and vector.broadcast.
+
+ Example:
+ ```
+ #l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
+ #r = #xegpu.slice<#l, dim = [0]>
+
+ %exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
+ %red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
+ %bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
+ ```
+ In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to
+ a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a
+ single reduction result of type vector<32xf32>.
+
+ }];
+
+ let parameters = (ins
+ "xegpu::LayoutTrait": $parent,
+ "DenseI64ArrayAttr": $dims
+ );
+
+ let extraClassDeclaration = [{
+
+ int64_t getRank() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.getRank() - attr.getDims().size();
+ }
+
+ DenseI32ArrayAttr getOrder() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.getOrder();
+ }
+
+ bool isWgLayout() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.isWgLayout();
+ }
+
+ bool isSgLayout() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.isSgLayout();
+ }
+
+ /// Returns the SgLayout of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ if (auto layout = parent.getSgLayoutAsInt()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
+ }
+ return std::nullopt;
+ }
+
+ /// Returns the SgData of the attribute, computed by applying
+ /// the slice dimensions to the underlying LayoutAttr.
+ std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ if (auto data = parent.getSgDataAsInt()) {
+ ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
+ return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
+ }
+ return std::nullopt;
+ }
+
+ /// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
+ /// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
+ /// it will coalese two slice operations and return a simplified SliceAttr
+ /// #xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0, 1]>
+ SliceAttr flatten() const;
+
+ /// Delinearizes a linear subgroup ID into its multidimensional indices
+ /// based on the effective subgroup layout.
+ FailureOr<SmallVector<Value>>
+ delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+
+ /// Generates instructions to compute multidimensional offsets for blocks
+ /// assigned to a subgroup identified by linearId. The shape parameter
+ /// represents the workgroup-level problem size. Each subgroup may access
+ /// multiple blocks according to round-robin distribution rules.
+ FailureOr<SmallVector<SmallVector<Value>>>
+ getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+
+ }];
+
+ let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
+ let genVerifyDecl = 1;
+}
+
def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
let summary = [{Specifies a half-open range}];
let description = [{
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index 549018b61d6fb..76d58e5ea2424 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
/// Checks if the given shape can be evenly distributed based on the layout
/// and data factors provided by the LayoutAttr.
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
+
+ /// drops/slices the shape in the specified dims, and return the rest. e.g.,
+ /// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
+ template<typename T, typename U>
+ static llvm::SmallVector<T> slice(llvm::ArrayRef<T> shape, llvm::ArrayRef<U> dims) {
+ llvm::SmallVector<T> result;
+ for (auto [i, v]: llvm::enumerate(shape)) {
+ if (!llvm::is_contained(dims, i))
+ result.push_back(v);
+ }
+ return result;
+ }
}];
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97ccfdf6d..7c6a4f37db9af 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,11 +7,14 @@ add_mlir_dialect_library(MLIRXeGPUDialect
DEPENDS
MLIRXeGPUIncGen
+ MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRIndexDialect
+ MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
MLIRIR
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3c0ca114a62d4..d997296a22c20 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,12 +6,16 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
using std::optional;
@@ -33,6 +37,57 @@ void XeGPUDialect::initialize() {
>();
}
+/// Generates instructions to compute offsets for a subgroup identified by
+/// its multidimensional indices (sgId), using the specified subgroup layout
+/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
+/// dimensions (sizePerWg).
+static SmallVector<SmallVector<Value>>
+genOffsetsComputingInsts(OpBuilder &builder, Location loc,
+ SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
+ ArrayRef<int64_t> sizePerSg,
+ ArrayRef<int64_t> sizePerWg) {
+
+ SmallVector<SmallVector<Value>> offsets;
+
+ // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
+ SmallVector<Value> localOffsets = llvm::map_to_vector(
+ llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::MulOp>(
+ loc, std::get<0>(t),
+ builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+ });
+
+ // distUnit[i] is the minimum value between sizePerWg[i] and
+ // sgLayout[i] * sizePerSg[i]
+ SmallVector<int64_t> distUnit = llvm::map_to_vector(
+ llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+ for (SmallVector<int64_t> unitOffs :
+ StaticTileOffsetRange(sizePerWg, distUnit)) {
+ SmallVector<Value> base =
+ llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
+ return builder.create<arith::ConstantIndexOp>(loc, d);
+ });
+
+ SmallVector<Value> adds = llvm::map_to_vector(
+ llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
+ std::get<1>(t));
+ });
+
+ SmallVector<Value> mods = llvm::map_to_vector(
+ llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::RemUOp>(
+ loc, std::get<0>(t),
+ builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+ });
+
+ offsets.push_back(mods);
+ }
+ return offsets;
+}
+
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
@@ -211,6 +266,148 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return success();
}
+FailureOr<SmallVector<Value>>
+LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ // delinearizeSubgroupId is only available for
+ // workgroup-level layout attribute
+ if (!isWgLayout())
+ return failure();
+
+ // TODO: handle order attribute
+ auto hasDefaultOrder = [&]() {
+ DenseI32ArrayAttr order = getOrder();
+ return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
+ llvm::reverse(order.asArrayRef())));
+ };
+ if (!hasDefaultOrder())
+ return mlir::emitError(loc, "order attribute is currently not supported.");
+
+ auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
+
+ return affine::delinearizeIndex(builder, loc, linearId, dims);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by LayoutAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ if (!isWgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+ SmallVector<int64_t> sgShape;
+ if (auto maybeSgShape = getSgDataAsInt())
+ sgShape = maybeSgShape.value();
+ else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+ SmallVector<Value> sgIds = *maybeIds;
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+ if (!parent || !dims)
+ return emitError() << "expected parent layout and dims attribute";
+
+ int64_t rank = parent.getRank();
+
+ // check every element in dims is unique and smaller than rank
+ llvm::SmallDenseSet<int64_t> seen;
+ for (int64_t dim : dims.asArrayRef()) {
+ if (dim < 0 || dim >= rank)
+ return emitError() << "invalid dim (" << dim << ") in slice attribute.";
+ if (!seen.insert(dim).second)
+ return emitError() << "repeated dim (" << dim << ") in slice attribute.";
+ }
+ return success();
+}
+
+SliceAttr SliceAttr::flatten() const {
+ xegpu::LayoutTrait parent = getParent();
+ SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
+
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
+ parent = sliceAttr.getParent();
+ slicedDims.push_back(sliceAttr.getDims());
+ }
+
+ auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
+ SmallVector<int64_t> indices =
+ llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
+
+ // get remaining dims (flattend) by applying slice ops with all slicedDims
+ SmallVector<int64_t> remainingDims(indices);
+ for (auto dim : llvm::reverse(slicedDims))
+ remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
+ dim.asArrayRef());
+
+ // get flattend sliced dims by applying slice ops with the remaining dims
+ SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
+ llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
+
+ return xegpu::SliceAttr::get(
+ getContext(), layoutAttr,
+ DenseI64ArrayAttr::get(getContext(), flattendDims));
+}
+
+FailureOr<SmallVector<Value>>
+SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.delinearizeSubgroupId(builder, loc, linearId);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by SliceAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+ if (!isWgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+ SmallVector<int64_t> sgShape;
+ if (auto maybeSgShape = getSgDataAsInt())
+ sgShape = maybeSgShape.value();
+ else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+
+ // The effective sgIds for offsets computing correspond
+ // to the dims that are not sliced.
+ ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
+ SmallVector<Value> sgIds =
+ XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..fc11fa810a35c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -931,6 +931,9 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
} // namespace xegpu
} // namespace mlir
+namespace mlir {
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
+} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70cca288f..4a5525c8abb30 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -125,39 +125,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
- // Calculate offset for each subgroup
- static SmallVector<OpFoldResult>
- calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
- const SmallVector<OpFoldResult> &originalOffsets,
- const SmallVector<Value> &localOffset,
- const SmallVector<int64_t> &distUnitBaseAddr,
- const SmallVector<int64_t> &distUnitShape) {
- assert(localOffset.size() == distUnitBaseAddr.size() &&
- "localOffset and distUnitBaseAddr must have the same rank");
-
- SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
- originalOffsets.end());
- size_t rank = localOffset.size();
- for (size_t i = 0; i < rank; ++i) {
- size_t dimIdx = originalOffsets.size() - rank + i;
- Value constOffset =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]);
- Value offset =
- rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
- Value modValue =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]);
- Value offsetMod =
- rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
- Value origOffset = getValueOrCreateConstantIndexOp(
- rewriter, loc, originalOffsets[dimIdx]);
- Value globalOffset =
- rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
- globalOffsets[dimIdx] = globalOffset;
- }
-
- return globalOffsets;
- }
-
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -177,74 +144,56 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-
- // TODO : Handle order attribute
// Get the subgroup ID
- auto linearSgId =
+ Value linearSgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- // Create constants for layout dimensions
- SmallVector<Value> sgLayoutDim(sgLayout.size());
- SmallVector<Value> sgDataDim(sgShape.size());
-
- for (size_t i = 0; i < sgLayout.size(); i++) {
- sgLayoutDim[i] =
- arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
- sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
- }
-
int64_t startOfRange = -1, endOfRange = -1;
bool sgIdRangeSpecified =
isSgIdRangeSpecified(op, startOfRange, endOfRange);
- Value adjustedSgId = linearSgId;
if (sgIdRangeSpecified) {
int64_t sgCount = endOfRange - startOfRange;
if (computeProduct(sgLayout) != sgCount)
return rewriter.notifyMatchFailure(
op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get the adjusted
- // sg id
+ // Subtract startOfRange from the original subgroup id to get
+ // the adjusted sg id
Value startOfRangeVal =
- arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- adjustedSgId =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ linearSgId =
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
}
- auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
- if (failed(deLinearizeSgId))
+ auto maybeTdescOffsets =
+ layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+ if (failed(maybeTdescOffsets))
return failure();
- SmallVector<Value> sgIds = *deLinearizeSgId;
-
- // Calculate distribution unit shape and local offsets for subgroup
- SmallVector<int64_t> distUnitShape(sgLayout.size());
- SmallVector<Value> localOffset(sgLayout.size());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
- localOffset[i] =
- rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
- }
-
- SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
xegpu::TensorDescType newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
layout.dropSgLayoutAndData());
+
SmallVector<Value> newCreateNdOps;
- for (SmallVector<int64_t> distUnitBaseAddr :
- StaticTileOffsetRange(wgShape, distUnitShape)) {
- SmallVector<OpFoldResult> globalOffsets =
- calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
- distUnitBaseAddr, distUnitShape);
-
- auto newCreateNdOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
+ SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets();
+
+ for (auto tdescOffsets : *maybeTdescOffsets) {
+ SmallVector<OpFoldResult> sgOffsets;
+ size_t rank = tdescOffsets.size();
+ for (size_t i = 0; i < rank; i++) {
+ size_t idx = wgOffsets.size() - rank + i;
+ Value add = rewriter.createOrFold<index::AddOp>(
+ loc, tdescOffsets[i],
+ getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx]));
+ sgOffsets.push_back(add);
+ }
+
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
op.getMixedSizes(), op.getMixedStrides());
- newCreateNdOps.push_back(newCreateNdOp);
+ newCreateNdOps.push_back(newOp);
}
-
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
return success();
}
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..948f136d78709 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -743,3 +743,22 @@ func.func @tensor_desc_invalid_sg_data(%src: ui64, %offsets: vector<16xindex>) {
#xegpu.layout<lane_layout = [8, 1], lane_data = [1, 2], order = [0, 1, 2]>>
return
}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{repeated dim (2) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [2, 2]>
+func.func @slice_attr_repeat_dim() {
+ %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+ return
+}
+
+// -----
+#l = #xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>
+// expected-error at +1 {{invalid dim (3) in slice attribute}}
+#s = #xegpu.slice<#l, dims = [3]>
+func.func @slice_attr_repeat_dim() {
+ %offsets = arith.constant {layout_result_0 = #s} dense<0.8> : vector<16x8xindex>
+ return
+}
+
diff --git a/mlir/test/Dialect/XeGPU/layout.mlir b/mlir/test/Dialect/XeGPU/layout.mlir
index 017dacc8d629a..e4b4e22e5cf97 100644
--- a/mlir/test/Dialect/XeGPU/layout.mlir
+++ b/mlir/test/Dialect/XeGPU/layout.mlir
@@ -50,4 +50,27 @@ gpu.func @convert_layout_wg(%a: vector<32x64xf16>) {
gpu.return
}
+gpu.func @slice_attr() {
+ //CHECK: arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
+ gpu.return
+}
+
+gpu.func @nested_slice_attr() {
+ //CHECK: arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>, dims = [1]>} dense<8> : vector<16xindex>
+ %cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>, dims = [1]>} dense<8> : vector<16xindex>
+ gpu.return
+}
+
+gpu.func @softmax_dim_0(%arg0: vector<256x128xf32>) -> vector<256x128xf32> {
+ %cst = arith.constant dense<0.000000e+00> : vector<128xf32>
+ %0 = math.exp %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xf32>
+ //CHECK: vector.multi_reduction <add>, {{.*}} {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>
+ %1 = vector.multi_reduction <add>, %0, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0] : vector<256x128xf32> to vector<128xf32>
+ //CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<128xf32> to vector<256x128xf32>
+ %2 = vector.broadcast %1 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<128xf32> to vector<256x128xf32>
+ %3 = arith.divf %0, %2 {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xf32>
+ gpu.return %3 : vector<256x128xf32>
+}
+
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
new file mode 100644
index 0000000000000..547c7355e00c6
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/xegpu-attr-interface.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
+
+//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
+gpu.module @test {
+ gpu.func @slice_attr() -> vector<128xindex> {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>, dims = [1]>}: vector<128xindex>
+ gpu.return %step : vector<128xindex>
+ }
+
+ gpu.func @nested_slice_attr() -> vector<128xindex> {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
+ //CHECK: [[c32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
+ //CHECK: [[c0:%.+]] = arith.constant 0 : index
+ //CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
+ //CHECK: [[c128:%.+]] = arith.constant 128 : index
+ //CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
+ //CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
+ //CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
+ //CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
+ %0 = vector.step {layout_result_0 = #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 1], sg_data = [32, 32, 1]>, dims = [2]>, dims = [1]>} : vector<128xindex>
+ gpu.return %0 : vector<128xindex>
+ }
+
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 628a4857d1253..e5cc65e6bd3d7 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -1,5 +1,8 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
+#map = affine_map<()[s0] -> (s0 floordiv 4)>
+#map1 = affine_map<()[s0] -> (s0 mod 4)>
+
gpu.module @test_round_robin_assignment {
// CHECK-LABEL: create_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -12,6 +15,30 @@ gpu.module @test_round_robin_assignment {
gpu.return
}
+ // CHECK-LABEL: create_nd_tdesc_with_shared_data
+ // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
+ gpu.func @create_nd_tdesc_with_shared_data(%src: memref<256x128xf32>) {
+ //CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[IdY:%.+]] = affine.apply #map()[[[sgId]]]
+ //CHECK: [[IdX:%.+]] = affine.apply #map1()[[[sgId]]]
+ //CHECK: [[C16:%.+]] = arith.constant 16 : index
+ //CHECK: [[LY:%.+]] = index.mul [[IdY]], [[C16]]
+ //CHECK: [[C64:%.+]] = arith.constant 64 : index
+ //CHECK: [[LX:%.+]] = index.mul [[IdX]], [[C64]]
+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[ADDY:%.+]] = arith.addi [[LY]], [[C0]] : index
+ //CHECK: [[ADDX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
+ //CHECK: [[offY:%.+]] = index.remu [[ADDY]], [[C128]]
+ //CHECK: [[C64_2:%.+]] = arith.constant 64 : index
+ //CHECK: [[offX:%.+]] = index.remu [[ADDX]], [[C64_2]]
+ //CHECK: xegpu.create_nd_tdesc [[ARG_0]][[[offY]], [[offX]]] : memref<256x128xf32> -> !xegpu.tensor_desc<16x64xf32>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<128x64xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64]>>
+ gpu.return
+ }
+
// CHECK-LABEL: load_nd_tdesc
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index d4b00372bc193..180ba8a162c9f 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,34 +4,26 @@
//CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_1_1_assignment {
// CHECK-LABEL: create_nd_tdesc
- // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+ // CHECK-SAME: [[ARG_0:%.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
- // CHECK: %[[SGID:.*]] = gpu.subgroup_id
- // CHECK: %[[C8:.*]] = arith.constant 8 : index
- // CHECK: %[[C32:.*]] = arith.constant 32 : index
- // CHECK: %[[C4:.*]] = arith.constant 4 : index
- // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
- // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
- // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
- // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
- // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
- // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
- // CHECK: %[[C256:.*]] = arith.constant 256 : index
- // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
- // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
- // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
- // CHECK: %[[C128:.*]] = arith.constant 128 : index
- // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
- // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
- // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
- // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
- // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
- // CHECK: gpu.return
- %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
- -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
- gpu.return
+ //CHECK: [[SGID:%.+]] = gpu.subgroup_id : index
+ //CHECK: [[SGIDY:%.+]] = affine.apply #map()[[[SGID]]]
+ //CHECK: [[SGIDX:%.+]] = affine.apply #map1()[[[SGID]]]
+ //CHECK: [[C32:%.+]] = arith.constant 32 : index
+ //CHECK: [[LY:%.+]] = index.mul [[SGIDY]], [[C32]]
+ //CHECK: [[LX:%.+]] = index.mul [[SGIDX]], [[C32]]
+ //CHECK: [[C0:%.+]] = arith.constant 0 : index
+ //CHECK: [[C0_1:%.+]] = arith.constant 0 : index
+ //CHECK: [[UY:%.+]] = arith.addi [[LY]], [[C0]] : index
+ //CHECK: [[UX:%.+]] = arith.addi [[LX]], [[C0_1]] : index
+ //CHECK: [[C256:%.+]] = arith.constant 256 : index
+ //CHECK: [[Y:%.+]] = index.remu [[UY]], [[C256]]
+ //CHECK: [[C128:%.+]] = arith.constant 128 : index
+ //CHECK: [[X:%.+]] = index.remu [[UX]], [[C128]]
+ //CHECK: [[TDESC:%.+]] = xegpu.create_nd_tdesc [[ARG_0]][[[Y]], [[X]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+ -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
+ gpu.return
}
// CHECK-LABEL: load_nd_tdesc
@@ -347,7 +339,7 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
// CHECK-LABEL: @subgroup_id_range_nested_if
gpu.func @subgroup_id_range_nested_if(%src: memref<256x128xf32>, %src1: memref<128x64xf32>) {
%sg_id = gpu.subgroup_id : index
- %c1 = arith.constant 1 : i1
+ %c1 = arith.constant 1 : i1
%c3 = arith.constant 3 : index
%c32 = arith.constant 32 : index
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index c6245b637c2a7..3bea8efcdb0ae 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -7,11 +7,14 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
@@ -147,12 +150,118 @@ struct TestXeGPUUnrollingPatterns
}
};
+#undef DEBUG_TYPE
+#define DEBUG_TYPE "test-xegpu-layout-interface"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
+
+// Test pattern for distributing vector::StepOp from workgroup to subgroup.
+// Validates LayoutTrait interfaces for offset computation abstraction between
+// LayoutAttr and SliceAttr.
+class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
+ using OpConversionPattern<vector::StepOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ auto layoutName = xegpu::getLayoutName(op->getResult(0));
+ auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+ if (!sliceAttr || sliceAttr.getRank() != 1)
+ return failure();
+
+ std::optional<SmallVector<int64_t>> sgShape = sliceAttr.getSgDataAsInt();
+ if (!sgShape)
+ return failure();
+
+ Location loc = op.getLoc();
+ VectorType type = op.getResult().getType();
+ auto wgShape = type.getShape();
+
+ Value sgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+ auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+ if (failed(maybeOffsets))
+ return failure();
+
+ VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
+ Value base = vector::StepOp::create(rewriter, loc, newTy);
+ SmallVector<Value> newOps;
+ for (auto offsets : *maybeOffsets) {
+ Value bcast =
+ vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
+ Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
+ newOps.push_back(add);
+ }
+ rewriter.replaceOpWithMultiple(op, {newOps});
+ return success();
+ }
+};
+
+struct TestXeGPULayoutInterface
+ : public PassWrapper<TestXeGPULayoutInterface,
+ OperationPass<gpu::GPUModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPULayoutInterface)
+
+ StringRef getArgument() const final { return "test-xegpu-layout-interface"; }
+
+ StringRef getDescription() const final {
+ return "Test the implementation of XeGPU Layout interfaces";
+ }
+
+ void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
+ registry.insert<arith::ArithDialect>();
+ registry.insert<memref::MemRefDialect>();
+ registry.insert<xegpu::XeGPUDialect>();
+ registry.insert<vector::VectorDialect>();
+ registry.insert<index::IndexDialect>();
+ }
+
+ TestXeGPULayoutInterface() = default;
+ TestXeGPULayoutInterface(const TestXeGPULayoutInterface &pass)
+ : PassWrapper(pass) {}
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+
+ TypeConverter typeConverter;
+ auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
+ mlir::ValueRange inputs,
+ mlir::Location loc) -> mlir::Value {
+ return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
+ .getResult(0);
+ };
+ typeConverter.addSourceMaterialization(materializeCast);
+ typeConverter.addTargetMaterialization(materializeCast);
+
+ RewritePatternSet patterns(ctx);
+ patterns.add<TestStepOpPattern>(typeConverter, ctx);
+
+ ConversionTarget target(*ctx);
+ auto isLegal = [&](xegpu::SliceAttr layout) -> bool {
+ return !layout || !layout.isWgLayout();
+ };
+
+ target.addDynamicallyLegalOp<vector::StepOp>(
+ [&](vector::StepOp op) -> bool {
+ auto layoutName = xegpu::getLayoutName(op->getResult(0));
+ auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
+ return isLegal(sliceAttr);
+ });
+
+ target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
+
+ (void)applyPartialConversion(getOperation(), target, std::move(patterns));
+ }
+};
+
} // namespace
namespace mlir {
namespace test {
void registerTestXeGPULowerings() {
PassRegistration<TestXeGPUUnrollingPatterns>();
+ PassRegistration<TestXeGPULayoutInterface>();
}
} // namespace test
} // namespace mlir
More information about the Mlir-commits
mailing list