[Mlir-commits] [mlir] [MLIR][XeGPU] Refactor isEvenlyDistributable() to Layout attribute interface (PR #191945)
Jianhui Li
llvmlistbot at llvm.org
Mon Apr 13 22:07:18 PDT 2026
https://github.com/Jianhui-Li created https://github.com/llvm/llvm-project/pull/191945
This PR refactor isEvenlyDistributable() to layout attribute interface isDistributable(), and used them in all anchor operations to check the shape can be ditributed with the anchor layout.
>From 2bb5d45698c9dc18a5f150a79e51a4b1258cae6c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Apr 2026 02:53:36 +0000
Subject: [PATCH 1/2] refactor isEvenlyDistributable
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 60 +++++++++++--
.../mlir/Dialect/XeGPU/IR/XeGPUDialect.td | 4 -
mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 88 +------------------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 40 ++++++++-
.../Transforms/XeGPUWgToSgDistribute.cpp | 2 +-
mlir/test/Dialect/XeGPU/invalid.mlir | 8 --
.../XeGPU/propagate-layout-subgroup.mlir | 16 ++--
mlir/test/Dialect/XeGPU/transform-ops.mlir | 22 ++---
8 files changed, 114 insertions(+), 126 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f8a2beabb9b95..a7bea9881602f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -188,6 +188,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
InterfaceMethod<"Check the availability of subgroup level layouts",
"bool",
"isForSubgroup">,
+ InterfaceMethod<"Check the availability of lane level layouts",
+ "bool",
+ "isForLane">,
InterfaceMethod<"Get the rank of attribute",
"int64_t",
"getRank">,
@@ -299,26 +302,60 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
} else {
return failure();
}
- assert(
- !subShape.empty() &&
- "sgdata or lanedata cannot be empty for distributed shape computation");
+ // sgdata or lanedata cannot be empty for distributed shape computation
+ if (subShape.empty())
+ return failure();
SmallVector<int64_t> distributedShape(shape.size());
for (auto [i, dim] : llvm::enumerate(shape)) {
int64_t distriUnit = layout[i]*subShape[i];
if ((dim % distriUnit) == 0) {
// Evenly divisible case, divide the dimension by the layout factor.
distributedShape[i] = dim / layout[i];
- assert((distributedShape[i] % subShape[i] == 0) &&
- "Even distribution: sgdata or lanedata must divide the distributed dimension");
+ if (distributedShape[i] % subShape[i] != 0)
+ return failure();
} else {
// wrap around case, the dimension size must be equal to subShape value
- assert(dim == subShape[i] &&
- "Wrap-around distribution: sgdata or lanedata must be same as tensor tile shape");
+ if(dim != subShape[i])
+ return failure();
distributedShape[i] = dim;
}
}
return distributedShape;
}]>,
+ InterfaceMethod<[{Checks if the given shape can be distributed by the layout}],
+ /*retTy=*/"bool",
+ /*methodName=*/"isDistributable",
+ /*args=*/(ins "SmallVector<int64_t>":$shape),
+ /*methodBody=*/[{
+ DistributeLayoutAttr curLayoutAttr = $_self;
+ SmallVector<int64_t> curShape = shape;
+ // Phase 1: Distribute across subgroups (sg_layout + sg_data).
+ if (curLayoutAttr.isForWorkgroup()) {
+ auto maybeSgShape = curLayoutAttr.computeDistributedShape(curShape);
+ if (failed(maybeSgShape))
+ return false;
+ curShape = maybeSgShape.value();
+ curLayoutAttr = curLayoutAttr.dropSgLayoutAndData();
+ if (!curLayoutAttr)
+ return true;
+ }
+ // Phase 2: Distribute across instruction data (inst_data).
+ if (curLayoutAttr.isForSubgroup() && !curLayoutAttr.isForLane()) {
+ SmallVector<int64_t> instData = curLayoutAttr.getEffectiveInstDataAsInt();
+ for (size_t i = 0; i < curShape.size(); ++i) {
+ if (curShape[i] % instData[i] != 0)
+ return false;
+ }
+ // inst_data becomes the new shape for next phase
+ curShape = instData;
+ curLayoutAttr = curLayoutAttr.dropInstData();
+ if (!curLayoutAttr)
+ return true;
+ }
+ // Phase 3: Distribute across lanes (lane_layout + lane_data).
+ auto maybeLaneShape = curLayoutAttr.computeDistributedShape(curShape);
+ return succeeded(maybeLaneShape);
+ }]>,
InterfaceMethod</*desc=*/[{Check if this layout is a slice of another layout.}],
/*retTy=*/"bool",
/*methodName=*/"isSliceOf",
@@ -487,6 +524,10 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return !isForWorkgroup();
}
+ bool isForLane() {
+ return !isForWorkgroup() && (getInstData() == nullptr);
+ }
+
int64_t getRank() const {
if (auto attr = getSgLayout())
return attr.size();
@@ -687,6 +728,11 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
return parent.isForSubgroup();
}
+ bool isForLane() const {
+ auto parent = dyn_cast<LayoutAttr>(getParent());
+ return parent.isForLane();
+ }
+
/// Returns the SgLayout of the attribute, computed by applying
/// the slice dimensions to the underlying LayoutAttr.
SmallVector<int64_t> getEffectiveSgLayoutAsInt() const {
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
index c173b93face98..84fd8f9e0060c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
@@ -38,10 +38,6 @@ def XeGPU_Dialect : Dialect {
let useDefaultAttributePrinterParser = true;
let extraClassDeclaration = [{
- /// Checks if the given shape can be evenly distributed based on the layout
- /// and data factors provided by the LayoutAttr.
- static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::DistributeLayoutAttr attr);
-
/// drops/slices the shape in the specified dims, and return the rest. e.g.,
/// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
template<typename T, typename U>
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index eaa43c02946d8..80a3fc91f1c4f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -121,74 +121,6 @@ static SmallVector<SmallVector<int64_t>> genStaticCoordinates(
return coordinates;
}
-// Checks if the given shape can be evenly distributed based on the layout
-// and data factors provided by the LayoutAttr.
-bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
- xegpu::DistributeLayoutAttr attr) {
- assert(attr && "Layout attribute is missing.");
-
- // Checks whether the given shape can be evenly distributed using the
- // specified layout and data attributes. If successful, it returns the work
- // size for each compute unit; otherwise, it returns `std::nullopt`. The work
- // size per compute unit is calculated as follows:
- // - If `data` is null: newShape[i] = shape[i] / layout[i]
- // - If `data` is not null: newShape[i] = data[i]
- // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
- // smaller than `layout[i] * data[i]`, allowing multiple compute units to
- // share the data.
- auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
- SmallVector<int64_t> layout,
- SmallVector<int64_t> data,
- bool rr = true) -> optional<SmallVector<int64_t>> {
- llvm::SmallVector<int64_t> newShape(shape);
- if (layout.size()) {
- if (layout.size() != shape.size())
- return std::nullopt;
- auto ratio = computeShapeRatio(shape, layout);
- if (ratio.has_value()) {
- newShape = ratio.value();
- } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
- return std::nullopt;
- }
- // Round-robin case: continue with original newShape
- }
-
- if (data.size()) {
- if (data.size() != shape.size())
- return std::nullopt;
- auto ratio = computeShapeRatio(newShape, data);
- if (!ratio.has_value() && rr)
- ratio = computeShapeRatio(data, newShape);
- if (!ratio.has_value())
- return std::nullopt;
-
- // if data is not null, we always return it for next phase.
- newShape = data;
- }
- return newShape;
- };
-
- // check the sgLayout and sgData
- auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
- attr.getEffectiveSgDataAsInt());
- if (!maybeSgShape)
- return false;
- auto sgShape = maybeSgShape.value();
-
- // check InstData, it neither have layout nor need round-robin
- auto maybeInstShape =
- tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
- if (!maybeInstShape)
- return false;
- auto instShape = maybeInstShape.value();
-
- // check LaneLayout and LaneData
- auto maybeLaneShape =
- tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
- attr.getEffectiveLaneDataAsInt());
- return maybeLaneShape.has_value();
-}
-
//===----------------------------------------------------------------------===//
// XeGPU_BlockTensorDescAttr
//===----------------------------------------------------------------------===//
@@ -1431,25 +1363,12 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
<< chunkAlignmentFactor;
}
}
-
- auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
- if (layoutAttr) {
+ if (auto layoutAttr =
+ mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
if (rank != (size_t)layoutAttr.getRank())
return emitError() << "expected layout rank to match tensor rank";
- auto laneData = layoutAttr.getLaneData();
- if (scatterAttr && laneData) {
- // Validate subgroup mapping rules for scattered tensors.
- // if chunkSize > 1, the last dimension of the tensor should
- // be distributed in the units divisible by chunkAlignmentFactor.
- int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
- if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
- return emitError()
- << "expected last dim of lane_data to be a multiple of: "
- << chunkAlignmentFactor;
- }
-
- if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
+ if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
std::string shapeStr;
llvm::raw_string_ostream stream(shapeStr);
llvm::interleaveComma(shape, stream);
@@ -1457,6 +1376,7 @@ TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
<< layoutAttr;
}
}
+
return success();
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 5697097a4c999..dae041544753c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -504,6 +504,14 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
+ if (getAnchorLayout()) {
+ auto layout = getAnchorLayout();
+ auto tdescShape = getShapeOf(tdescTy);
+ if (!layout.isDistributable(tdescShape))
+ return emitOpError(
+ "TensorDesc shape is not distributable with the layout");
+ }
+
return success();
}
@@ -628,6 +636,13 @@ LogicalResult LoadNdOp::verify() {
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
+ if (getAnchorLayout()) {
+ auto layout = getAnchorLayout();
+ if (!layout.isDistributable(tdescShape))
+ return emitOpError(
+ "TensorDesc shape is not distributable with the layout");
+ }
+
return success();
}
@@ -721,6 +736,13 @@ LogicalResult StoreNdOp::verify() {
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");
+ if (getAnchorLayout()) {
+ auto layout = getAnchorLayout();
+ if (!layout.isDistributable(tdescShape))
+ return emitOpError(
+ "TensorDesc shape is not distributable with the layout");
+ }
+
return success();
}
@@ -823,6 +845,18 @@ LogicalResult PrefetchOp::verify() {
if (getOffsetAlignByteAttr() && !srcTy.isInteger())
return emitOpError("offset_align_byte only allowed with integer source.");
+ if (getAnchorLayout()) {
+ auto layout = getAnchorLayout();
+ // get the offset operand and its shape
+ auto offsets = getOffsets();
+ auto offsetsTy = offsets.getType();
+ if (!llvm::isa<VectorType>(offsetsTy))
+ return emitOpError("Offsets should be a vector.");
+ auto offsetShape = getShapeOf(offsetsTy);
+ if (!layout.isDistributable(offsetShape))
+ return emitOpError("offset shape is not distributable with the layout");
+ }
+
return success();
}
@@ -1103,12 +1137,12 @@ LogicalResult ConvertLayoutOp::verify() {
Type srcType = getSource().getType();
if (llvm::isa<VectorType>(srcType)) {
- auto shape = llvm::cast<VectorType>(srcType).getShape();
- if (!XeGPUDialect::isEvenlyDistributable(shape, srcLayout))
+ SmallVector<int64_t> shape(llvm::cast<VectorType>(srcType).getShape());
+ if (!srcLayout.isDistributable(shape))
return emitOpError(
"invalid input layout, data cannot be evenly distributed.");
- if (!XeGPUDialect::isEvenlyDistributable(shape, resLayout))
+ if (!resLayout.isDistributable(shape))
return emitOpError(
"invalid target layout, data cannot be evenly distributed.");
}
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a095c19d66c15..d637b6828deab 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -498,7 +498,7 @@ struct WgToSgVectorBroadcastOp
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
- if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
+ if (!layout.isDistributable(SmallVector<int64_t>(wgShape)))
return failure();
SmallVector<Value> newBroadcastOps;
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7390b47b3f8d9..82c7879c79d56 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -325,14 +325,6 @@ func.func @create_tdesc_layout_1(%src: ui64) {
return
}
-// -----
-func.func @create_tdesc_layout_2(%src: ui64) {
- %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
- // expected-error at +1 {{expected last dim of lane_data to be a multiple of: 2}}
- %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x4xf16, #xegpu.scatter_tdesc_attr<chunk_size = 4>, #xegpu.layout<lane_layout = [4, 1], lane_data = [1, 1]>>
- return
-}
-
// -----
func.func @load_gather_simt_1(%src: ui64) {
%0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
index 831d1e05967f8..62426d619445b 100644
--- a/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
+++ b/mlir/test/Dialect/XeGPU/propagate-layout-subgroup.mlir
@@ -27,24 +27,24 @@ gpu.module @test {
// CHECK-SAME: %[[ARG_1:.*]]: memref<128x256xf32>
func.func @vector_transpose(%src: memref<256x128xf32>, %src1: memref<128x256xf32>) {
// CHECK: %[[TDESC_LD:.*]] = xegpu.create_nd_tdesc %[[ARG_0]] : memref<256x128xf32> ->
- // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>>
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>>
// CHECK: %[[TDESC_ST:.*]] = xegpu.create_nd_tdesc %[[ARG_1]] : memref<128x256xf32> ->
- // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>>
- // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>}>
- // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], order = [0, 1]>> -> vector<256x128xf32>
+ // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC_LD]][0, 0] <{layout = #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>}>
+ // CHECK-SAME: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 16], order = [0, 1]>> -> vector<256x128xf32>
// CHECK: %[[TRANSPOSED:.*]] = vector.transpose %2, [1, 0]
- // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
+ // CHECK-SAME {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>} : vector<256x128xf32> to vector<128x256xf32>
// CHECK: xegpu.store_nd %[[TRANSPOSED]], %[[TDESC_ST]][0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>}> : vector<128x256xf32>,
- // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>}> : vector<128x256xf32>,
+ // CHECK-SAME: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>>
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32> -> !xegpu.tensor_desc<256x128xf32>
%tdesc1 = xegpu.create_nd_tdesc %src1 : memref<128x256xf32> -> !xegpu.tensor_desc<128x256xf32>
%load = xegpu.load_nd %tdesc[0, 0] : !xegpu.tensor_desc<256x128xf32> -> vector<256x128xf32>
%trans = vector.transpose %load, [1, 0] : vector<256x128xf32> to vector<128x256xf32>
- xegpu.store_nd %trans, %tdesc1[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], order = [1, 0]>}>
+ xegpu.store_nd %trans, %tdesc1[0, 0] <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 64], order = [1, 0]>}>
: vector<128x256xf32>, !xegpu.tensor_desc<128x256xf32>
return
}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index acba80d870253..3daa74d223946 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -83,7 +83,7 @@ module attributes {transform.with_named_sequence} {
func.func @set_anchor_layout(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
// CHECK: = xegpu.load_nd %0[0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
return
}
@@ -92,7 +92,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_anchor_layout %{{.*}}
- transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
@@ -103,9 +103,9 @@ module attributes {transform.with_named_sequence} {
func.func @set_anchor_layout_multiple(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
// CHECK: xegpu.prefetch_nd %0[0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
// CHECK: xegpu.prefetch_nd %0[16, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
xegpu.prefetch_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16>
xegpu.prefetch_nd %0[16, 0] : !xegpu.tensor_desc<256x32xf16>
return
@@ -115,7 +115,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.prefetch_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_anchor_layout %{{.*}}
- transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
@@ -126,7 +126,7 @@ module attributes {transform.with_named_sequence} {
func.func @set_anchor_layout_param(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
// CHECK: = xegpu.load_nd %0[0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
return
}
@@ -136,7 +136,7 @@ module attributes {transform.with_named_sequence} {
%0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_anchor_layout %{{.*}}
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
- transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
transform.yield
}
}
@@ -147,7 +147,7 @@ module attributes {transform.with_named_sequence} {
func.func @set_anchor_layout_param2(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
// CHECK: = xegpu.load_nd %0[0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>}>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
return
}
@@ -158,7 +158,7 @@ module attributes {transform.with_named_sequence} {
// CHECK: transform.xegpu.set_anchor_layout %{{.*}}
%layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
%layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
- transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [%layout0, %layout1] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
transform.yield
}
}
@@ -192,7 +192,7 @@ module attributes {transform.with_named_sequence} {
func.func @set_anchor_layout_order(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
// CHECK: = xegpu.load_nd %0[0, 0]
- // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [1, 0]>}>
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>}>
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
return
}
@@ -201,7 +201,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_anchor_layout %{{.*}}
- transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
+ transform.xegpu.set_anchor_layout %0 index = 0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] order = [1, 0] : !transform.any_op
transform.yield
}
}
>From aa7525b4beb0e35a95c3b78c74bd1f833e0d40a6 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Apr 2026 05:03:35 +0000
Subject: [PATCH 2/2] add isDistributableCheck for all anchor operations
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 52 ++++++++++++++++++++++----
1 file changed, 44 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index dae041544753c..9107cda30a8fa 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -218,6 +218,11 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
}
}
}
+
+ if (layout && !layout.isDistributable(
+ SmallVector<int64_t>(dataShape.begin(), dataShape.end())))
+ return emitError() << "Value shape is not distributable with the layout";
+
if (dataShape.size() == 2) {
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -638,7 +643,8 @@ LogicalResult LoadNdOp::verify() {
if (getAnchorLayout()) {
auto layout = getAnchorLayout();
- if (!layout.isDistributable(tdescShape))
+ auto origTdescShape = getShapeOf(tdescTy);
+ if (!layout.isDistributable(origTdescShape))
return emitOpError(
"TensorDesc shape is not distributable with the layout");
}
@@ -848,13 +854,14 @@ LogicalResult PrefetchOp::verify() {
if (getAnchorLayout()) {
auto layout = getAnchorLayout();
// get the offset operand and its shape
- auto offsets = getOffsets();
- auto offsetsTy = offsets.getType();
- if (!llvm::isa<VectorType>(offsetsTy))
- return emitOpError("Offsets should be a vector.");
- auto offsetShape = getShapeOf(offsetsTy);
- if (!layout.isDistributable(offsetShape))
- return emitOpError("offset shape is not distributable with the layout");
+ if (auto offsets = getOffsets()) {
+ auto offsetsTy = offsets.getType();
+ if (!llvm::isa<VectorType>(offsetsTy))
+ return emitOpError("Offsets should be a vector.");
+ auto offsetShape = getShapeOf(offsetsTy);
+ if (!layout.isDistributable(offsetShape))
+ return emitOpError("offset shape is not distributable with the layout");
+ }
}
return success();
@@ -904,6 +911,13 @@ LogicalResult LoadGatherOp::verify() {
if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
+ if (getAnchorLayout()) {
+ auto layout = getAnchorLayout();
+ auto valShape = getShapeOf(valueTy);
+ if (!layout.isDistributable(valShape))
+ return emitOpError("Value shape is not distributable with the layout");
+ }
+
auto offsetsTy = getOffsets().getType();
return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
@@ -988,6 +1002,13 @@ LogicalResult StoreScatterOp::verify() {
if (memTy && (getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";
+ if (getAnchorLayout()) {
+ auto layout = getAnchorLayout();
+ auto valShape = getShapeOf(valueTy);
+ if (!layout.isDistributable(valShape))
+ return emitOpError("Value shape is not distributable with the layout");
+ }
+
auto offsetsTy = getOffsets().getType();
return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
@@ -1086,6 +1107,21 @@ LogicalResult DpasOp::verify() {
auto rhsShape = getRhsType().getShape();
auto resShape = getResultType().getShape();
+ if (auto cdLayout = getLayoutCd())
+ if (!cdLayout->isDistributable(
+ SmallVector<int64_t>(resShape.begin(), resShape.end())))
+ return emitOpError("Value shape is not distributable with the layout");
+
+ if (auto aLayout = getLayoutA())
+ if (!aLayout->isDistributable(
+ SmallVector<int64_t>(lhsShape.begin(), lhsShape.end())))
+ return emitOpError("Value shape is not distributable with the layout");
+
+ if (auto bLayout = getLayoutB())
+ if (!bLayout->isDistributable(
+ SmallVector<int64_t>(rhsShape.begin(), rhsShape.end())))
+ return emitOpError("Value shape is not distributable with the layout");
+
if (getAcc() && getAcc().getType() != getResultType())
return emitOpError("Expecting the acc type to be the same as result.");
More information about the Mlir-commits
mailing list