[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add set_op_layout_attr op (PR #166854)

Tuomas Kärnä llvmlistbot at llvm.org
Mon Nov 10 03:13:37 PST 2025


https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/166854

>From a03869960f28ea0bc6c1ed62785d57a26e08b555 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 09:23:17 +0200
Subject: [PATCH 1/2] [mlir][xegpu][transformops] add set_op_layout_attr op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  65 +++++++++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 125 +++++++++++++---
 mlir/python/mlir/dialects/transform/xegpu.py  |  47 ++++++
 .../Dialect/XeGPU/transform-ops-invalid.mlir  |  58 ++++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 134 ++++++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    |  49 +++++++
 6 files changed, 459 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..4e0eae1007c8f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,69 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
+def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute of an op.";
+  let description = [{
+    Sets the `xegpu.layout` attribute of an op. If `result=true`, sets the
+    `layout_result_{index}`, otherwise `layout_operand_{index}` attribute. The
+    target operand/result value is defined by the `index` argument. The layout
+    is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface : $target,
+                   DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
+                   DefaultValuedAttr<UnitAttr, "false">:$result
+                   );
+
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "int64_t":$index,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData,
+                   CArg<"bool", "false">:$result
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target (`result` $result^)? (`index` `=` $index^)?
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..456cfb9ddd2bc 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -90,6 +90,38 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
       /*order=*/nullptr);
 }
 
+/// Generate `xegpu::LayoutAttr` from op mixed layout values.
+DiagnosedSilenceableFailure
+getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
+                          transform::TransformState &state,
+                          TransformOpInterface transformOp,
+                          ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
+                          ArrayRef<::mlir::OpFoldResult> mixedSgData,
+                          ArrayRef<::mlir::OpFoldResult> mixedInstData,
+                          xegpu::LayoutAttr &layoutAttr) {
+  SmallVector<int32_t> sgLayout, sgData, instData;
+  auto status =
+      convertMixedValuesToInt(state, transformOp, sgLayout, mixedSgLayout);
+  if (!status.succeeded())
+    return status;
+
+  status = convertMixedValuesToInt(state, transformOp, sgData, mixedSgData);
+  if (!status.succeeded())
+    return status;
+
+  status = convertMixedValuesToInt(state, transformOp, instData, mixedInstData);
+  if (!status.succeeded())
+    return status;
+  auto maybeInstData = instData.empty()
+                           ? std::nullopt
+                           : std::optional<ArrayRef<int32_t>>(instData);
+
+  layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 /// Replace xegpu.create_nd_desc op with a new one with the given layout.
 static xegpu::CreateNdDescOp
 setDescLayout(transform::TransformRewriter &rewriter,
@@ -142,26 +174,13 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
   Operation *target = *targetOps.begin();
 
-  SmallVector<int32_t> sgLayout;
-  DiagnosedSilenceableFailure status =
-      convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
   if (!status.succeeded())
     return status;
 
-  SmallVector<int32_t> sgData;
-  status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
-  if (!status.succeeded())
-    return status;
-
-  SmallVector<int32_t> instData;
-  status =
-      convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
-  if (!status.succeeded())
-    return status;
-  auto maybeInstData = instData.empty()
-                           ? std::nullopt
-                           : std::optional<ArrayRef<int32_t>>(instData);
-
   // For now only create_nd_desc op is supported.
   auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
   if (!descOp) {
@@ -173,8 +192,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto layoutAttr =
-      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
   auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
 
   // Map result handles.
@@ -193,6 +210,76 @@ void transform::SetDescLayoutOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::SetOpLayoutAttrOp::build(
+    OpBuilder &builder, OperationState &ostate, Value target, int64_t index,
+    ArrayRef<OpFoldResult> mixedSgLayout, ArrayRef<OpFoldResult> mixedSgData,
+    ArrayRef<OpFoldResult> mixedInstData, bool result) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*index=*/index,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData,
+        /*result=*/result);
+}
+
+DiagnosedSilenceableFailure
+transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
+                                    transform::TransformResults &results,
+                                    transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  bool resultTarget = getResult();
+
+  int64_t index = getIndex();
+  if (resultTarget && index >= target->getNumResults()) {
+    return emitSilenceableFailure(getLoc())
+           << "Index exceeds the number of op results";
+  }
+  if (!resultTarget && index >= target->getNumOperands()) {
+    return emitSilenceableFailure(getLoc())
+           << "Index exceeds the number of op operands";
+  }
+
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
+  if (!status.succeeded())
+    return status;
+
+  // Set layout attribute for the op result or operand
+  if (resultTarget) {
+    xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
+  } else {
+    xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetOpLayoutAttrOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..46a1f032630d1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,50 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
+    """Specialization for SetOpLayoutAttrOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        *,
+        inst_data: MixedValues = None,
+        index: Union[int, Attribute] = None,
+        result: Union[bool, Attribute] = None,
+        loc=None,
+        ip=None,
+    ):
+        inst_data = [] if inst_data is None else inst_data
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+        super().__init__(
+            _get_op_result_or_value(target),
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            index=index,
+            result=result,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..726b6748452ae 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,61 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_result_index
+func.func @set_op_layout_attr_bad_result_index(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Index exceeds the number of op results}}
+    transform.xegpu.set_op_layout_attr %0 result index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_bad_operand_index
+func.func @set_op_layout_attr_bad_operand_index(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Index exceeds the number of op operands}}
+    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_multiple
+func.func @set_op_layout_attr_multiple(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  %3 = arith.extf %2 : vector<256x32xf32> to vector<256x32xf64>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+    transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..089a8fb4fd9b6 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,137 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_default_index
+func.func @set_op_layout_attr_result_default_index(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[0, 0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param
+func.func @set_op_layout_attr_result_sg_param(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result_sg_param2
+func.func @set_op_layout_attr_result_sg_param2(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    %layout1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    transform.xegpu.set_op_layout_attr %0 result sg_layout = [%layout0, %layout1] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op, !transform.param<i64>, !transform.param<i64>
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_result0
+func.func @set_op_layout_attr_result0(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 result index = 0 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand_minimal
+func.func @set_op_layout_attr_operand_minimal(%arg0: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.extf %1
+  // CHECK-SAME: {layout_operand_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64]>}
+  %2 = arith.extf %1 : vector<256x32xf16> to vector<256x32xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.extf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 sg_layout = [8, 4] sg_data = [32, 64] : !transform.any_op
+    transform.yield
+  }
+}
+// -----
+
+// CHECK-LABEL: @set_op_layout_attr_operand1
+func.func @set_op_layout_attr_operand1(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>) {
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %3 = xegpu.load_nd %2[0, 0]  : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // CHECK: = arith.addf %1, %3
+  // CHECK-SAME: {layout_operand_1 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], inst_data = [8, 16]>}
+  %6 = arith.addf %1, %3 : vector<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_op_layout_attr %{{.*}}
+    transform.xegpu.set_op_layout_attr %0 index = 1 sg_layout = [8, 4] sg_data = [32, 64] inst_data = [8, 16] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..0f48ef9dc529f 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -49,3 +49,52 @@ def setDescLayoutInstData():
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+
+
+ at run
+def setOpLayoutAttrOperandMinimal():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetOpLayoutAttrOp(
+            sequence.bodyTarget,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setOpLayoutAttr
+    # CHECK: transform.xegpu.set_op_layout_attr %
+    # NO-CHECK: index = 0
+    # NO-CHECK: result
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # NO-CHECK: inst_data
+
+
+ at run
+def setOpLayoutAttrResult():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetOpLayoutAttrOp(
+            sequence.bodyTarget,
+            index=0,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+            inst_data=[8, 16],
+            result=True,
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setOpLayoutAttr
+    # CHECK: transform.xegpu.set_op_layout_attr %
+    # NO-CHECK: index = 0
+    # CHECK: result
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]

>From 9dddbbc6abe3c7b99d5681e5d77457f5e69169c9 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Mon, 10 Nov 2025 13:08:44 +0200
Subject: [PATCH 2/2] address review comments

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 20 +++++++++----------
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 16 ++++++---------
 mlir/python/mlir/dialects/transform/xegpu.py  |  6 +++---
 3 files changed, 19 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 4e0eae1007c8f..b3a905baff823 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -31,16 +31,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 
   let arguments = (ins
-                   TransformHandleTypeInterface : $target,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   TransformHandleTypeInterface:$target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
                    );
 
-  let results = (outs TransformHandleTypeInterface : $transformed);
+  let results = (outs TransformHandleTypeInterface:$transformed);
   let builders = [
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -92,11 +92,11 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
     is defined by the `sg_layout`, `sg_data` and optional `inst_data` attributes.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface : $target,
-                   DefaultValuedOptionalAttr<I64Attr, "0"> : $index,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                   DefaultValuedOptionalAttr<I64Attr, "0">:$index,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 456cfb9ddd2bc..4b8824ba743d6 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -92,8 +92,7 @@ createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
 
 /// Generate `xegpu::LayoutAttr` from op mixed layout values.
 DiagnosedSilenceableFailure
-getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
-                          transform::TransformState &state,
+getLayoutAttrFromOperands(MLIRContext *ctx, transform::TransformState &state,
                           TransformOpInterface transformOp,
                           ArrayRef<::mlir::OpFoldResult> mixedSgLayout,
                           ArrayRef<::mlir::OpFoldResult> mixedSgData,
@@ -116,8 +115,7 @@ getLayoutAttrFromOperands(transform::TransformRewriter &rewriter,
                            ? std::nullopt
                            : std::optional<ArrayRef<int32_t>>(instData);
 
-  layoutAttr =
-      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+  layoutAttr = createLayoutAttr(ctx, sgLayout, sgData, maybeInstData);
 
   return DiagnosedSilenceableFailure::success();
 }
@@ -175,7 +173,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   Operation *target = *targetOps.begin();
 
   xegpu::LayoutAttr layoutAttr = nullptr;
-  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+  auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
                                           getMixedSgLayout(), getMixedSgData(),
                                           getMixedInstData(), layoutAttr);
   if (!status.succeeded())
@@ -235,7 +233,6 @@ DiagnosedSilenceableFailure
 transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
                                     transform::TransformResults &results,
                                     transform::TransformState &state) {
-
   auto targetOps = state.getPayloadOps(getTarget());
   if (!llvm::hasSingleElement(targetOps)) {
     return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
@@ -256,18 +253,17 @@ transform::SetOpLayoutAttrOp::apply(transform::TransformRewriter &rewriter,
   }
 
   xegpu::LayoutAttr layoutAttr = nullptr;
-  auto status = getLayoutAttrFromOperands(rewriter, state, (*this),
+  auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
                                           getMixedSgLayout(), getMixedSgData(),
                                           getMixedInstData(), layoutAttr);
   if (!status.succeeded())
     return status;
 
   // Set layout attribute for the op result or operand
-  if (resultTarget) {
+  if (resultTarget)
     xegpu::setDistributeLayoutAttr(target->getResult(index), layoutAttr);
-  } else {
+  else
     xegpu::setDistributeLayoutAttr(target->getOpOperand(index), layoutAttr);
-  }
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 46a1f032630d1..ffec10ed6439f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -76,9 +76,9 @@ def __init__(
         sg_layout: MixedValues,
         sg_data: MixedValues,
         *,
-        inst_data: MixedValues = None,
-        index: Union[int, Attribute] = None,
-        result: Union[bool, Attribute] = None,
+        inst_data: Optional[MixedValues] = None,
+        index: Optional[Union[int, Attribute]] = None,
+        result: Optional[Union[bool, Attribute]] = None,
         loc=None,
         ip=None,
     ):



More information about the Mlir-commits mailing list