[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] set_op_layout_attr supports setting anchor layout (PR #172542)
Tuomas Kärnä
llvmlistbot at llvm.org
Tue Dec 16 11:49:08 PST 2025
https://github.com/tkarna created https://github.com/llvm/llvm-project/pull/172542
Changes `transform.xegpu.set_op_layout_attr` to support xegpu anchor layouts. By default, if `result` and `operand` bool arguments are unset, this transform op sets the op's anchor layout, if the op supports it (otherwise emits a silenceable failure).
In contrast to the earlier implementation, setting the operand layout now requires setting the new `operand` argument.
>From 84024adcf70d2ed010d82760efc5a77d364a1a54 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Fri, 12 Dec 2025 19:31:01 +0200
Subject: [PATCH] set_op_layout_attr supports setting anchor layout
---
.../XeGPU/TransformOps/XeGPUTransformOps.td | 22 ++++++++-----
.../XeGPU/TransformOps/XeGPUTransformOps.cpp | 33 ++++++++++++++++---
mlir/python/mlir/dialects/transform/xegpu.py | 4 +++
.../Dialect/XeGPU/transform-ops-invalid.mlir | 21 +++++++++++-
mlir/test/Dialect/XeGPU/transform-ops.mlir | 24 ++++++++++++--
.../python/dialects/transform_xegpu_ext.py | 32 ++++++++++++++++--
6 files changed, 117 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 29579acc727ed..8c5539ad19382 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -109,12 +109,14 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
let summary = "Set xegpu.layout attribute of an op.";
let description = [{
- Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
- `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
- target operand/result value is defined by the `index` argument. The layout
- is defined by the `sg_layout`, `sg_data` and optional `inst_data`
- attributes. If `slice_dims` is provided, the `xegpu.layout` attribute is
- wrapped in an `xegpu.slice<..., dims=slice_dims>` attribute.
+ Sets the `xegpu.layout` attribute of an op. By default it sets the anchor
+ layout for XeGPU ops that support it. If `result=true` or `operand=true`,
+ it sets the `layout_result_{index}` or `layout_operand_{index}` attribute,
+ respectively, applicable to any op. The target operand/result value is
+ defined by the `index` argument. The layout is defined by the `sg_layout`,
+ `sg_data` and optional `inst_data` attributes. If `slice_dims` is provided,
+ the `xegpu.layout` attribute is wrapped in an `xegpu.slice<..., dims=slice_dims>`
+ attribute.
}];
let arguments = (ins TransformHandleTypeInterface:$target,
@@ -126,7 +128,8 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
- DefaultValuedAttr<UnitAttr, "false">:$result
+ DefaultValuedAttr<UnitAttr, "false">:$result,
+ DefaultValuedAttr<UnitAttr, "false">:$operand
);
let results = (outs);
@@ -137,12 +140,13 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
"ArrayRef<int64_t>":$sliceDims,
- CArg<"bool", "false">:$result
+ CArg<"bool", "false">:$result,
+ CArg<"bool", "false">:$operand
)>,
];
let assemblyFormat = [{
- $target (`result` $result^)? (`index` `=` $index^)?
+ $target (`result` $result^)? (`operand` $operand^)? (`index` `=` $index^)?
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index e6009d5afeab2..9f6c92ce78657 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -290,7 +290,7 @@ void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
- bool result) {
+ bool result, bool operand) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -306,7 +306,8 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*slice_dims=*/sliceDims,
- /*result=*/result);
+ /*result=*/result,
+ /*operand=*/operand);
}
DiagnosedSilenceableFailure
@@ -321,13 +322,19 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
Operation *target = *targetOps.begin();
bool resultTarget = getResult();
+ bool operandTarget = getOperand();
+
+ if (resultTarget && operandTarget) {
+ return emitSilenceableFailure(getLoc())
+ << "Cannot set both result and operand layout attributes.";
+ }
int64_t index = getIndex();
if (resultTarget && index >= target->getNumResults()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op results";
}
- if (!resultTarget && index >= target->getNumOperands()) {
+ if (operandTarget && index >= target->getNumOperands()) {
return emitSilenceableFailure(getLoc())
<< "Index exceeds the number of op operands";
}
@@ -348,10 +355,26 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
}
// Set layout attribute for the op result or operand
- if (resultTarget)
+ if (resultTarget) {
xegpu::setDistributeLayoutAttr(target->getResult(index), layout);
- else
+ } else if (operandTarget) {
xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layout);
+ } else {
+ // Set anchor layout if requested.
+ // TODO use AnchorLayoutInterface when available.
+ if (!isa<xegpu::LoadNdOp>(target)) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Cannot set anchor layout to op: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+ auto loadOp = dyn_cast<xegpu::LoadNdOp>(target);
+ if (loadOp)
+ loadOp.setLayoutAttr(layout);
+ auto prefetchOp = dyn_cast<xegpu::PrefetchOp>(target);
+ if (prefetchOp)
+ prefetchOp.setLayoutAttr(layout);
+ }
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 5aa6453b7cb8a..a768ce5f4e720 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -134,6 +134,7 @@ def __init__(
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
+ operand: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
):
@@ -164,6 +165,7 @@ def __init__(
slice_dims=slice_dims,
index=index,
result=result,
+ operand=operand,
loc=loc,
ip=ip,
)
@@ -178,6 +180,7 @@ def set_op_layout_attr(
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
+ operand: Optional[Union[bool, Attribute]] = None,
loc=None,
ip=None,
) -> SetOpLayoutAttrOp:
@@ -189,6 +192,7 @@ def set_op_layout_attr(
slice_dims=slice_dims,
index=index,
result=result,
+ operand=operand,
loc=loc,
ip=ip,
)
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index dce4a41982550..2a147497a893b 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -47,7 +47,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error at below {{Index exceeds the number of op operands}}
- transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
@@ -67,6 +67,25 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_not_anchor_op
+func.func @set_op_layout_attr_not_anchor_op(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32> // expected-note {{target op}}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Cannot set anchor layout to op: arith.extf}}
transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 561034fb5880b..dd51f1e50c62d 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -264,7 +264,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+ transform.xegpu.set_op_layout_attr %0 operand sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
transform.yield
}
}
@@ -287,7 +287,27 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
- transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.xegpu.set_op_layout_attr %0 operand index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor
+func.func @set_op_layout_attr_anchor(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
transform.yield
}
}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 2b11acb04ed5b..d8d2b9cb33ca0 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -97,10 +97,12 @@ def setOpLayoutAttrOperandMinimal():
sequence.bodyTarget,
sg_layout=[6, 4],
sg_data=[32, 16],
+ operand=True,
)
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttr
# CHECK: transform.xegpu.set_op_layout_attr %
+ # CHECK: operand
# NO-CHECK: index = 0
# NO-CHECK: result
# CHECK: sg_layout = [6, 4]
@@ -127,8 +129,8 @@ def setOpLayoutAttrResult():
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResult
# CHECK: transform.xegpu.set_op_layout_attr %
- # NO-CHECK: index = 0
# CHECK: result
+ # NO-CHECK: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
@@ -154,14 +156,40 @@ def setOpLayoutAttrResultSlice():
transform.YieldOp()
# CHECK-LABEL: TEST: setOpLayoutAttrResultSlice
# CHECK: transform.xegpu.set_op_layout_attr %
- # NO-CHECK: index = 0
# CHECK: result
+ # NO-CHECK: index = 0
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]
+ at run
+def setOpLayoutAttrAnchor():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrAnchor
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # NO-CHECK: result
+ # NO-CHECK: operand
+ # NO-CHECK: index = 0
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+
+
@run
def setGPULaunchThreadsOp():
sequence = transform.SequenceOp(
More information about the Mlir-commits
mailing list