[Mlir-commits] [mlir] ce22796 - [mlir][xegpu] Add support for setting `order` in `SetDescLayoutOp` and `SetOpLayoutAttrOp` transform ops. (#184705)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 9 08:37:34 PDT 2026
Author: Charitha Saumya
Date: 2026-03-09T08:37:28-07:00
New Revision: ce227964cc4de126c43f3458498ac70315809ce8
URL: https://github.com/llvm/llvm-project/commit/ce227964cc4de126c43f3458498ac70315809ce8
DIFF: https://github.com/llvm/llvm-project/commit/ce227964cc4de126c43f3458498ac70315809ce8.diff
LOG: [mlir][xegpu] Add support for setting `order` in `SetDescLayoutOp` and `SetOpLayoutAttrOp` transform ops. (#184705)
Currently XeGPU transform dialect does not allow the user to set the
`order` attribute of a layout in `SetDescLayoutOp` and
`SetOpLayoutAttrOp`. This PR adds `order` as an optional argument to
these transform ops.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
mlir/python/mlir/dialects/transform/xegpu.py
mlir/test/Dialect/XeGPU/transform-ops.mlir
mlir/test/python/dialects/transform_xegpu_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 23dabe4eb380a..f7f45508b6a03 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -58,6 +58,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
);
@@ -67,6 +68,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
+ "ArrayRef<int32_t>":$order,
"ArrayRef<int64_t>":$sliceDims
)>,
];
@@ -76,6 +78,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ (`order` `=` $order^)?
(`slice_dims` `=` $slice_dims^)?
attr-dict `:` functional-type(operands, results)
}];
@@ -127,6 +130,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+ DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
DefaultValuedAttr<UnitAttr, "false">:$result,
DefaultValuedAttr<UnitAttr, "false">:$operand
@@ -139,6 +143,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
"ArrayRef<OpFoldResult>":$mixedSgLayout,
"ArrayRef<OpFoldResult>":$mixedSgData,
"ArrayRef<OpFoldResult>":$mixedInstData,
+ "ArrayRef<int32_t>":$order,
"ArrayRef<int64_t>":$sliceDims,
CArg<"bool", "false">:$result,
CArg<"bool", "false">:$operand
@@ -150,6 +155,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
`sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
(`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+ (`order` `=` $order^)?
(`slice_dims` `=` $slice_dims^)?
attr-dict `:` qualified(type(operands))
}];
@@ -281,12 +287,14 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_layout,
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_data,
Variadic<TransformAnyParamTypeOrAnyHandle>:$target_inst_data,
+ DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$input_order,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_data,
- DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data,
+ DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$target_order
);
let results = (outs TransformHandleTypeInterface:$newConvertOp);
@@ -295,9 +303,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
"ArrayRef<OpFoldResult>":$mixedInputSgLayout,
"ArrayRef<OpFoldResult>":$mixedInputSgData,
"ArrayRef<OpFoldResult>":$mixedInputInstData,
+ "ArrayRef<int32_t>":$inputOrder,
"ArrayRef<OpFoldResult>":$mixedTargetSgLayout,
"ArrayRef<OpFoldResult>":$mixedTargetSgData,
- "ArrayRef<OpFoldResult>":$mixedTargetInstData
+ "ArrayRef<OpFoldResult>":$mixedTargetInstData,
+ "ArrayRef<int32_t>":$targetOrder
)>,
];
@@ -306,9 +316,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
`input_sg_layout` `=` custom<DynamicIndexList>($input_sg_layout, $static_input_sg_layout)
`input_sg_data` `=` custom<DynamicIndexList>($input_sg_data, $static_input_sg_data)
(`input_inst_data` `=` custom<DynamicIndexList>($input_inst_data, $static_input_inst_data)^)?
+ (`input_order` `=` $input_order^)?
`target_sg_layout` `=` custom<DynamicIndexList>($target_sg_layout, $static_target_sg_layout)
`target_sg_data` `=` custom<DynamicIndexList>($target_sg_data, $static_target_sg_data)
(`target_inst_data` `=` custom<DynamicIndexList>($target_inst_data, $static_target_inst_data)^)?
+ (`target_order` `=` $target_order^)?
attr-dict `:` functional-type(operands, results)
}];
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 7bc67da8263dc..39f9ae0bf1287 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -122,17 +122,16 @@ static std::optional<T> findProducerOfType(Value val) {
}
/// Create a layout attribute from the given parameters.
-static xegpu::LayoutAttr
-createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
- ArrayRef<int32_t> sgData,
- std::optional<ArrayRef<int32_t>> instData) {
+static xegpu::LayoutAttr createLayoutAttr(
+ MLIRContext *ctx, ArrayRef<int32_t> sgLayout, ArrayRef<int32_t> sgData,
+ std::optional<ArrayRef<int32_t>> instData, ArrayRef<int32_t> order) {
return xegpu::LayoutAttr::get(
ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
DenseI32ArrayAttr::get(ctx, sgData),
instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
/*lane_layout=*/nullptr,
/*lane_data=*/nullptr,
- /*order=*/nullptr);
+ /*order=*/order.empty() ? nullptr : DenseI32ArrayAttr::get(ctx, order));
}
/// Generate `xegpu::LayoutAttr` from op mixed layout values.
@@ -142,6 +141,7 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
ArrayRef<::mlir::OpFoldResult> mixedSgData,
ArrayRef<::mlir::OpFoldResult> mixedInstData,
+ ArrayRef<int32_t> order,
xegpu::LayoutAttr &layoutAttr) {
SmallVector<int32_t> sgLayout, sgData, instData;
auto status =
@@ -160,7 +160,7 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
? std::nullopt
: std::optional<ArrayRef<int32_t>>(instData);
- layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
+ layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData, order);
return DiagnosedSilenceableFailure::success();
}
@@ -215,6 +215,7 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
ArrayRef<OpFoldResult> mixedSgLayout,
ArrayRef<OpFoldResult> mixedSgData,
ArrayRef<OpFoldResult> mixedInstData,
+ ArrayRef<int32_t> order,
ArrayRef<int64_t> sliceDims) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
@@ -229,6 +230,7 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
+ /*order=*/order,
/*slice_dims=*/sliceDims);
}
@@ -244,9 +246,9 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
Operation *target = *targetOps.begin();
xegpu::LayoutAttr layoutAttr = nullptr;
- auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
- getMixedSgLayout(), getMixedSgData(),
- getMixedInstData(), layoutAttr);
+ auto status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), getOrder(), layoutAttr);
if (!status.succeeded())
return status;
@@ -290,8 +292,8 @@ void transform::SetDescLayoutOp::getEffects(
void transform::SetOpLayoutAttrOp::build(
OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
- ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
- bool result, bool operand) {
+ ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order,
+ ArrayRef<int64_t> sliceDims, bool result, bool operand) {
SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -306,6 +308,7 @@ void transform::SetOpLayoutAttrOp::build(
/*static_sg_layout=*/staticSgLayout,
/*static_sg_data=*/staticSgData,
/*static_inst_data=*/staticInstData,
+ /*order=*/order,
/*slice_dims=*/sliceDims,
/*result=*/result,
/*operand=*/operand);
@@ -336,9 +339,9 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
}
xegpu::LayoutAttr layoutAttr = nullptr;
- auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
- getMixedSgLayout(), getMixedSgData(),
- getMixedInstData(), layoutAttr);
+ auto status = getLayoutAttrFromOperands(
+ getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
+ getMixedInstData(), getOrder(), layoutAttr);
if (!status.succeeded())
return status;
@@ -600,10 +603,10 @@ void transform::ConvertLayoutOp::build(
OpBuilder &builder, OperationState &ostate, Value target,
ArrayRef<OpFoldResult> mixedInputSgLayout,
ArrayRef<OpFoldResult> mixedInputSgData,
- ArrayRef<OpFoldResult> mixedInputInstData,
+ ArrayRef<OpFoldResult> mixedInputInstData, ArrayRef<int32_t> inputOrder,
ArrayRef<OpFoldResult> mixedTargetSgLayout,
ArrayRef<OpFoldResult> mixedTargetSgData,
- ArrayRef<OpFoldResult> mixedTargetInstData) {
+ ArrayRef<OpFoldResult> mixedTargetInstData, ArrayRef<int32_t> targetOrder) {
SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
staticInputInstData;
SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
@@ -632,12 +635,14 @@ void transform::ConvertLayoutOp::build(
/*target_sg_layout=*/dynamicTargetSgLayout,
/*target_sg_data=*/dynamicTargetSgData,
/*target_inst_data=*/dynamicTargetInstData,
+ /*input_order=*/inputOrder,
/*static_input_sg_layout=*/staticInputSgLayout,
/*static_input_sg_data=*/staticInputSgData,
/*static_input_inst_data=*/staticInputInstData,
/*static_target_sg_layout=*/staticTargetSgLayout,
/*static_target_sg_data=*/staticTargetSgData,
- /*static_target_inst_data=*/staticTargetInstData);
+ /*static_target_inst_data=*/staticTargetInstData,
+ /*target_order=*/targetOrder);
}
DiagnosedSilenceableFailure
@@ -655,14 +660,16 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
xegpu::LayoutAttr inputLayoutAttr = nullptr;
auto status = getLayoutAttrFromOperands(
getContext(), state, (*this), getMixedInputSgLayout(),
- getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
+ getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
+ inputLayoutAttr);
if (!status.succeeded())
return status;
xegpu::LayoutAttr targetLayoutAttr = nullptr;
status = getLayoutAttrFromOperands(
getContext(), state, (*this), getMixedTargetSgLayout(),
- getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
+ getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
+ targetLayoutAttr);
if (!status.succeeded())
return status;
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index a768ce5f4e720..782c9a3f242a0 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -62,6 +62,7 @@ def __init__(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ order: Optional[MixedInt] = None,
slice_dims: Optional[MixedInt] = None,
loc=None,
ip=None,
@@ -93,6 +94,7 @@ def __init__(
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
+ order=order,
slice_dims=slice_dims,
loc=loc,
ip=ip,
@@ -105,6 +107,7 @@ def set_desc_layout(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ order: Optional[MixedInt] = None,
slice_dims: Optional[MixedInt] = None,
loc=None,
ip=None,
@@ -114,6 +117,7 @@ def set_desc_layout(
sg_layout,
sg_data,
inst_data=inst_data,
+ order=order,
slice_dims=slice_dims,
loc=loc,
ip=ip,
@@ -131,6 +135,7 @@ def __init__(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ order: Optional[MixedInt] = None,
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
@@ -162,6 +167,7 @@ def __init__(
static_sg_layout=static_sg_layout,
static_sg_data=static_sg_data,
static_inst_data=static_inst_data,
+ order=order,
slice_dims=slice_dims,
index=index,
result=result,
@@ -177,6 +183,7 @@ def set_op_layout_attr(
sg_data: MixedValues,
*,
inst_data: Optional[MixedValues] = None,
+ order: Optional[MixedInt] = None,
slice_dims: Optional[MixedInt] = None,
index: Optional[Union[int, Attribute]] = None,
result: Optional[Union[bool, Attribute]] = None,
@@ -189,6 +196,7 @@ def set_op_layout_attr(
sg_layout,
sg_data,
inst_data=inst_data,
+ order=order,
slice_dims=slice_dims,
index=index,
result=result,
@@ -290,6 +298,8 @@ def __init__(
*,
input_inst_data: Optional[MixedValues] = None,
target_inst_data: Optional[MixedValues] = None,
+ input_order: Optional[MixedInt] = None,
+ target_order: Optional[MixedInt] = None,
loc=None,
ip=None,
):
@@ -334,12 +344,14 @@ def __init__(
dynamic_target_sg_layout,
dynamic_target_sg_data,
dynamic_target_inst_data,
+ input_order=input_order,
static_input_sg_layout=static_input_sg_layout,
static_input_sg_data=static_input_sg_data,
static_input_inst_data=static_input_inst_data,
static_target_sg_layout=static_target_sg_layout,
static_target_sg_data=static_target_sg_data,
static_target_inst_data=static_target_inst_data,
+ target_order=target_order,
loc=loc,
ip=ip,
)
@@ -354,6 +366,8 @@ def convert_layout(
*,
input_inst_data: Optional[MixedValues] = None,
target_inst_data: Optional[MixedValues] = None,
+ input_order: Optional[MixedInt] = None,
+ target_order: Optional[MixedInt] = None,
loc=None,
ip=None,
) -> ConvertLayoutOp:
@@ -365,6 +379,8 @@ def convert_layout(
target_sg_data,
input_inst_data=input_inst_data,
target_inst_data=target_inst_data,
+ input_order=input_order,
+ target_order=target_order,
loc=loc,
ip=ip,
).result
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 9a278cbf7b498..5bb1ab708e301 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -140,6 +140,26 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @set_desc_layout_order
+func.func @set_desc_layout_order(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+ // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
+ // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+ %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] order = [1, 0] : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @set_op_layout_attr_result_default
func.func @set_op_layout_attr_result_default(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -225,6 +245,25 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-LABEL: @set_op_layout_attr_result_order
+func.func @set_op_layout_attr_result_order(%arg0: vector<256xf16>) {
+ // CHECK: = arith.extf
+ // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [0, 1]>}
+ %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [0, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
// CHECK-LABEL: @set_op_layout_attr_operand_minimal
func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
%0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -287,6 +326,27 @@ module attributes {transform.with_named_sequence} {
}
}
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_order
+func.func @set_op_layout_attr_anchor_order(%arg0: memref<4096x4096xf16>) {
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ // CHECK: = xegpu.load_nd %0[0, 0]
+ // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [1, 0]>}>
+ %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+ transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
+ transform.yield
+ }
+}
+
+
// -----
// CHECK-LABEL: @set_op_layout_attr_anchor_dpas_a
@@ -515,12 +575,12 @@ module attributes {transform.with_named_sequence} {
func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
%c0 = arith.constant 0 : index
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16], order = [1, 0]>>
// CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
- %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16], order = [1, 0]>> -> vector<256x32xf16>
// CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
- // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
- // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+ // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16], order = [1, 0]>
+ // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>
%2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
%3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
%4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
@@ -536,8 +596,8 @@ module attributes {transform.with_named_sequence} {
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
// CHECK: transform.xegpu.convert_layout %{{.*}}
transform.xegpu.convert_layout %1
- input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16]
- target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16]
+ input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16] input_order = [1, 0]
+ target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16] target_order = [1, 0]
: (!transform.any_value) -> !transform.any_op
transform.yield
}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index e3e1313cf5f81..346e68eca9201 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -85,6 +85,25 @@ def setDescLayoutSlice():
# CHECK: slice_dims = [0]
+ at run
+def setDescLayoutOrder():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.create_nd_tdesc"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_desc_layout(
+ sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], order=[0, 1]
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setDescLayoutOrder
+ # CHECK: %0 = transform.xegpu.set_desc_layout %
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: order = [0, 1]
+
+
@run
def setOpLayoutAttrOperandMinimal():
sequence = transform.SequenceOp(
@@ -164,6 +183,34 @@ def setOpLayoutAttrResultSlice():
# CHECK: slice_dims = [0]
+ at run
+def setOpLayoutAttrResultOrder():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_op_layout_attr(
+ sequence.bodyTarget,
+ index=0,
+ sg_layout=[6, 4],
+ sg_data=[32, 16],
+ inst_data=[8, 16],
+ order=[0, 1],
+ result=True,
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setOpLayoutAttrResultOrder
+ # CHECK: transform.xegpu.set_op_layout_attr %
+ # CHECK: result
+ # CHECK-NOT: index = 0
+ # CHECK: sg_layout = [6, 4]
+ # CHECK: sg_data = [32, 16]
+ # CHECK: inst_data = [8, 16]
+ # CHECK: order = [0, 1]
+
+
@run
def setOpLayoutAttrAnchor():
sequence = transform.SequenceOp(
@@ -309,9 +356,11 @@ def ConvertLayout():
input_sg_layout=[6, 4],
input_sg_data=[32, 32],
input_inst_data=[32, 16],
+ input_order=[1, 0],
target_sg_layout=[6, 4],
target_sg_data=[32, 32],
target_inst_data=[8, 16],
+ target_order=[0, 1],
)
transform.YieldOp()
# CHECK-LABEL: TEST: ConvertLayout
@@ -319,6 +368,8 @@ def ConvertLayout():
# CHECK: input_sg_layout = [6, 4]
# CHECK: input_sg_data = [32, 32]
# CHECK: input_inst_data = [32, 16]
+ # CHECK: input_order = [1, 0]
# CHECK: target_sg_layout = [6, 4]
# CHECK: target_sg_data = [32, 32]
# CHECK: target_inst_data = [8, 16]
+ # CHECK: target_order = [0, 1]
More information about the Mlir-commits
mailing list