[Mlir-commits] [mlir] ce22796 - [mlir][xegpu] Add support for setting `order` in `SetDescLayoutOp` and `SetOpLayoutAttrOp` transform ops. (#184705)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 9 08:37:34 PDT 2026


Author: Charitha Saumya
Date: 2026-03-09T08:37:28-07:00
New Revision: ce227964cc4de126c43f3458498ac70315809ce8

URL: https://github.com/llvm/llvm-project/commit/ce227964cc4de126c43f3458498ac70315809ce8
DIFF: https://github.com/llvm/llvm-project/commit/ce227964cc4de126c43f3458498ac70315809ce8.diff

LOG: [mlir][xegpu] Add support for setting `order` in `SetDescLayoutOp` and `SetOpLayoutAttrOp` transform ops. (#184705)

Currently XeGPU transform dialect does not allow the user to set the
`order` attribute of a layout in `SetDescLayoutOp` and
`SetOpLayoutAttrOp`. This PR adds `order` as an optional argument to
these transform ops.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
    mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
    mlir/python/mlir/dialects/transform/xegpu.py
    mlir/test/Dialect/XeGPU/transform-ops.mlir
    mlir/test/python/dialects/transform_xegpu_ext.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 23dabe4eb380a..f7f45508b6a03 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -58,6 +58,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims
                    );
 
@@ -67,6 +68,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
                    "ArrayRef<OpFoldResult>":$mixedSgData,
                    "ArrayRef<OpFoldResult>":$mixedInstData,
+                   "ArrayRef<int32_t>":$order,
                    "ArrayRef<int64_t>":$sliceDims
                    )>,
   ];
@@ -76,6 +78,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
     `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    (`order` `=` $order^)?
     (`slice_dims` `=` $slice_dims^)?
     attr-dict `:` functional-type(operands, results)
   }];
@@ -127,6 +130,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$order,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$slice_dims,
                    DefaultValuedAttr<UnitAttr, "false">:$result,
                    DefaultValuedAttr<UnitAttr, "false">:$operand
@@ -139,6 +143,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
                    "ArrayRef<OpFoldResult>":$mixedSgData,
                    "ArrayRef<OpFoldResult>":$mixedInstData,
+                   "ArrayRef<int32_t>":$order,
                    "ArrayRef<int64_t>":$sliceDims,
                    CArg<"bool", "false">:$result,
                    CArg<"bool", "false">:$operand
@@ -150,6 +155,7 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
     `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    (`order` `=` $order^)?
     (`slice_dims` `=` $slice_dims^)?
     attr-dict `:` qualified(type(operands))
   }];
@@ -281,12 +287,14 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
                    Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_layout,
                    Variadic<TransformAnyParamTypeOrAnyHandle>:$target_sg_data,
                    Variadic<TransformAnyParamTypeOrAnyHandle>:$target_inst_data,
+                   DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$input_order,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_input_inst_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_sg_data,
-                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_target_inst_data,
+                   DefaultValuedOptionalAttr<DenseI32ArrayAttr, "{}">:$target_order
                    );
 
   let results = (outs TransformHandleTypeInterface:$newConvertOp);
@@ -295,9 +303,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
                    "ArrayRef<OpFoldResult>":$mixedInputSgLayout,
                    "ArrayRef<OpFoldResult>":$mixedInputSgData,
                    "ArrayRef<OpFoldResult>":$mixedInputInstData,
+                   "ArrayRef<int32_t>":$inputOrder,
                    "ArrayRef<OpFoldResult>":$mixedTargetSgLayout,
                    "ArrayRef<OpFoldResult>":$mixedTargetSgData,
-                   "ArrayRef<OpFoldResult>":$mixedTargetInstData
+                   "ArrayRef<OpFoldResult>":$mixedTargetInstData,
+                   "ArrayRef<int32_t>":$targetOrder
                    )>,
   ];
 
