[Mlir-commits] [mlir] [MLIR][XeGPU] Matrix load/store subgroup distribution (PR #165008)
Artem Kroviakov
llvmlistbot at llvm.org
Tue Oct 28 10:30:51 PDT 2025
https://github.com/akroviakov updated https://github.com/llvm/llvm-project/pull/165008
>From 887f9781ea3b62cd990d9df7066f28ec049f603b Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Fri, 24 Oct 2025 16:08:00 +0000
Subject: [PATCH 1/5] [MLIR][XeGPU] Matrix load/store subgroup distribution
---
.../Transforms/XeGPUSubgroupDistribute.cpp | 124 ++++++++++++++++--
.../Dialect/XeGPU/subgroup-distribute.mlir | 15 +++
2 files changed, 131 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index d09dc196c0bf7..fe059bb86eba2 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -906,6 +906,110 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
+template <class MatrixOp>
+struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+ constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
+ int operandIdx{-1};
+
+ VectorType payloadTy;
+ VectorType warpResultTy;
+ if constexpr (isLoad) {
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ operandIdx = producedByLastLoad->getOperandNumber();
+ payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+ warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
+ } else {
+ payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ }
+ if (!payloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the load op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp,
+ "The matrix op payload has no layouts, using defaults instead.");
+
+ SmallVector<Value> operands;
+ if constexpr (isLoad)
+ operands = {matrixOp.getMemDesc()};
+ else
+ operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ if constexpr (!isLoad)
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
+ newOperands[operandIdxToModify] = arith::AddIOp::create(
+ rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
+ newWarpOp.getLaneid());
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ if constexpr (isLoad) {
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], newOffsets, newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ } else {
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
+ xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
+ }
+ return success();
+ }
+};
+
/// Distribute a scattered load op. The logic and requirements are the same as
/// for the scattered store distribution. The warpOp's payload vector is
/// expected to be distributed by the load's result consumer.
@@ -1433,14 +1537,16 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns.add<CreateNdDescDistribution, StoreNdDistribution,
- LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
- GpuBarrierDistribution, VectorMultiReductionDistribution,
- LoadDistribution, StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution,
- MemrefExtractAlignedPointerAsIndexDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/regularPatternBenefit);
+ patterns
+ .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
+ DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
+ VectorMultiReductionDistribution, LoadDistribution,
+ StoreDistribution, VectorTransposeDistribution,
+ VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
+ MatrixOpDistribution<xegpu::StoreMatrixOp>,
+ MemrefExtractAlignedPointerAsIndexDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/regularPatternBenefit);
patterns.add<VectorShapeCastDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
@@ -1462,6 +1568,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
// Layouts are needed for vector type only.
if (!isa<VectorType>(operand.get().getType()))
continue;
+ if (isa<xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>(op))
+ continue;
auto layout = xegpu::getDistributeLayoutAttr(operand.get());
if (!layout) {
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 27a3dc373c739..3fcc747217c9d 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -265,3 +265,18 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
+
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+
+ gpu.return
+ }
+}
>From f80ee32a523ddda05eaf789358ef300efd1208d3 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Sat, 25 Oct 2025 10:51:52 +0000
Subject: [PATCH 2/5] Add offset calculation
---
mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h | 4 +-
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 47 ++++--
.../mlir/Dialect/XeGPU/Transforms/Passes.td | 2 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 154 ++++++++++--------
.../Transforms/XeGPUSubgroupDistribute.cpp | 44 +++--
.../Transforms/XeGPUWgToSgDistribute.cpp | 10 +-
.../Dialect/XeGPU/subgroup-distribute.mlir | 15 +-
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 3 +-
8 files changed, 166 insertions(+), 113 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 1481859e94a92..0c059967bb898 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -30,9 +30,11 @@ class SliceAttr;
} // namespace xegpu
} // namespace mlir
+// clang-format off
+#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
-#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
+// clang-format on
#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 19a52317956d2..1b515b11658c0 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -167,6 +167,16 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
let cppNamespace = "::mlir::xegpu";
}
+def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
+def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
+def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
+def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
+ "The enumeration for the scope of fence operation.",
+ [XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::xegpu";
+}
+
def XeGPU_FenceScopeAttr:
EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
let summary = [{Describes the scope of fence.
@@ -223,18 +233,18 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Derive a new layout by dropping InstData",
"xegpu::DistributeLayoutAttr",
"dropInstData">,
- InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
- indices based on the effective subgroup layout.}],
+ InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
+ indices based on the effective `level` layout.}],
"FailureOr<SmallVector<Value>>",
- "delinearizeSubgroupId",
- (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
- InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
- assigned to a subgroup identified by linearId. The shape parameter
- represents the workgroup-level problem size. Each subgroup may access
+ "delinearizeId",
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
+ InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
+ assigned to a `level` identified by linearId. The shape parameter
+ represents the higher-level problem size. Each `level` may access
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
- "getOffsets",
- (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
+ "computeDistributedCoords",
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
to some other layout according to given permutation of (0...n-1).}],
/*retTy=*/"bool",
@@ -476,17 +486,17 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return {};
}
- /// Delinearizes a linear subgroup ID into its multidimensional indices
- /// based on the effective subgroup layout.
+ /// Delinearizes a linear ID into its multidimensional indices
+ /// based on the effective `level` layout.
FailureOr<SmallVector<Value>>
- delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+ delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
- /// Generates instructions to compute multidimensional offsets for blocks
- /// assigned to a subgroup identified by linearId. The shape parameter
- /// represents the workgroup-level problem size. Each subgroup may access
+ /// Generates instructions to compute multidimensional offsets for dist units
+ /// assigned to a `level` identified by linearId. The shape parameter
+ /// represents the higher-level problem size. Each `level` may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
- getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+ computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -643,14 +653,15 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
- delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
+ delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
/// Generates instructions to compute multidimensional offsets for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.
+
FailureOr<SmallVector<SmallVector<Value>>>
- getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
+ computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
index 564d9c4d5422b..5f803233041ab 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td
@@ -26,7 +26,7 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
The pass distributes subgroup level (SIMD) XeGPU ops to work items.
}];
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
- "vector::VectorDialect"];
+ "vector::VectorDialect", "index::IndexDialect"];
}
def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e909548fe0b..cbe459bfcbb48 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -38,47 +38,47 @@ void XeGPUDialect::initialize() {
>();
}
-/// Generates instructions to compute offsets for a subgroup identified by
-/// its multidimensional indices (sgId), using the specified subgroup layout
-/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
-/// dimensions (sizePerWg).
+// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
+// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
+// within each distribution unit.
static SmallVector<SmallVector<Value>>
-genOffsetsComputingInsts(OpBuilder &builder, Location loc,
- SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
- ArrayRef<int64_t> sizePerSg,
- ArrayRef<int64_t> sizePerWg) {
-
+genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
+ ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
+ ArrayRef<int64_t> srcShape) {
SmallVector<SmallVector<Value>> offsets;
- // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
- SmallVector<Value> localOffsets = llvm::map_to_vector(
- llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ // A distribution unit must be less than or equal to `srcShape`
+ SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
+ llvm::zip_equal(srcShape,
+ computeElementwiseMul(subShapesLayout, subShape)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+ // Get the offset of `subShape` within a distribution unit.
+ SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
+ llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::MulOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});
- // distUnit[i] is the minimum value between sizePerWg[i] and
- // sgLayout[i] * sizePerSg[i]
- SmallVector<int64_t> distUnit = llvm::map_to_vector(
- llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
- [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
-
+ // For each dist unit
for (SmallVector<int64_t> unitOffs :
- StaticTileOffsetRange(sizePerWg, distUnit)) {
+ StaticTileOffsetRange(srcShape, distUnitShape)) {
+ // Get dist unit offset within `srcShape`.
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return arith::ConstantIndexOp::create(builder, loc, d);
});
-
- SmallVector<Value> adds = llvm::map_to_vector(
- llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
- return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
- std::get<1>(t));
- });
-
+ // Calculate `subShape` offset within `srcShape`.
+ SmallVector<Value> adds =
+ llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
+ [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(
+ loc, std::get<0>(t), std::get<1>(t));
+ });
+ // Do not go beyond `srcShape` bounds.
SmallVector<Value> mods = llvm::map_to_vector(
- llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
@@ -268,12 +268,8 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
- // delinearizeSubgroupId is only available for
- // workgroup-level layout attribute
- if (!isForWorkgroup())
- return failure();
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
+ xegpu::DistributionLevel idLevel) {
// TODO: handle order attribute
auto hasDefaultOrder = [&]() {
@@ -283,41 +279,53 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
};
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");
-
- auto dims =
- llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
- return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
- });
+ SmallVector<int64_t> layout;
+ if (idLevel == xegpu::DistributionLevel::SG) {
+ layout = getEffectiveSgLayoutAsInt();
+ } else if (idLevel == xegpu::DistributionLevel::WI) {
+ layout = getEffectiveLaneLayoutAsInt();
+ } else {
+ return failure();
+ }
+ auto dims = llvm::map_to_vector(layout, [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
return affine::delinearizeIndex(builder, loc, linearId, dims);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
+/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
- if (!isForWorkgroup())
+LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape,
+ xegpu::DistributionLevel targetLevel) {
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (targetLevel == DistributionLevel::SG) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (targetLevel == DistributionLevel::WI) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
return failure();
-
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ }
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
if (failed(maybeIds))
return failure();
- SmallVector<Value> sgIds = *maybeIds;
+ SmallVector<Value> ids = *maybeIds;
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genOffsets(builder, loc, ids, layout, subShape, shape);
}
//===----------------------------------------------------------------------===//
@@ -371,34 +379,45 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
- Value linearId) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
+ xegpu::DistributionLevel level) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeSubgroupId(builder, loc, linearId);
+ return parent.delinearizeId(builder, loc, linearId, level);
}
-/// Implements DistributeLayoutAttr::getOffsets to generate
-/// instructions for computing multi-dimensional offsets when distributed by
-/// SliceAttr.
+// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// instructions for computing multi-dimensional offsets when distributed by
+// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
- ArrayRef<int64_t> shape) {
+SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape,
+ xegpu::DistributionLevel targetLevel) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
- SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
- SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
- if (sgShape.empty()) {
- if (auto derivedShape = computeShapeRatio(shape, sgLayout))
- sgShape = derivedShape.value();
+ SmallVector<int64_t> layout;
+ SmallVector<int64_t> subShape;
+ if (targetLevel == DistributionLevel::SG) {
+ layout = getEffectiveSgLayoutAsInt();
+ subShape = getEffectiveSgDataAsInt();
+ } else if (targetLevel == DistributionLevel::WI) {
+ layout = getEffectiveLaneLayoutAsInt();
+ subShape = getEffectiveLaneDataAsInt();
+ } else {
+ return failure();
+ }
+
+ if (subShape.empty()) {
+ if (auto derivedShape = computeShapeRatio(shape, layout))
+ subShape = derivedShape.value();
else
return failure();
}
// delinearize Ids
- auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
if (failed(maybeIds))
return failure();
@@ -408,8 +427,7 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
SmallVector<Value> sgIds =
XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
- return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
- shape);
+ return genOffsets(builder, loc, sgIds, layout, subShape, shape);
}
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index fe059bb86eba2..b02290d4b251b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/DistributionUtils.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
@@ -919,7 +920,7 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
int operandIdx{-1};
- VectorType payloadTy;
+ VectorType sgPayloadTy;
VectorType warpResultTy;
if constexpr (isLoad) {
OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
@@ -929,12 +930,12 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
return rewriter.notifyMatchFailure(
warpOp, "The last op is not xegpu::LoadMatrixOp");
operandIdx = producedByLastLoad->getOperandNumber();
- payloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
+ sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
} else {
- payloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
}
- if (!payloadTy)
+ if (!sgPayloadTy)
return rewriter.notifyMatchFailure(
matrixOp, "the matrix op payload must be a vector type");
@@ -952,7 +953,7 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
matrixOp, "the matrix operation lacks layout attribute");
FailureOr<VectorType> distPayloadByWarpOpOrFailure =
- getDistVecTypeBasedOnLaneLayout(layout, payloadTy);
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
if (failed(distPayloadByWarpOpOrFailure))
return rewriter.notifyMatchFailure(
matrixOp,
@@ -977,23 +978,36 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
SmallVector<Value> newOperands = llvm::map_to_vector(
newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
- rewriter.setInsertionPointAfter(newWarpOp);
- unsigned operandIdxToModify = offsetsStartIdx + offsetsAsValues.size() - 1;
- newOperands[operandIdxToModify] = arith::AddIOp::create(
- rewriter, loc, rewriter.getIndexType(), newOperands[operandIdxToModify],
- newWarpOp.getLaneid());
-
SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
std::fill(newConstOffsets.begin(), newConstOffsets.end(),
ShapedType::kDynamic);
DenseI64ArrayAttr newConstOffsetsAttr =
rewriter.getDenseI64ArrayAttr(newConstOffsets);
- ValueRange newOffsets = ValueRange(newOperands).drop_front(offsetsStartIdx);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ rewriter.setInsertionPointAfter(newWarpOp);
+ SmallVector<Value> newOffsets = currentOffsets;
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ auto maybeDescOffsets = layout.computeDistributedCoords(
+ rewriter, loc, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ xegpu::DistributionLevel::WI);
+ if (failed(maybeDescOffsets))
+ return failure();
+ assert(maybeDescOffsets.value().size() == 1 &&
+ "Expected same number of offset sets as number of accessed "
+ "sub-tensors or sub-memory descriptors.");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
+ offsets);
+ newOffsets = llvm::to_vector(llvm::map_range(
+ ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ }
if constexpr (isLoad) {
xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
- newOperands[0], newOffsets, newConstOffsetsAttr,
+ newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
// Resolve the output type and replace all uses.
rewriter.replaceAllUsesWith(
@@ -1002,8 +1016,8 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
} else {
xegpu::StoreMatrixOp::create(
rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
- newOffsets, newConstOffsetsAttr, matrixOp.getSubgroupBlockIoAttr(),
- xegpu::DistributeLayoutAttr{});
+ ValueRange(newOffsets), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
rewriter.eraseOp(matrixOp);
}
return success();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9fc5ad9af5c7b..93e23cea9c7dd 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeDescOffsets = layout.computeDistributedCoords(
+ rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(maybeDescOffsets))
return failure();
@@ -830,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
-
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets = layout.computeDistributedCoords(
+ rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(sgOffsets))
return failure();
@@ -1052,7 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
+ auto sgOffsets = layout.computeDistributedCoords(
+ rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(sgOffsets))
return failure();
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index 3fcc747217c9d..b69c661f8cfd5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -268,15 +268,20 @@ gpu.module @xevm_module{
// -----
// CHECK-LABEL: gpu.func @load_store_matrix_1({{.*}}) {
-// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[{{.*}}] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
-// CHECK: xegpu.store_matrix %[[MAT]], %arg0[{{.*}}] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
+// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%0]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%0]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<1x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
gpu.func @load_store_matrix_1(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
%1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x8xf32>
-
xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>}> : vector<2x8xf32>, !xegpu.mem_desc<32x32xf32>, index, index
-
- gpu.return
+ gpu.return
}
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 76d461108b296..4408e827a97fc 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
+ auto maybeOffsets = sliceAttr.computeDistributedCoords(
+ rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(maybeOffsets))
return failure();
>From b4f5a4d325a3069ad658362c701fe6c8ad9b8a81 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Mon, 27 Oct 2025 16:56:11 +0000
Subject: [PATCH 3/5] Relax `subgroup_block_io` dimensionality restriction
---
.../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 2 +
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 5 +--
mlir/test/Dialect/XeGPU/invalid.mlir | 25 +----------
.../Dialect/XeGPU/subgroup-distribute.mlir | 43 ++++++++++++++++++-
4 files changed, 45 insertions(+), 30 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66dbe9e45..53b8c4f0bbd59 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -562,6 +562,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
if (!valOrResVecTy)
valOrResVecTy = VectorType::get(1, data.getType());
+ if (valOrResVecTy.getShape().size() != 1)
+ return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
int64_t elemBitWidth =
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..68f49d648e738 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -181,7 +181,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
if (!dataTy) {
if (subgroup_block_io)
return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
+ "are only allowed when result is a VectorType.";
else
return success();
}
@@ -193,9 +193,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
if (dataShape.size() == 2) {
- if (subgroup_block_io)
- return emitError() << "subgroup_block_io "
- "are only allowed when result is a 1D VectorType.";
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitError() << "data shape must not exceed mem_desc shape.";
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce0ec0d0..0b0ef27e39233 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -451,7 +451,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error at +1 {{Mask should match value except the chunk size dim}}
- xegpu.store %val, %src[%offsets], %mask
+ xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
@@ -870,14 +870,6 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
return
}
-// -----
-func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
- // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
- return
-}
-
-
// -----
func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
// expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -898,18 +890,3 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
xegpu.store_matrix %arg1, %arg0[32] : vector<32xf16>, !xegpu.mem_desc<64xf16>
return
}
-
-// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
- return
-}
-
-// -----
-func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
- // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
- xegpu.store_matrix %data, %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
- return
-}
-
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index b69c661f8cfd5..fe129428dc189 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -271,8 +271,8 @@ gpu.module @xevm_module{
// CHECK: %[[LAYOUT_X:.*]] = arith.constant 8 : index
// CHECK: %[[LAYOUT_Y:.*]] = arith.constant 2 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
-// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%0]
-// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%0]
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_Y]], %[[LAYOUT_Y]]
// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[LAYOUT_X]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<1x1xf32>
@@ -285,3 +285,42 @@ gpu.module @xevm_module{
gpu.return
}
}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_2({{.*}}) {
+// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
+// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
+// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
+// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
+// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
+// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ gpu.return
+ }
+}
+
+// -----
+// CHECK-LABEL: gpu.func @load_store_matrix_3({{.*}}) {
+// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index -> vector<2x1xf32>
+// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%{{.*}}, %{{.*}}] <{subgroup_block_io}>:
+// CHECK-SAME: vector<2x1xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<block = [16, 16], stride = [1, 32]>>, index, index
+gpu.module @xevm_module{
+ gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
+ %c0 = arith.constant 0 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+ !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+ vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+ gpu.return
+ }
+}
>From 3c4a5aa8e0a7bf66da85f22552e86615bdc8d1d9 Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 28 Oct 2025 16:11:34 +0000
Subject: [PATCH 4/5] Address feedback
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 8 +-
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 21 +-
.../Transforms/XeGPUSubgroupDistribute.cpp | 203 ++++++++++++------
.../Transforms/XeGPUWgToSgDistribute.cpp | 6 +-
.../Dialect/XeGPU/subgroup-distribute.mlir | 14 +-
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 2 +-
6 files changed, 163 insertions(+), 91 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 1b515b11658c0..794a84c839548 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -171,7 +171,7 @@ def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
- "The enumeration for the scope of fence operation.",
+ "Specify target level for offsets distribution utility.",
[XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::xegpu";
@@ -243,7 +243,7 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
represents the higher-level problem size. Each `level` may access
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
- "computeDistributedCoords",
+ "computeDistributedOffsets",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
to some other layout according to given permutation of (0...n-1).}],
@@ -496,7 +496,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
/// represents the higher-level problem size. Each `level` may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
- computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+ computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -661,7 +661,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
- computeDistributedCoords(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+ computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index cbe459bfcbb48..e335efefb608f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -41,6 +41,11 @@ void XeGPUDialect::initialize() {
// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
// within each distribution unit.
+// Example:
+// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
+// distribution unit of shape 64x64, we have 2x4 such distribution units.
+// `delinearizedId` is used to identify a 16x32 of a subgroup in each
+// distribution unit.
static SmallVector<SmallVector<Value>>
genOffsets(OpBuilder &builder, Location loc, SmallVector<Value> delinearizedId,
ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
@@ -294,13 +299,13 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
return affine::delinearizeIndex(builder, loc, linearId, dims);
}
-/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+/// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
/// instructions for computing multi-dimensional offsets when distributed by
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
- Value linearId, ArrayRef<int64_t> shape,
- xegpu::DistributionLevel targetLevel) {
+LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape,
+ xegpu::DistributionLevel targetLevel) {
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
if (targetLevel == DistributionLevel::SG) {
@@ -386,13 +391,13 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
return parent.delinearizeId(builder, loc, linearId, level);
}
-// Implements DistributeLayoutAttr::computeDistributedCoords to generate
+// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
// instructions for computing multi-dimensional offsets when distributed by
// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
-SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
- Value linearId, ArrayRef<int64_t> shape,
- xegpu::DistributionLevel targetLevel) {
+SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
+ Value linearId, ArrayRef<int64_t> shape,
+ xegpu::DistributionLevel targetLevel) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index b02290d4b251b..c576172683f68 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -907,34 +907,48 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
}
};
-template <class MatrixOp>
-struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
+static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
+ PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
+ Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
+ SmallVector<Value> newOffsets;
+ ;
+ auto maybeDescOffsets = layout.computeDistributedOffsets(
+ rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
+ if (failed(maybeDescOffsets))
+ return {};
+ assert(maybeDescOffsets.value().size() == 1 &&
+ "Expected one set of distributed offsets");
+ SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
+ rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
+ getAsOpFoldResult(origOffsets));
+ newOffsets = llvm::to_vector(llvm::map_range(
+ ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ return newOffsets;
+}
+
+/// Pattern for distributing xegpu::LoadMatrixOp.
+struct LoadMatrixDistribution final : public gpu::WarpDistributionPattern {
using gpu::WarpDistributionPattern::WarpDistributionPattern;
LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
PatternRewriter &rewriter) const override {
gpu::YieldOp yield = warpOp.getTerminator();
Operation *lastNode = yield->getPrevNode();
- auto matrixOp = dyn_cast_or_null<MatrixOp>(lastNode);
+ auto matrixOp = dyn_cast_or_null<xegpu::LoadMatrixOp>(lastNode);
if (!matrixOp)
return failure();
- constexpr bool isLoad{std::is_same_v<MatrixOp, xegpu::LoadMatrixOp>};
- int operandIdx{-1};
-
- VectorType sgPayloadTy;
- VectorType warpResultTy;
- if constexpr (isLoad) {
- OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
- return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
- });
- if (!producedByLastLoad)
- return rewriter.notifyMatchFailure(
- warpOp, "The last op is not xegpu::LoadMatrixOp");
- operandIdx = producedByLastLoad->getOperandNumber();
- sgPayloadTy = dyn_cast<VectorType>(matrixOp.getResult().getType());
- warpResultTy = cast<VectorType>(warpOp.getResult(operandIdx).getType());
- } else {
- sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
- }
+
+ OpOperand *producedByLastLoad = getWarpResult(warpOp, [&](Operation *op) {
+ return isa<xegpu::LoadMatrixOp>(op) && matrixOp == op;
+ });
+ if (!producedByLastLoad)
+ return rewriter.notifyMatchFailure(
+ warpOp, "The last op is not xegpu::LoadMatrixOp");
+ const int operandIdx = producedByLastLoad->getOperandNumber();
+
+ VectorType sgPayloadTy =
+ dyn_cast<VectorType>(matrixOp.getResult().getType());
+ VectorType warpResultTy =
+ cast<VectorType>(warpOp.getResult(operandIdx).getType());
if (!sgPayloadTy)
return rewriter.notifyMatchFailure(
matrixOp, "the matrix op payload must be a vector type");
@@ -956,21 +970,14 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
if (failed(distPayloadByWarpOpOrFailure))
return rewriter.notifyMatchFailure(
- matrixOp,
- "The matrix op payload has no layouts, using defaults instead.");
-
- SmallVector<Value> operands;
- if constexpr (isLoad)
- operands = {matrixOp.getMemDesc()};
- else
- operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ matrixOp, "The matrix op payload has no layout.");
+
+ SmallVector<Value> operands = {matrixOp.getMemDesc()};
const unsigned offsetsStartIdx = operands.size();
operands.append(offsetsAsValues);
SmallVector<Type> operandTypes = llvm::to_vector(
llvm::map_range(operands, [](Value v) { return v.getType(); }));
- if constexpr (!isLoad)
- operandTypes[0] = *distPayloadByWarpOpOrFailure;
SmallVector<size_t> newRetIndices;
gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -986,40 +993,97 @@ struct MatrixOpDistribution final : public gpu::WarpDistributionPattern {
ValueRange currentOffsets =
ValueRange(newOperands).drop_front(offsetsStartIdx);
- rewriter.setInsertionPointAfter(newWarpOp);
SmallVector<Value> newOffsets = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
if (!matrixOp.getSubgroupBlockIoAttr()) {
- auto maybeDescOffsets = layout.computeDistributedCoords(
- rewriter, loc, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
- xegpu::DistributionLevel::WI);
- if (failed(maybeDescOffsets))
- return failure();
- assert(maybeDescOffsets.value().size() == 1 &&
- "Expected same number of offset sets as number of accessed "
- "sub-tensors or sub-memory descriptors.");
- SmallVector<OpFoldResult> ofrVec = xegpu::addWithRightAligned(
- rewriter, loc, getAsOpFoldResult(maybeDescOffsets.value()[0]),
- offsets);
- newOffsets = llvm::to_vector(llvm::map_range(
- ofrVec, [&](OpFoldResult ofr) -> Value { return cast<Value>(ofr); }));
+ newOffsets = computeDistributedOffsetsForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
}
+ xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
+ rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
+ newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ // Resolve the output type and replace all uses.
+ rewriter.replaceAllUsesWith(
+ newWarpOp.getResult(operandIdx),
+ resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
+ return success();
+ }
+};
- if constexpr (isLoad) {
- xegpu::LoadMatrixOp newOp = xegpu::LoadMatrixOp::create(
- rewriter, newWarpOp.getLoc(), *distPayloadByWarpOpOrFailure,
- newOperands[0], ValueRange(newOffsets), newConstOffsetsAttr,
- matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
- // Resolve the output type and replace all uses.
- rewriter.replaceAllUsesWith(
- newWarpOp.getResult(operandIdx),
- resolveDistributedTy(newOp.getResult(), warpResultTy, rewriter));
- } else {
- xegpu::StoreMatrixOp::create(
- rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
- ValueRange(newOffsets), newConstOffsetsAttr,
- matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
- rewriter.eraseOp(matrixOp);
+/// Pattern for distributing xegpu::StoreMatrixOp.
+struct StoreMatrixDistribution final : public gpu::WarpDistributionPattern {
+ using gpu::WarpDistributionPattern::WarpDistributionPattern;
+ LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp,
+ PatternRewriter &rewriter) const override {
+ gpu::YieldOp yield = warpOp.getTerminator();
+ Operation *lastNode = yield->getPrevNode();
+ auto matrixOp = dyn_cast_or_null<xegpu::StoreMatrixOp>(lastNode);
+ if (!matrixOp)
+ return failure();
+
+ VectorType sgPayloadTy = dyn_cast<VectorType>(matrixOp.getData().getType());
+ if (!sgPayloadTy)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix op payload must be a vector type");
+
+ auto loc = matrixOp.getLoc();
+ auto offsets = matrixOp.getMixedOffsets();
+ if (offsets.empty())
+ return rewriter.notifyMatchFailure(matrixOp,
+ "the store op must have offsets");
+ SmallVector<Value> offsetsAsValues =
+ vector::getAsValues(rewriter, matrixOp.getLoc(), offsets);
+
+ auto layout = matrixOp.getLayoutAttr();
+ if (!layout)
+ return rewriter.notifyMatchFailure(
+ matrixOp, "the matrix operation lacks layout attribute");
+
+ FailureOr<VectorType> distPayloadByWarpOpOrFailure =
+ getDistVecTypeBasedOnLaneLayout(layout, sgPayloadTy);
+ if (failed(distPayloadByWarpOpOrFailure))
+ return rewriter.notifyMatchFailure(
+ matrixOp, "The matrix op payload has no layout.");
+
+ SmallVector<Value> operands = {matrixOp.getData(), matrixOp.getMemDesc()};
+ const unsigned offsetsStartIdx = operands.size();
+ operands.append(offsetsAsValues);
+
+ SmallVector<Type> operandTypes = llvm::to_vector(
+ llvm::map_range(operands, [](Value v) { return v.getType(); }));
+ operandTypes[0] = *distPayloadByWarpOpOrFailure;
+
+ SmallVector<size_t> newRetIndices;
+ gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
+ rewriter, warpOp, operands, operandTypes, newRetIndices);
+ SmallVector<Value> newOperands = llvm::map_to_vector(
+ newRetIndices, [&](size_t idx) { return newWarpOp.getResult(idx); });
+
+ SmallVector<int64_t> newConstOffsets{matrixOp.getConstOffsets()};
+ std::fill(newConstOffsets.begin(), newConstOffsets.end(),
+ ShapedType::kDynamic);
+ DenseI64ArrayAttr newConstOffsetsAttr =
+ rewriter.getDenseI64ArrayAttr(newConstOffsets);
+ ValueRange currentOffsets =
+ ValueRange(newOperands).drop_front(offsetsStartIdx);
+
+ SmallVector<Value> newOffsets = currentOffsets;
+ rewriter.setInsertionPointAfter(newWarpOp);
+
+ if (!matrixOp.getSubgroupBlockIoAttr()) {
+ newOffsets = computeDistributedOffsetsForMatrixOp(
+ rewriter, loc, layout, newWarpOp.getLaneid(), sgPayloadTy.getShape(),
+ currentOffsets);
}
+
+ xegpu::StoreMatrixOp::create(
+ rewriter, loc, TypeRange{}, newOperands[0], newOperands[1],
+ ValueRange(newOffsets), newConstOffsetsAttr,
+ matrixOp.getSubgroupBlockIoAttr(), xegpu::DistributeLayoutAttr{});
+ rewriter.eraseOp(matrixOp);
return success();
}
};
@@ -1551,16 +1615,15 @@ struct XeGPUSubgroupDistributePass final
void xegpu::populateXeGPUSubgroupDistributePatterns(
RewritePatternSet &patterns) {
- patterns
- .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
- DpasDistribution, PrefetchNdDistribution, GpuBarrierDistribution,
- VectorMultiReductionDistribution, LoadDistribution,
- StoreDistribution, VectorTransposeDistribution,
- VectorBitcastDistribution, MatrixOpDistribution<xegpu::LoadMatrixOp>,
- MatrixOpDistribution<xegpu::StoreMatrixOp>,
- MemrefExtractAlignedPointerAsIndexDistribution>(
- patterns.getContext(),
- /*pattern benefit=*/regularPatternBenefit);
+ patterns.add<CreateNdDescDistribution, StoreNdDistribution,
+ LoadNdDistribution, DpasDistribution, PrefetchNdDistribution,
+ GpuBarrierDistribution, VectorMultiReductionDistribution,
+ LoadDistribution, StoreDistribution, VectorTransposeDistribution,
+ VectorBitcastDistribution, LoadMatrixDistribution,
+ StoreMatrixDistribution,
+ MemrefExtractAlignedPointerAsIndexDistribution>(
+ patterns.getContext(),
+ /*pattern benefit=*/regularPatternBenefit);
patterns.add<VectorShapeCastDistribution>(
patterns.getContext(),
/*pattern benefit=*/highPatternBenefit);
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 93e23cea9c7dd..35072f0529072 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,7 +114,7 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.computeDistributedCoords(
+ auto maybeDescOffsets = layout.computeDistributedOffsets(
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(maybeDescOffsets))
return failure();
@@ -831,7 +831,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.computeDistributedCoords(
+ auto sgOffsets = layout.computeDistributedOffsets(
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(sgOffsets))
return failure();
@@ -1053,7 +1053,7 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.computeDistributedCoords(
+ auto sgOffsets = layout.computeDistributedOffsets(
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(sgOffsets))
return failure();
diff --git a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
index fe129428dc189..da4151024edb5 100644
--- a/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
+++ b/mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
@@ -291,19 +291,22 @@ gpu.module @xevm_module{
// CHECK: %[[DIST_UNIT_HEIGHT_X:.*]] = arith.constant 4 : index
// CHECK: %[[DIST_UNIT_HEIGHT_Y:.*]] = arith.constant 8 : index
// CHECK: %[[LANE_DATA_Y:.*]] = arith.constant 2 : index
+// CHECK: %[[USER_OFFSET_X:.*]] = arith.constant 1 : index
// CHECK: %[[LANE_ID:.*]] = gpu.lane_id
// CHECK: %[[DELINEARIZED_LANE_Y:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[DELINEARIZED_LANE_X:.*]] = affine.apply #{{.*}}()[%[[LANE_ID]]]
// CHECK: %[[LANE_Y_OFFSET_1:.*]] = index.mul %[[DELINEARIZED_LANE_Y]], %[[LANE_DATA_Y]]
// CHECK: %[[LANE_Y_OFFSET:.*]] = index.remu %[[LANE_Y_OFFSET_1]], %[[DIST_UNIT_HEIGHT_Y]]
-// CHECK: %[[LANE_X_OFFSET:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[LANE_X_OFFSET_1:.*]] = index.remu %[[DELINEARIZED_LANE_X]], %[[DIST_UNIT_HEIGHT_X]]
+// CHECK: %[[LANE_X_OFFSET:.*]] = index.add %[[LANE_X_OFFSET_1]], %[[USER_OFFSET_X]]
// CHECK: %[[MAT:.*]] = xegpu.load_matrix %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<2x1xf32>
// CHECK: xegpu.store_matrix %[[MAT]], %arg0[%[[LANE_Y_OFFSET]], %[[LANE_X_OFFSET]]] : vector<2x1xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.module @xevm_module{
gpu.func @load_store_matrix_2(%arg0: !xegpu.mem_desc<32x32xf32>) {
%c0 = arith.constant 0 : index
- %1 = xegpu.load_matrix %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
- xegpu.store_matrix %1, %arg0[%c0, %c0] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x4xf32>
+ xegpu.store_matrix %1, %arg0[%c0, %c1] <{layout = #xegpu.layout<lane_layout = [4, 4], lane_data = [2, 1]>}> : vector<8x4xf32>, !xegpu.mem_desc<32x32xf32>, index, index
gpu.return
}
}
@@ -317,9 +320,10 @@ gpu.module @xevm_module{
gpu.module @xevm_module{
gpu.func @load_store_matrix_3(%arg0: !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>) {
%c0 = arith.constant 0 : index
- %1 = xegpu.load_matrix %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+ %c1 = arith.constant 1 : index
+ %1 = xegpu.load_matrix %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
!xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<2x16xf32>
- xegpu.store_matrix %1, %arg0[%c0, %c0] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
+ xegpu.store_matrix %1, %arg0[%c0, %c1] {subgroup_block_io, layout = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} :
vector<2x16xf32>, !xegpu.mem_desc<32x32xf32, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
gpu.return
}
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 4408e827a97fc..61ebdce5d7995 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,7 +200,7 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto maybeOffsets = sliceAttr.computeDistributedCoords(
+ auto maybeOffsets = sliceAttr.computeDistributedOffsets(
rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
if (failed(maybeOffsets))
return failure();
>From 5965b543738799f504a171624384b4c388fc0deb Mon Sep 17 00:00:00 2001
From: Artem Kroviakov <artem.kroviakov at intel.com>
Date: Tue, 28 Oct 2025 17:30:33 +0000
Subject: [PATCH 5/5] Remove DistributionLevel enum
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 32 +++++++------------
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 30 ++++++++---------
.../Transforms/XeGPUSubgroupDistribute.cpp | 5 ++-
.../Transforms/XeGPUWgToSgDistribute.cpp | 12 +++----
.../lib/Dialect/XeGPU/TestXeGPUTransforms.cpp | 4 +--
5 files changed, 34 insertions(+), 49 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 794a84c839548..699a7c7e0cf98 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -167,16 +167,6 @@ def XeGPU_FenceScope: I32EnumAttr<"FenceScope",
let cppNamespace = "::mlir::xegpu";
}
-def XeGPU_WGLevel: I32EnumAttrCase<"WG", 0, "wg">;
-def XeGPU_SGLevel: I32EnumAttrCase<"SG", 1, "sg">;
-def XeGPU_WILevel: I32EnumAttrCase<"WI", 2, "wi">;
-def XeGPU_DistributionLevel: I32EnumAttr<"DistributionLevel",
- "Specify target level for offsets distribution utility.",
- [XeGPU_WGLevel, XeGPU_SGLevel, XeGPU_WILevel]> {
- let genSpecializedAttr = 0;
- let cppNamespace = "::mlir::xegpu";
-}
-
def XeGPU_FenceScopeAttr:
EnumAttr<XeGPU_Dialect, XeGPU_FenceScope, "fence_scope"> {
let summary = [{Describes the scope of fence.
@@ -234,17 +224,17 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
"xegpu::DistributeLayoutAttr",
"dropInstData">,
InterfaceMethod<[{Delinearizes a linear ID into its multidimensional
- indices based on the effective `level` layout.}],
+ indices based on the effective layout level.}],
"FailureOr<SmallVector<Value>>",
"delinearizeId",
- (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "xegpu::DistributionLevel": $level)>,
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for dist units
- assigned to a `level` identified by linearId. The shape parameter
- represents the higher-level problem size. Each `level` may access
+ assigned to a level identified by linearId. The shape parameter
+ represents the higher-level problem size. Each level may access
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
"computeDistributedOffsets",
- (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape, "xegpu::DistributionLevel": $level)>,
+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
to some other layout according to given permutation of (0...n-1).}],
/*retTy=*/"bool",
@@ -487,16 +477,16 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
}
/// Delinearizes a linear ID into its multidimensional indices
- /// based on the effective `level` layout.
+ /// based on the effective level of the layout.
FailureOr<SmallVector<Value>>
- delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
+ delinearizeId(OpBuilder &builder, Location loc, Value linearId);
/// Generates instructions to compute multidimensional offsets for dist units
- /// assigned to a `level` identified by linearId. The shape parameter
+ /// assigned to a level identified by linearId. The shape parameter
/// represents the higher-level problem size. Each `level` may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
- computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+ computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
@@ -653,7 +643,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
- delinearizeId(OpBuilder &builder, Location loc, Value linearId, xegpu::DistributionLevel level);
+ delinearizeId(OpBuilder &builder, Location loc, Value linearId);
/// Generates instructions to compute multidimensional offsets for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
@@ -661,7 +651,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
- computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape, xegpu::DistributionLevel level);
+ computeDistributedOffsets(OpBuilder &builder, Location loc, Value linearId,ArrayRef<int64_t> shape);
/// Check if this is slice of some other layout.
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index e335efefb608f..d162d36bef504 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -273,8 +273,7 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
}
FailureOr<SmallVector<Value>>
-LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
- xegpu::DistributionLevel idLevel) {
+LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
// TODO: handle order attribute
auto hasDefaultOrder = [&]() {
@@ -285,9 +284,9 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
if (!hasDefaultOrder())
return mlir::emitError(loc, "order attribute is currently not supported.");
SmallVector<int64_t> layout;
- if (idLevel == xegpu::DistributionLevel::SG) {
+ if (isForWorkgroup()) {
layout = getEffectiveSgLayoutAsInt();
- } else if (idLevel == xegpu::DistributionLevel::WI) {
+ } else if (isForSubgroup()) {
layout = getEffectiveLaneLayoutAsInt();
} else {
return failure();
@@ -304,14 +303,13 @@ LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
/// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
- Value linearId, ArrayRef<int64_t> shape,
- xegpu::DistributionLevel targetLevel) {
+ Value linearId, ArrayRef<int64_t> shape) {
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
- if (targetLevel == DistributionLevel::SG) {
+ if (isForWorkgroup()) {
layout = getEffectiveSgLayoutAsInt();
subShape = getEffectiveSgDataAsInt();
- } else if (targetLevel == DistributionLevel::WI) {
+ } else if (isForSubgroup()) {
layout = getEffectiveLaneLayoutAsInt();
subShape = getEffectiveLaneDataAsInt();
} else {
@@ -325,7 +323,7 @@ LayoutAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
}
// delinearize Ids
- auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
SmallVector<Value> ids = *maybeIds;
@@ -384,11 +382,10 @@ SliceAttr SliceAttr::flatten() const {
}
FailureOr<SmallVector<Value>>
-SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
- xegpu::DistributionLevel level) {
+SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
- return parent.delinearizeId(builder, loc, linearId, level);
+ return parent.delinearizeId(builder, loc, linearId);
}
// Implements DistributeLayoutAttr::computeDistributedOffsets to generate
@@ -396,18 +393,17 @@ SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId,
// LayoutAttr.
FailureOr<SmallVector<SmallVector<Value>>>
SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
- Value linearId, ArrayRef<int64_t> shape,
- xegpu::DistributionLevel targetLevel) {
+ Value linearId, ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isForWorkgroup())
return failure();
SmallVector<int64_t> layout;
SmallVector<int64_t> subShape;
- if (targetLevel == DistributionLevel::SG) {
+ if (isForWorkgroup()) {
layout = getEffectiveSgLayoutAsInt();
subShape = getEffectiveSgDataAsInt();
- } else if (targetLevel == DistributionLevel::WI) {
+ } else if (isForSubgroup()) {
layout = getEffectiveLaneLayoutAsInt();
subShape = getEffectiveLaneDataAsInt();
} else {
@@ -422,7 +418,7 @@ SliceAttr::computeDistributedOffsets(OpBuilder &builder, Location loc,
}
// delinearize Ids
- auto maybeIds = delinearizeId(builder, loc, linearId, targetLevel);
+ auto maybeIds = delinearizeId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
index c576172683f68..b9ec19f15b65c 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp
@@ -911,9 +911,8 @@ static SmallVector<Value> computeDistributedOffsetsForMatrixOp(
PatternRewriter &rewriter, Location loc, xegpu::DistributeLayoutAttr layout,
Value laneId, ArrayRef<int64_t> payloadShape, ValueRange origOffsets) {
SmallVector<Value> newOffsets;
- ;
- auto maybeDescOffsets = layout.computeDistributedOffsets(
- rewriter, loc, laneId, payloadShape, xegpu::DistributionLevel::WI);
+ auto maybeDescOffsets =
+ layout.computeDistributedOffsets(rewriter, loc, laneId, payloadShape);
if (failed(maybeDescOffsets))
return {};
assert(maybeDescOffsets.value().size() == 1 &&
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 35072f0529072..5f8627bc75d4d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -114,8 +114,8 @@ genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
// Compute the list of subgroup-relative offsets for sub-tensors or sub-memory
// descriptors to be accessed, based on the layout information.
ArrayRef<int64_t> wgShape = op.getDataShape();
- auto maybeDescOffsets = layout.computeDistributedOffsets(
- rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+ auto maybeDescOffsets =
+ layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
if (failed(maybeDescOffsets))
return failure();
@@ -831,8 +831,8 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.computeDistributedOffsets(
- rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+ auto sgOffsets =
+ layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
@@ -1053,8 +1053,8 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto sgOffsets = layout.computeDistributedOffsets(
- rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+ auto sgOffsets =
+ layout.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();
diff --git a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
index 61ebdce5d7995..ba5591a996eec 100644
--- a/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
+++ b/mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp
@@ -200,8 +200,8 @@ class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- auto maybeOffsets = sliceAttr.computeDistributedOffsets(
- rewriter, loc, sgId, wgShape, xegpu::DistributionLevel::SG);
+ auto maybeOffsets =
+ sliceAttr.computeDistributedOffsets(rewriter, loc, sgId, wgShape);
if (failed(maybeOffsets))
return failure();
More information about the Mlir-commits
mailing list