[Mlir-commits] [mlir] [mlir][xegpu] Add support for setting `order` in `SetDescLayoutOp` and `SetOpLayoutAttrOp` transform ops. (PR #184705)
Charitha Saumya
llvmlistbot at llvm.org
Wed Mar 4 15:45:58 PST 2026
https://github.com/charithaintc created https://github.com/llvm/llvm-project/pull/184705
Currently XeGPU transform dialect does not allow the user to set the `order` attribute of a layout in `SetDescLayoutOp` and `SetOpLayoutAttrOp`. This PR adds `order` as an optional argument to these transform ops.
>From e1010fe019940630728e7c9bdeaa5c85e07e3d7e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 4 Mar 2026 23:02:12 +0000
Subject: [PATCH 1/2] save work
---
.../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td | 9 +++
.../XeGPU/TransformOps/XeGPUTransformOps.td | 10 +++-
.../XeGPU/TransformOps/XeGPUTransformOps.cpp | 22 ++++++-
mlir/test/Dialect/XeGPU/transform-ops.mlir | 60 +++++++++++++++++++
4 files changed, 96 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 6f667f4801673..fd48411e51b02 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -494,6 +494,15 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
return 0;
}
+ /// Derive a new layout with the same sg_layout, sg_data, inst_data,
+ /// lane_layout and lane_data but different order.
+ LayoutAttr cloneWithOrder(DenseI32ArrayAttr newOrder) const {
+ assert(getRank() == static_cast<int64_t>(newOrder.size())
+ && "The size of new order must match the layout rank.");
+ return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), getInstData(),
+ getLaneLayout(), getLaneData(), newOrder);
+ }
+
LayoutAttr dropSgLayoutAndData() const{
// avoid every field of the attribute is nullptr, which may lead to segment fault
if (!getInstData() && !getLaneLayout())
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 23dabe4eb380a..bee7eea79dce9 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -58,7 +58,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
+ DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order
);
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -67,7 +68,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
- "ArrayRef<int64_t>":$sliceDims
+ "ArrayRef<int64_t>":$sliceDims,
+ "ArrayRef<int32_t>":$order
)>,
];
@@ -77,6 +79,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
(`slice_dims` `=` $slice_dims^)?
+ (`order` `=` $order^)?
attr-dict `:` functional-type(operands, results)
}];
@@ -128,6 +131,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
+ DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
DefaultValuedAttr<UnitAttr, "false">:$result,
DefaultValuedAttr<UnitAttr, "false">:$operand
);
@@ -140,6 +144,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
"ArrayRef<int64_t>":$sliceDims,
+ "ArrayRef<int32_t>":$order,
CArg<"bool", "false">:$result,
CArg<"bool", "false">:$operand
)>,
@@ -151,6 +156,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
(`slice_dims` `=` $slice_dims^)?
+ (`order` `=` $order^)?
attr-dict `:` qualified(type(operands))
}];
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 7bc67da8263dc..58b103614454c 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include <optional>
@@ -215,7 +216,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
ArrayRef<OpFoldResult> mixedSgLayout,
ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData,
- ArrayRef<int64_t> sliceDims) {
+ ArrayRef<int64_t> sliceDims,
+ ArrayRef<int32_t> order) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -229,7 +231,8 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
- /*slice_dims=*/sliceDims);
+ /*slice_dims=*/sliceDims,
+ /*order=*/order);
}
DiagnosedSilenceableFailure
@@ -250,6 +253,12 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
if (!status.succeeded())
return status;
+ // If order is provided, clone the layout with the provided order.
+ auto order = getOrder();
+ if (order.size() > 0)
+ layoutAttr =
+ layoutAttr.cloneWithOrder(DenseI32ArrayAttr::get(getContext(), order));
+
xegpu::DistributeLayoutAttr layout = layoutAttr;
auto sliceDims = getSliceDims();
if (sliceDims.size() > 0) {
@@ -291,7 +300,7 @@ void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
- bool result, bool operand) {
+ ArrayRef<int32_t> order, bool result, bool operand) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -307,6 +316,7 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
/*slice_dims=*/sliceDims,
+ /*order=*/order,
/*result=*/result,
/*operand=*/operand);
}
@@ -342,6 +352,12 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
if (!status.succeeded())
return status;
+ // If order is provided, clone the layout with the provided order.
+ auto order = getOrder();
+ if (order.size() > 0)
+ layoutAttr =
+ layoutAttr.cloneWithOrder(DenseI32ArrayAttr::get(getContext(), order));
+
xegpu::DistributeLayoutAttr layout = layoutAttr;
auto sliceDims = getSliceDims();
if (sliceDims.size() > 0) {
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 9a278cbf7b498..5b2dc47246463 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -140,6 +140,26 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @set_desc_layout_order
+func.func @set_desc_layout_order(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
+ // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] order = [1, 0] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @set_op_layout_attr_result_default
func.func @set_op_layout_attr_result_default(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -225,6 +245,25 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @set_op_layout_attr_result_order
+func.func @set_op_layout_attr_result_order(%arg0: vector<256xf16>) {
+ // CHECK: = arith.extf
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [0, 1]>}
+ %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [0, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @set_op_layout_attr_operand_minimal
func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -287,6 +326,27 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_order
+func.func @set_op_layout_attr_anchor_order(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [1, 0]>}>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
+ transform.yield
+ }
+}
+
+
// -----
// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_a
>From 29dbdc18a4832e99564f753614c6053297939b4e Mon Sep 17 00:00:00 2001
From: Charitha Saumya <charitha.saumya.gusthinna.waduge at intel.com>
Date: Wed, 4 Mar 2026 23:42:24 +0000
Subject: [PATCH 2/2] save work
---
mlir/python/mlir/dialects/transform/xegpu.py | 8 ++++
.../python/dialects/transform_xegpu_ext.py | 45 +++++++++++++++++++
2 files changed, 53 insertions(+)
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index a768ce5f4e720..03a5239dceff1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -63,6 +63,7 @@ def __init__(
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
+ order: Optional[MixedInt] = None,
loc=None,
ip=None,
):
@@ -94,6 +95,7 @@ def __init__(
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
slice_dims=slice_dims,
+ order=order,
loc=loc,
ip=ip,
)
@@ -106,6 +108,7 @@ def set_desc_layout(
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
+ order: Optional[MixedInt] = None,
loc=None,
ip=None,
) -> OpResult:
@@ -115,6 +118,7 @@ def set_desc_layout(
sg_data,
inst_data=inst_data,
slice_dims=slice_dims,
+ order=order,
loc=loc,
ip=ip,
).result
@@ -132,6 +136,7 @@ def __init__(
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
+ order: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
operand: Optional[Union[bool, Attribute]] = None,
@@ -163,6 +168,7 @@ def __init__(
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
slice_dims=slice_dims,
+ order=order,
index=index,
result=result,
operand=operand,
@@ -178,6 +184,7 @@ def set_op_layout_attr(
*,
inst_data: Optional[MixedValues] = None,
slice_dims: Optional[MixedInt] = None,
+ order: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
operand: Optional[Union[bool, Attribute]] = None,
@@ -190,6 +197,7 @@ def set_op_layout_attr(
sg_data,
inst_data=inst_data,
slice_dims=slice_dims,
+ order=order,
index=index,
result=result,
operand=operand,
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index e3e1313cf5f81..afb8ef9514354 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -84,6 +84,24 @@ def setDescLayoutSlice():
# CHECK: sg_data = [32, 16]
# CHECK: slice_dims = [0]
+ at run
+def setDescLayoutOrder():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.create_nd_tdesc"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_desc_layout(
+ sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], order=[0, 1]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setDescLayoutOrder
+ # CHECK: %0 = transform.xegpu.set_desc_layout %
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: order = [0, 1]
+
@run
def setOpLayoutAttrOperandMinimal():
@@ -163,6 +181,33 @@ def setOpLayoutAttrResultSlice():
# CHECK: inst_data = [8, 16]
# CHECK: slice_dims = [0]
+ at run
+def setOpLayoutAttrResultOrder():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ order=[0, 1],
+ result=True,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrResultOrder
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # CHECK: result
+ # CHECK-NOT: index = 0
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+ # CHECK: order = [0, 1]
+
@run
def setOpLayoutAttrAnchor():
More information about the Mlir-commits
mailing list