@@ -306,9 +316,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
     `input_sg_layout` `=` custom<DynamicIndexList>($input_sg_layout, $static_input_sg_layout)
     `input_sg_data` `=` custom<DynamicIndexList>($input_sg_data, $static_input_sg_data)
     (`input_inst_data` `=` custom<DynamicIndexList>($input_inst_data, $static_input_inst_data)^)?
+    (`input_order` `=` $input_order^)?
     `target_sg_layout` `=` custom<DynamicIndexList>($target_sg_layout, $static_target_sg_layout)
     `target_sg_data` `=` custom<DynamicIndexList>($target_sg_data, $static_target_sg_data)
     (`target_inst_data` `=` custom<DynamicIndexList>($target_inst_data, $static_target_inst_data)^)?
+    (`target_order` `=` $target_order^)?
     attr-dict `:` functional-type(operands, results)
   }];
 

diff  --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 7bc67da8263dc..39f9ae0bf1287 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -122,17 +122,16 @@ static std::optional<T> findProducerOfType(Value val) {
 }
 
 /// Create a layout attribute from the given parameters.
-static xegpu::LayoutAttr
-createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
-                 ArrayRef<int32_t> sgData,
-                 std::optional<ArrayRef<int32_t>> instData) {
+static xegpu::LayoutAttr createLayoutAttr(
+    MLIRContext *ctx, ArrayRef<int32_t> sgLayout, ArrayRef<int32_t> sgData,
+    std::optional<ArrayRef<int32_t>> instData, ArrayRef<int32_t> order) {
   return xegpu::LayoutAttr::get(
       ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
       DenseI32ArrayAttr::get(ctx, sgData),
       instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
       /*lane_layout=*/nullptr,
       /*lane_data=*/nullptr,
-      /*order=*/nullptr);
+      /*order=*/order.empty() ? nullptr : DenseI32ArrayAttr::get(ctx, order));
 }
 
 /// Generate `xegpu::LayoutAttr` from op mixed layout values.
@@ -142,6 +141,7 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
                           ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
                           ArrayRef<::mlir::OpFoldResult> mixedSgData,
                           ArrayRef<::mlir::OpFoldResult> mixedInstData,
+                          ArrayRef<int32_t> order,
                           xegpu::LayoutAttr &layoutAttr) {
   SmallVector<int32_t> sgLayout, sgData, instData;
   auto status =
@@ -160,7 +160,7 @@ getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
                            ? std::nullopt
                            : std::optional<ArrayRef<int32_t>>(instData);
 
-  layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
+  layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData, order);
 
   return DiagnosedSilenceableFailure::success();
 }
@@ -215,6 +215,7 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
                                        ArrayRef<OpFoldResult> mixedSgLayout,
                                        ArrayRef<OpFoldResult> mixedSgData,
                                        ArrayRef<OpFoldResult> mixedInstData,
+                                       ArrayRef<int32_t> order,
                                        ArrayRef<int64_t> sliceDims) {
   SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
   SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
@@ -229,6 +230,7 @@ void transform::SetDescLayoutOp::build(OpBuilder &builder,
         /*static_sg_layout=*/staticSgLayout,
         /*static_sg_data=*/staticSgData,
         /*static_inst_data=*/staticInstData,
+        /*order=*/order,
         /*slice_dims=*/sliceDims);
 }
 
@@ -244,9 +246,9 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   Operation *target = *targetOps.begin();
 
   xegpu::LayoutAttr layoutAttr = nullptr;
-  auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
-                                          getMixedSgLayout(), getMixedSgData(),
-                                          getMixedInstData(), layoutAttr);
+  auto status = getLayoutAttrFromOperands(
+      getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
+      getMixedInstData(), getOrder(), layoutAttr);
   if (!status.succeeded())
     return status;
 
