[Mlir-commits] [mlir] 68c4c83 - [MLIR][XeGPU] Matrix load/store subgroup distribution (#165008)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 3 12:48:31 PST 2025
Author: Artem Kroviakov
Date: 2025-11-03T21:48:27+01:00
New Revision: 68c4c83bcbf9612a02074b946fe6bb73054183ef
URL: https://github.com/llvm/llvm-project/commit/68c4c83bcbf9612a02074b946fe6bb73054183ef
DIFF: https://github.com/llvm/llvm-project/commit/68c4c83bcbf9612a02074b946fe6bb73054183ef.diff
LOG: [MLIR][XeGPU] Matrix load/store subgroup distribution (#165008)
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
mlir/test/Dialect/XeGPU/invalid.mlir
mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 1481859e94a92..0c059967bb898 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -30,9 +30,11 @@ class SliceAttr;
} // namespace xegpu
} // namespace mlir
+// clang-format off
+#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+// clang-format on
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 40352b44b6441..9c35c07a7e587 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -223,17 +223,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout by dropping InstData",
"xegpu::DistributeLayoutAttr",
"dropInstData">,
- InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
- indices based on the effective subgroup layout.}],
+ InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
+ indices based on the effective layout level.}],
"FailureOr<SmallVector<Value>>",
- "delinearizeSubgroupId",
+ "delinearizeId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
- InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
- assigned to a subgroup identified by linearId. The shape parameter
- represents the workgroup-level problem size. Each subgroup may access
+ InterfaceMethod<[{Generates instructions to compute multidimensional coordinates for dist units
+ assigned to a level identified by linearId. The shape parameter
+ represents the higher-level problem size. Each level may access
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
- "getOffsets",
+ "computeDistributedCoords",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
to some other layout according to given permutation of (0...n-1).}],
@@ -476,17 +476,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return {};
}
- /// Delinearizes a linear subgroup ID into its multidimensional indices
- /// based on the effective subgroup layout.
+ /// Delinearizes a linear ID into its multidimensional indices
+ /// based on the effective level of the layout.
FailureOr<SmallVector<Value>>
- delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+ delinearizeId(OpBuilder &builder, Location loc, Value linearId);
- /// Generates instructions to compute multidimensional offsets for blocks
- /// assigned to a subgroup identified by linearId. The shape parameter
- /// represents the workgroup-level problem size. Each subgroup may access
+ /// Generates instructions to compute multidimensional coordinates for dist units
+ /// assigned to a level identified by linearId. The shape parameter
+ /// represents the higher-level problem size. Each `level` may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
- getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+ computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -643,14 +643,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
- delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+ delinearizeId(OpBuilder &builder, Location loc, Value linearId);
- /// Generates instructions to compute multidimensional offsets for blocks
+ /// Generates instructions to compute multidimensional coordinates for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.
+
FailureOr<SmallVector<SmallVector<Value>>>
- getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+ computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index b7af5413669c9..eb05628d4772b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
The pass distributes subgroup level (SIMD) XeGPU ops to work items.
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
- "vector::VectorDialect"];
+ "vector::VectorDialect", "index::IndexDialect"];
}
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2ed1f6ed..de552ce22ef62 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
+ if (valOrResVecTy.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 83406c8c75dcf..397107b786c9e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,55 +37,61 @@ void XeGPUDialect::initialize() {
>();
}
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
- SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
- ArrayRef<int64_t> sizePerSg,
- ArrayRef<int64_t> sizePerWg) {
-
- SmallVector<SmallVector<Value>> offsets;
+genCoordinates(OpBuilder &builder, Location loc,
+ SmallVector<Value> delinearizedId,
+ ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+ ArrayRef<int64_t> srcShape) {
+ SmallVector<SmallVector<Value>> coordinates;
+
+ // A distribution unit must be less than or equal to `srcShape`
+ SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+ llvm::zip_equal(srcShape,
+ computeElementwiseMul(subShapesLayout, subShape)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
- SmallVector<Value> localOffsets = llvm::map_to_vector(
- llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ // Get the offset of `subShape` within a distribution unit.
+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+ llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::MulOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
- // distUnit[i] is the minimum value between sizePerWg[i] and
- // sgLayout[i] * sizePerSg[i]
- SmallVector<int64_t> distUnit = llvm::map_to_vector(
- llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
- [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+ // For each dist unit
for (SmallVector<int64_t> unitOffs :
- StaticTileOffsetRange(sizePerWg, distUnit)) {
+ StaticTileOffsetRange(srcShape, distUnitShape)) {
+ // Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});
-
- SmallVector<Value> adds = llvm::map_to_vector(
- llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
- return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
- std::get<1>(t));
- });
-
+ // Calculate `subShape` offset within `srcShape`.
+ SmallVector<Value> adds =
+ llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+ [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(
+ loc, std::get<0>(t), std::get<1>(t));
+ });
+ // Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
- llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
});
- offsets.push_back(mods);
+ coordinates.push_back(mods);
}
- return offsets;
+ return coordinates;
}
// Checks if the given shape can be evenly distributed based on the layout
@@ -272,12 +278,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
- // delinearizeSubgroupId is only available for
- // workgroup-level layout attribute
- if (!isForWorkgroup())
- return failure();
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
// TODO: handle order attribute
auto hasDefaultOrder = [&]() {
@@ -287,41 +288,52 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
};
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");
-
- auto dims =
- llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
+ SmallVector<int64_t> layout;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ } else {
+ return failure();
+ }
+ auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
return affine::delinearizeIndex(builder, loc, linearId, dims);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
- if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
return failure();
-
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ }
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
- SmallVector<Value> sgIds = *maybeIds;
+ SmallVector<Value> ids = *maybeIds;
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, ids, layout, subShape, shape);
}
//===----------------------------------------------------------------------===//
@@ -375,34 +387,43 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeSubgroupId(builder, loc, linearId);
+ return parent.delinearizeId(builder, loc, linearId);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (isForWorkgroup()) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (isForSubgroup()) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
@@ -412,8 +433,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..7b6c4b6c2c813 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -175,13 +175,13 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
LogicalResult
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
- UnitAttr subgroup_block_io,
+ UnitAttr subgroup_block_io, DistributeLayoutAttr layout,
function_ref<InFlightDiagnostic()> emitError) {
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
+ "are only allowed when result is a VectorType.";
else
return success();
}
@@ -192,15 +192,37 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> dataShape = dataTy.getShape();
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
+ ArrayAttr strideAttr = mdescTy.getStrideAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+ if (subgroup_block_io && layout) {
+ auto laneData = layout.getEffectiveLaneDataAsInt();
+ auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
+ if (!laneData.empty()) {
+ bool isLaneDataContiguous =
+ std::all_of(laneData.begin(), std::prev(laneData.end()),
+ [](int x) { return x == 1; });
+ if (!isLaneDataContiguous)
+ return emitError() << "With subgroup_block_io, accessed data must be "
+ "contiguous and coalesced.";
+ for (size_t i = 0; i < laneData.size(); ++i) {
+ if (laneLayout[i] != blockShape[i])
+ return emitError() << "With subgroup_block_io, the block shape must "
+ "match the lane layout.";
+ if (laneLayout[i] != 1 && strides[i] != 1)
+ return emitError() << "With subgroup_block_io, the distributed "
+ "dimensions must be contiguous.";
+ }
+ }
+ }
if (dataShape.size() == 2) {
- if (subgroup_block_io)
- return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
} else {
- SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
// if the subgroup_block_io attribute is set, mdescTy must have block
// attribute
if (subgroup_block_io && !blockShape.size())
@@ -1105,7 +1127,7 @@ LogicalResult LoadMatrixOp::verify() {
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
//===----------------------------------------------------------------------===//
@@ -1129,7 +1151,7 @@ LogicalResult StoreMatrixOp::verify() {
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
MemDescType mdescTy = getMemDesc().getType();
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
- [&]() { return emitError(); });
+ getLayoutAttr(), [&]() { return emitError(); });
}
namespace mlir {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index 5a3b27ec6108e..bbd7733e89c29 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -912,6 +913,186 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+static SmallVector<Value> computeDistributedCoordinatesForMatrixOp(
+ PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+ Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+ SmallVector<Value> newCoods;
+ auto maybeCoords =
+ layout.computeDistributedCoords(rewriter, loc, laneId, payloadShape);
+ if (failed(maybeCoords))
+ return {};
+ assert(maybeCoords.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeCoords.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ newCoods = llvm::to_vector(llvm::map_range(
+ ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ return newCoods;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ const int operandIdx = producedByLastLoad->getOperandNumber();
+
+ VectorType sgPayloadTy =
+ dyn_cast<VectorType>(matrixOp.getResult().getType());
+ VectorType warpResultTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ return success();
+ }
+};
+
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the store op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "Failed to distribute matrix op payload based on layout.");
+
+ SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newCoords = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newCoords = computeDistributedCoordinatesForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
+ }
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ ValueRange(newCoords), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1443,7 +1624,8 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
GpuBarrierDistribution, VectorMultiReductionDistribution,
LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
+ VectorBitcastDistribution, LoadMatrixDistribution,
+ StoreMatrixDistribution,
MemrefExtractAlignedPointerAsIndexDistribution>(
patterns.getContext(),
/*pattern benefit=*/regularPatternBenefit);
@@ -1468,6 +1650,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..79eea55c8b78a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeDescOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeDescOffsets))
return failure();
@@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets =
+ layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce0ec0d0..92f353717ac59 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error at +1 {{Mask should match value except the chunk size dim}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
@@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
return
}
-// -----
-func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
- return
-}
-
-
// -----
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
// expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -900,16 +892,25 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
}
// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+func.func @simt_store_matrix_vector_nonlinear(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>, %arg1: vector<2x16xf32>) {
+ // expected-error at +1 {{With subgroup_block_io, accessed data must be contiguous and coalesced}}
+ xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+ vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1]>>
return
}
// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>, %arg1: vector<16x2xf32>) {
+ // expected-error at +1 {{With subgroup_block_io, the distributed dimensions must be contiguous}}
+ xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 2]>} :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [1, 16]>>
return
}
+// -----
+func.func @simt_store_matrix_vector_noncoalesced(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>, %arg1: vector<16x2xf32>) {
+ // expected-error at +1 {{With subgroup_block_io, the block shape must match the lane layout}}
+ xegpu.store_matrix %arg1, %arg0[0, 0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [32, 1], block = [1, 17]>>
+ return
+}
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..8946d14e80b72 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,66 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
+// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
+// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
+// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
+// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
+// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index -> vector<1x2xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<1x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 1], stride = [1, 32]>>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+ !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index -> vector<16x2xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
+ vector<16x2xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 1]>>, index, index
+ gpu.return
+ }
+}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 76d461108b296..93d51441f5b81 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeOffsets =
+ sliceAttr.computeDistributedCoords(rewriter, loc, sgId, wgShape);
if (failed(maybeOffsets))
return failure();
More information about the Mlir-commits
mailing list