@@ -290,8 +292,8 @@ void transform::SetDescLayoutOp::getEffects(
 void transform::SetOpLayoutAttrOp::build(
     OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
     ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
-    ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int64_t> sliceDims,
-    bool result, bool operand) {
+    ArrayRef<OpFoldResult> mixedInstData, ArrayRef<int32_t> order,
+    ArrayRef<int64_t> sliceDims, bool result, bool operand) {
   SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
   SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
   dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
@@ -306,6 +308,7 @@ void transform::SetOpLayoutAttrOp::build(
         /*static_sg_layout=*/staticSgLayout,
         /*static_sg_data=*/staticSgData,
         /*static_inst_data=*/staticInstData,
+        /*order=*/order,
         /*slice_dims=*/sliceDims,
         /*result=*/result,
         /*operand=*/operand);
@@ -336,9 +339,9 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
   }
 
   xegpu::LayoutAttr layoutAttr = nullptr;
-  auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
-                                          getMixedSgLayout(), getMixedSgData(),
-                                          getMixedInstData(), layoutAttr);
+  auto status = getLayoutAttrFromOperands(
+      getContext(), state, (*this), getMixedSgLayout(), getMixedSgData(),
+      getMixedInstData(), getOrder(), layoutAttr);
   if (!status.succeeded())
     return status;
 
@@ -600,10 +603,10 @@ void transform::ConvertLayoutOp::build(
     OpBuilder &builder, OperationState &ostate, Value target,
     ArrayRef<OpFoldResult> mixedInputSgLayout,
     ArrayRef<OpFoldResult> mixedInputSgData,
-    ArrayRef<OpFoldResult> mixedInputInstData,
+    ArrayRef<OpFoldResult> mixedInputInstData, ArrayRef<int32_t> inputOrder,
     ArrayRef<OpFoldResult> mixedTargetSgLayout,
     ArrayRef<OpFoldResult> mixedTargetSgData,
-    ArrayRef<OpFoldResult> mixedTargetInstData) {
+    ArrayRef<OpFoldResult> mixedTargetInstData, ArrayRef<int32_t> targetOrder) {
   SmallVector<int64_t> staticInputSgLayout, staticInputSgData,
       staticInputInstData;
   SmallVector<Value> dynamicInputSgLayout, dynamicInputSgData,
@@ -632,12 +635,14 @@ void transform::ConvertLayoutOp::build(
         /*target_sg_layout=*/dynamicTargetSgLayout,
         /*target_sg_data=*/dynamicTargetSgData,
         /*target_inst_data=*/dynamicTargetInstData,
+        /*input_order=*/inputOrder,
         /*static_input_sg_layout=*/staticInputSgLayout,
         /*static_input_sg_data=*/staticInputSgData,
         /*static_input_inst_data=*/staticInputInstData,
         /*static_target_sg_layout=*/staticTargetSgLayout,
         /*static_target_sg_data=*/staticTargetSgData,
-        /*static_target_inst_data=*/staticTargetInstData);
+        /*static_target_inst_data=*/staticTargetInstData,
+        /*target_order=*/targetOrder);
 }
 
 DiagnosedSilenceableFailure
@@ -655,14 +660,16 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
   xegpu::LayoutAttr inputLayoutAttr = nullptr;
   auto status = getLayoutAttrFromOperands(
       getContext(), state, (*this), getMixedInputSgLayout(),
-      getMixedInputSgData(), getMixedInputInstData(), inputLayoutAttr);
+      getMixedInputSgData(), getMixedInputInstData(), getInputOrder(),
+      inputLayoutAttr);
   if (!status.succeeded())
     return status;
 
   xegpu::LayoutAttr targetLayoutAttr = nullptr;
   status = getLayoutAttrFromOperands(
       getContext(), state, (*this), getMixedTargetSgLayout(),
-      getMixedTargetSgData(), getMixedTargetInstData(), targetLayoutAttr);
+      getMixedTargetSgData(), getMixedTargetInstData(), getTargetOrder(),
+      targetLayoutAttr);
   if (!status.succeeded())
     return status;
 

diff  --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index a768ce5f4e720..782c9a3f242a0 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -62,6 +62,7 @@ def __init__(
         sg_data: MixedValues,
         *,
         inst_data: Optional[MixedValues] = None,
+        order: Optional[MixedInt] = None,
         slice_dims: Optional[MixedInt] = None,
         loc=None,
         ip=None,
@@ -93,6 +94,7 @@ def __init__(
             static_sg_layout=static_sg_layout,
             static_sg_data=static_sg_data,
             static_inst_data=static_inst_data,
+            order=order,
             slice_dims=slice_dims,
             loc=loc,
             ip=ip,
@@ -105,6 +107,7 @@ def set_desc_layout(
     sg_data: MixedValues,
     *,
     inst_data: Optional[MixedValues] = None,
+    order: Optional[MixedInt] = None,
     slice_dims: Optional[MixedInt] = None,
     loc=None,
     ip=None,
@@ -114,6 +117,7 @@ def set_desc_layout(
         sg_layout,
         sg_data,
         inst_data=inst_data,
+        order=order,
         slice_dims=slice_dims,
         loc=loc,
         ip=ip,
@@ -131,6 +135,7 @@ def __init__(
         sg_data: MixedValues,
         *,
         inst_data: Optional[MixedValues] = None,
+        order: Optional[MixedInt] = None,
         slice_dims: Optional[MixedInt] = None,
         index: Optional[Union[int, Attribute]] = None,
         result: Optional[Union[bool, Attribute]] = None,
@@ -162,6 +167,7 @@ def __init__(
             static_sg_layout=static_sg_layout,
             static_sg_data=static_sg_data,
             static_inst_data=static_inst_data,
+            order=order,
             slice_dims=slice_dims,
             index=index,
             result=result,
@@ -177,6 +183,7 @@ def set_op_layout_attr(
     sg_data: MixedValues,
     *,
     inst_data: Optional[MixedValues] = None,
+    order: Optional[MixedInt] = None,
     slice_dims: Optional[MixedInt] = None,
     index: Optional[Union[int, Attribute]] = None,
     result: Optional[Union[bool, Attribute]] = None,
@@ -189,6 +196,7 @@ def set_op_layout_attr(
         sg_layout,
         sg_data,
         inst_data=inst_data,
+        order=order,
         slice_dims=slice_dims,
         index=index,
         result=result,
@@ -290,6 +298,8 @@ def __init__(
         *,
         input_inst_data: Optional[MixedValues] = None,
         target_inst_data: Optional[MixedValues] = None,
+        input_order: Optional[MixedInt] = None,
+        target_order: Optional[MixedInt] = None,
         loc=None,
         ip=None,
     ):
@@ -334,12 +344,14 @@ def __init__(
             dynamic_target_sg_layout,
             dynamic_target_sg_data,
             dynamic_target_inst_data,
+            input_order=input_order,
             static_input_sg_layout=static_input_sg_layout,
             static_input_sg_data=static_input_sg_data,
             static_input_inst_data=static_input_inst_data,
             static_target_sg_layout=static_target_sg_layout,
             static_target_sg_data=static_target_sg_data,
             static_target_inst_data=static_target_inst_data,
+            target_order=target_order,
             loc=loc,
             ip=ip,
         )
@@ -354,6 +366,8 @@ def convert_layout(
     *,
     input_inst_data: Optional[MixedValues] = None,
     target_inst_data: Optional[MixedValues] = None,
+    input_order: Optional[MixedInt] = None,
+    target_order: Optional[MixedInt] = None,
     loc=None,
     ip=None,
 ) -> ConvertLayoutOp:
@@ -365,6 +379,8 @@ def convert_layout(
         target_sg_data,
         input_inst_data=input_inst_data,
         target_inst_data=target_inst_data,
+        input_order=input_order,
+        target_order=target_order,
         loc=loc,
         ip=ip,
     ).result

diff  --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 9a278cbf7b498..5bb1ab708e301 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -140,6 +140,26 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @set_desc_layout_order
+func.func @set_desc_layout_order(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] order = [1, 0] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @set_op_layout_attr_result_default
 func.func @set_op_layout_attr_result_default(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -225,6 +245,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @set_op_layout_attr_result_order
+func.func @set_op_layout_attr_result_order(%arg0: vector<256xf16>) {
+  // CHECK: = arith.extf
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [0, 1]>}
+  %2 = arith.extf %arg0 : vector<256xf16> to vector<256xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [0, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @set_op_layout_attr_operand_minimal
 func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
@@ -287,6 +326,27 @@ module attributes {transform.with_named_sequence} {
   }
 }
 
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_anchor_order
+func.func @set_op_layout_attr_anchor_order(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  // CHECK: = xegpu.load_nd %0[0, 0]
+  // CHECK-SAME: <{layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16], order = [1, 0]>}>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.load_nd"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] order = [1, 0] : !transform.any_op
+    transform.yield
+  }
+}
+
+
 // -----
 
 // CHECK-LABEL: @set_op_layout_attr_anchor_dpas_a
@@ -515,12 +575,12 @@ module attributes {transform.with_named_sequence} {
 func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
   %c0 = arith.constant 0 : index
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
-  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16], order = [1, 0]>>
   // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
-  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16], order = [1, 0]>> -> vector<256x32xf16>
   // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
-  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
-  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16], order = [1, 0]>
+  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16], order = [1, 0]>
   %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
   %3 = xegpu.load_nd %2[%c0, %c0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
   %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
@@ -536,8 +596,8 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     // CHECK: transform.xegpu.convert_layout %{{.*}}
     transform.xegpu.convert_layout %1
-      input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16]
-      target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16]
+      input_sg_layout = [8, 4] input_sg_data = [32, 32] input_inst_data = [32, 16] input_order = [1, 0]
+      target_sg_layout = [8, 4] target_sg_data = [32, 32] target_inst_data = [8, 16] target_order = [1, 0]
       : (!transform.any_value) -> !transform.any_op
     transform.yield
   }

diff  --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index e3e1313cf5f81..346e68eca9201 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -85,6 +85,25 @@ def setDescLayoutSlice():
     # CHECK: slice_dims = [0]
 
 
+ at run
+def setDescLayoutOrder():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.set_desc_layout(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], order=[0, 1]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutOrder
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: order = [0, 1]
+
+
 @run
 def setOpLayoutAttrOperandMinimal():
     sequence = transform.SequenceOp(
@@ -164,6 +183,34 @@ def setOpLayoutAttrResultSlice():
     # CHECK: slice_dims = [0]
 
 
+ at run
+def setOpLayoutAttrResultOrder():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.set_op_layout_attr(
+            sequence.bodyTarget,
+            index=0,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+            inst_data=[8, 16],
+            order=[0, 1],
+            result=True,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setOpLayoutAttrResultOrder
+    # CHECK: transform.xegpu.set_op_layout_attr %
+    # CHECK: result
+    # CHECK-NOT: index = 0
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]
+    # CHECK: order = [0, 1]
+
+
 @run
 def setOpLayoutAttrAnchor():
     sequence = transform.SequenceOp(
@@ -309,9 +356,11 @@ def ConvertLayout():
             input_sg_layout=[6, 4],
             input_sg_data=[32, 32],
             input_inst_data=[32, 16],
+            input_order=[1, 0],
             target_sg_layout=[6, 4],
             target_sg_data=[32, 32],
             target_inst_data=[8, 16],
+            target_order=[0, 1],
         )
         transform.YieldOp()
     # CHECK-LABEL: TEST: ConvertLayout
@@ -319,6 +368,8 @@ def ConvertLayout():
     # CHECK: input_sg_layout = [6, 4]
     # CHECK: input_sg_data = [32, 32]
     # CHECK: input_inst_data = [32, 16]
+    # CHECK: input_order = [1, 0]
     # CHECK: target_sg_layout = [6, 4]
     # CHECK: target_sg_data = [32, 32]
     # CHECK: target_inst_data = [8, 16]
+    # CHECK: target_order = [0, 1]


        


More information about the Mlir-commits mailing list