[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add convert_layout op (PR #167342)

Tuomas Kärnä llvmlistbot at llvm.org
Tue Nov 11 04:38:17 PST 2025


https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/167342

>From b54b8c6361247065bfe4d1e158395ac3d25e81eb Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 09:48:02 +0200
Subject: [PATCH 1/6] [mlir][xegpu][transformops] add convert_layout op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 62 +++++++++++++++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 79 +++++++++++++++++++
 mlir/python/mlir/dialects/transform/xegpu.py  | 43 ++++++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 63 +++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    | 44 +++++++++++
 5 files changed, 291 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 34f333e556deb..b33b0a6110b1e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -161,4 +161,66 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
   }];
 }
 
+def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Convert xegpu.layout attribute for a value.";
+  let description = [{
+    Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
+    of a value. First, the `xegpu.load_nd` producer op of the value is found.
+    It must already be annotated with a layout. An `xegpu.convert_layout` op,
+    whose destination layout is defined by the `sg_layout`, `sg_data` and
+    optional `inst_data` attributes, is inserted after the load op.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   );
+
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 5fdd8534e4e51..45c76a7859a19 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -341,6 +341,85 @@ void transform::SetOpLayoutAttrOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::ConvertLayoutOp::build(OpBuilder &builder,
+                                       OperationState &ostate, Value target,
+                                       ArrayRef<OpFoldResult> mixedSgLayout,
+                                       ArrayRef<OpFoldResult> mixedSgData,
+                                       ArrayRef<OpFoldResult> mixedInstData) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+
+  auto value = *targetValues.begin();
+
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
+  if (!status.succeeded())
+    return status;
+
+  // Get load op.
+  auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+  if (!maybeLoadOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+  }
+  auto loadOp = *maybeLoadOp;
+  // Get load op operand value layout
+  auto producerLayoutAttr =
+      xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
+  if (!producerLayoutAttr) {
+    return emitSilenceableFailure(getLoc())
+           << "Operand producer op does not have a layout attr.";
+  }
+
+  if (producerLayoutAttr != layoutAttr) {
+    rewriter.setInsertionPointAfter(loadOp.getOperation());
+    auto source = loadOp.getResult();
+    auto convLayoutOp = xegpu::ConvertLayoutOp::create(
+        rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
+        layoutAttr);
+    // Replace load op result with the converted layout.
+    rewriter.replaceUsesWithIf(
+        source, convLayoutOp.getResult(), [&](OpOperand &use) {
+          return use.getOwner() != convLayoutOp.getOperation();
+        });
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ConvertLayoutOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index ce8015d8f557b..6bf8ad3064be1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -132,3 +132,46 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConvertLayoutOp(ConvertLayoutOp):
+    """Specialization for ConvertLayoutOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        *,
+        inst_data: Optional[MixedValues] = None,
+        loc=None,
+        ip=None,
+    ):
+        inst_data = [] if inst_data is None else inst_data
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+        super().__init__(
+            target,
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index bd6a79244ed30..2a914d7604ba9 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -252,3 +252,66 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @convert_layout_a
+func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+  // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+  // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
+  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas %[[V2]]
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.convert_layout %{{.*}}
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @convert_layout_a_sg_param
+func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+  // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+  // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
+  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas %[[V2]]
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.convert_layout %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 0b587d2020aa6..4fda801e48964 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -113,3 +113,47 @@ def setOpLayoutAttrResult():
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+
+
+ at run
+def ConvertLayoutMinimal():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        xegpu.ConvertLayoutOp(
+            operand,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: ConvertLayoutMinimal
+    # CHECK: transform.xegpu.convert_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+
+
+ at run
+def ConvertLayout():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
+        xegpu.ConvertLayoutOp(
+            operand,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+            inst_data=[8, 16],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: ConvertLayout
+    # CHECK: transform.xegpu.convert_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]

>From cbac6d9cb6f7ffe65792b42baee89da5c4c278a6 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 11 Nov 2025 10:39:01 +0200
Subject: [PATCH 2/6] remove braces

---
 .../Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp   | 10 +++-------
 1 file changed, 3 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 45c76a7859a19..7494ed9b8f622 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -366,12 +366,10 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
                                   transform::TransformResults &results,
                                   transform::TransformState &state) {
   auto targetValues = state.getPayloadValues(getTarget());
-  if (!llvm::hasSingleElement(targetValues)) {
+  if (!llvm::hasSingleElement(targetValues))
     return emitDefiniteFailure()
            << "requires exactly one target value handle (got "
            << llvm::range_size(targetValues) << ")";
-  }
-
   auto value = *targetValues.begin();
 
   xegpu::LayoutAttr layoutAttr = nullptr;
@@ -383,17 +381,15 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
 
   // Get load op.
   auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
-  if (!maybeLoadOp) {
+  if (!maybeLoadOp)
     return emitSilenceableFailure(getLoc()) << "Could not find load op.";
-  }
   auto loadOp = *maybeLoadOp;
   // Get load op operand value layout
   auto producerLayoutAttr =
       xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
-  if (!producerLayoutAttr) {
+  if (!producerLayoutAttr)
     return emitSilenceableFailure(getLoc())
            << "Operand producer op does not have a layout attr.";
-  }
 
   if (producerLayoutAttr != layoutAttr) {
     rewriter.setInsertionPointAfter(loadOp.getOperation());

>From 10b4fd6bec15dbbb82f0d9b145aed2463aa77718 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 11 Nov 2025 10:46:10 +0200
Subject: [PATCH 3/6] add snake_case python binding wrappers

---
 mlir/python/mlir/dialects/transform/xegpu.py  | 70 +++++++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    | 14 ++--
 2 files changed, 77 insertions(+), 7 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 6bf8ad3064be1..fb2d0d307326c 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -41,6 +41,15 @@ def __init__(
         )
 
 
+def get_desc_op(
+    target: Value,
+    *,
+    loc=None,
+    ip=None,
+) -> GetDescOp:
+    return GetDescOp(target, loc=loc, ip=ip)
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class SetDescLayoutOp(SetDescLayoutOp):
     """Specialization for SetDescLayoutOp class."""
@@ -87,6 +96,25 @@ def __init__(
         )
 
 
+def set_desc_layout(
+    target: Union[Operation, Value],
+    sg_layout: MixedValues,
+    sg_data: MixedValues,
+    *,
+    inst_data: Optional[MixedValues] = None,
+    loc=None,
+    ip=None,
+) -> SetDescLayoutOp:
+    return SetDescLayoutOp(
+        target,
+        sg_layout,
+        sg_data,
+        inst_data=inst_data,
+        loc=loc,
+        ip=ip,
+    )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class SetOpLayoutAttrOp(SetOpLayoutAttrOp):
     """Specialization for SetOpLayoutAttrOp class."""
@@ -134,6 +162,29 @@ def __init__(
         )
 
 
+def set_op_layout_attr(
+    target: Union[Operation, Value],
+    sg_layout: MixedValues,
+    sg_data: MixedValues,
+    *,
+    inst_data: Optional[MixedValues] = None,
+    index: Optional[Union[int, Attribute]] = None,
+    result: Optional[Union[bool, Attribute]] = None,
+    loc=None,
+    ip=None,
+) -> SetOpLayoutAttrOp:
+    return SetOpLayoutAttrOp(
+        target,
+        sg_layout,
+        sg_data,
+        inst_data=inst_data,
+        index=index,
+        result=result,
+        loc=loc,
+        ip=ip,
+    )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class ConvertLayoutOp(ConvertLayoutOp):
     """Specialization for ConvertLayoutOp class."""
@@ -175,3 +226,22 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+def convert_layout(
+    target: Value,
+    sg_layout: MixedValues,
+    sg_data: MixedValues,
+    *,
+    inst_data: Optional[MixedValues] = None,
+    loc=None,
+    ip=None,
+) -> ConvertLayoutOp:
+    return ConvertLayoutOp(
+        target,
+        sg_layout,
+        sg_data,
+        inst_data=inst_data,
+        loc=loc,
+        ip=ip,
+    )
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 4fda801e48964..a578b6465aa74 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -25,7 +25,7 @@ def getDescOpDefaultIndex():
     )
     with InsertionPoint(sequence.body):
         operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
-        desc_handle = xegpu.GetDescOp(operand)
+        desc_handle = xegpu.get_desc_op(operand)
         transform.YieldOp()
     # CHECK-LABEL: TEST: getDescOpDefaultIndex
     # CHECK: transform.xegpu.get_desc_op %
@@ -39,7 +39,7 @@ def setDescLayoutMinimal():
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
+        xegpu.set_desc_layout(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
         transform.YieldOp()
     # CHECK-LABEL: TEST: setDescLayoutMinimal
     # CHECK: %0 = transform.xegpu.set_desc_layout %
@@ -55,7 +55,7 @@ def setDescLayoutInstData():
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(
+        xegpu.set_desc_layout(
             sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
         )
         transform.YieldOp()
@@ -74,7 +74,7 @@ def setOpLayoutAttrOperandMinimal():
         transform.OperationType.get("xegpu.dpas"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetOpLayoutAttrOp(
+        xegpu.set_op_layout_attr(
             sequence.bodyTarget,
             sg_layout=[6, 4],
             sg_data=[32, 16],
@@ -97,7 +97,7 @@ def setOpLayoutAttrResult():
         transform.OperationType.get("xegpu.dpas"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetOpLayoutAttrOp(
+        xegpu.set_op_layout_attr(
             sequence.bodyTarget,
             index=0,
             sg_layout=[6, 4],
@@ -124,7 +124,7 @@ def ConvertLayoutMinimal():
     )
     with InsertionPoint(sequence.body):
         operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
-        xegpu.ConvertLayoutOp(
+        xegpu.convert_layout(
             operand,
             sg_layout=[6, 4],
             sg_data=[32, 16],
@@ -145,7 +145,7 @@ def ConvertLayout():
     )
     with InsertionPoint(sequence.body):
         operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
-        xegpu.ConvertLayoutOp(
+        xegpu.convert_layout(
             operand,
             sg_layout=[6, 4],
             sg_data=[32, 16],

>From 2caf598edbe3779e0a9313c24c3f000fa8f82079 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 11 Nov 2025 12:34:05 +0200
Subject: [PATCH 4/6] generalize mechanism to find producer layout

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  9 +--
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 60 +++++++++++++------
 .../Dialect/XeGPU/transform-ops-invalid.mlir  | 21 +++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 30 +++++++++-
 4 files changed, 98 insertions(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b33b0a6110b1e..2549d333252c8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -170,10 +170,11 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
   let summary = "Convert xegpu.layout attribute for a value.";
   let description = [{
     Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
-    of a value. First, the `xegpu.load_nd` producer op of the value is found.
-    It must already be annotated with a layout. An `xegpu.convert_layout` op,
-    whose destination layout is defined by the `sg_layout`, `sg_data` and
-    optional `inst_data` attributes, is inserted after the load op.
+    of a value. The source layout is inferred by inspecting the producer ops. A
+    failure is emitted if source layout cannot be found. An
+    `xegpu.convert_layout` op, whose destination layout is defined by the
+    `sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
+    before the first use of the value.
   }];
 
   let arguments = (ins TransformValueHandleTypeInterface:$target,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 7494ed9b8f622..f701ab1c753db 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -118,6 +118,33 @@ static std::optional<T> findProducerOfType(Value val) {
   return findProducerOfType<T>(producerOp->getOperand(0));
 }
 
+/// Find layout attribute in producer chain.
+/// Traces producer ops until a layout attribute is found. Only traces through
+/// ops with a single operand, in other cases the op's result layout attribute
+/// must be set. Returns std::nullopt if no layout attribute is found.
+xegpu::LayoutAttr findProducerLayout(Value val) {
+  // Get layout attr from value or producer's attribute or operand.
+  if (auto layoutAttr = dyn_cast_if_present<xegpu::LayoutAttr>(
+          xegpu::getDistributeLayoutAttr(val)))
+    return layoutAttr;
+
+  // Recurse up the producer chain.
+  Operation *producerOp = val.getDefiningOp();
+  if (!producerOp) {
+    LDBG() << "Failed to find producer op.";
+    return nullptr;
+  }
+  if (producerOp->getNumOperands() == 0) {
+    LDBG() << "Producer has no operands.";
+    return nullptr;
+  }
+  if (producerOp->getNumOperands() > 1) {
+    LDBG() << "Producer has multiple operands.";
+    return nullptr;
+  }
+  return findProducerLayout(producerOp->getOperand(0));
+}
+
 /// Create a layout attribute from the given parameters.
 static xegpu::LayoutAttr
 createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -372,34 +399,33 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
            << llvm::range_size(targetValues) << ")";
   auto value = *targetValues.begin();
 
-  xegpu::LayoutAttr layoutAttr = nullptr;
+  xegpu::LayoutAttr targetLayoutAttr = nullptr;
   auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
                                           getMixedSgLayout(), getMixedSgData(),
-                                          getMixedInstData(), layoutAttr);
+                                          getMixedInstData(), targetLayoutAttr);
   if (!status.succeeded())
     return status;
 
-  // Get load op.
-  auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
-  if (!maybeLoadOp)
-    return emitSilenceableFailure(getLoc()) << "Could not find load op.";
-  auto loadOp = *maybeLoadOp;
-  // Get load op operand value layout
-  auto producerLayoutAttr =
-      xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
+  // Find source layout attribute from the producer chain.
+  auto producerLayoutAttr = findProducerLayout(value);
   if (!producerLayoutAttr)
     return emitSilenceableFailure(getLoc())
-           << "Operand producer op does not have a layout attr.";
+           << "Could not find a layout attribute in the producer chain.";
+
+  // Find first user op to define insertion point for layout conversion.
+  if (value.use_empty())
+    return emitSilenceableFailure(getLoc())
+           << "Value has no users to insert layout conversion.";
+  Operation *userOp = *value.getUsers().begin();
 
-  if (producerLayoutAttr != layoutAttr) {
-    rewriter.setInsertionPointAfter(loadOp.getOperation());
-    auto source = loadOp.getResult();
+  if (producerLayoutAttr != targetLayoutAttr) {
+    rewriter.setInsertionPoint(userOp);
     auto convLayoutOp = xegpu::ConvertLayoutOp::create(
-        rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
-        layoutAttr);
+        rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
+        targetLayoutAttr);
     // Replace load op result with the converted layout.
     rewriter.replaceUsesWithIf(
-        source, convLayoutOp.getResult(), [&](OpOperand &use) {
+        value, convLayoutOp.getResult(), [&](OpOperand &use) {
           return use.getOwner() != convLayoutOp.getOperation();
         });
   }
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 726b6748452ae..73cb74a701d3f 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -71,3 +71,24 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @convert_layout_no_producer_attr
+func.func @convert_layout_no_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
+  %c0 = arith.constant 0 : index
+  %0 = arith.addf %arg0, %arg1 : vector<32x32xf16>
+  %1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
+  %2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // expected-error at below {{Could not find a layout attribute in the producer chain.}}
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 2a914d7604ba9..bc8772e391bd7 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -309,9 +309,37 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
-    // CHECK: transform.xegpu.convert_layout %{{.*}}
     %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    // CHECK: transform.xegpu.convert_layout %{{.*}}
     transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @convert_layout_producer_attr
+func.func @convert_layout_producer_attr(%arg0: vector<32x32xf16>, %arg1: vector<32x32xf16>) {
+  %c0 = arith.constant 0 : index
+  %0 = arith.addf %arg0, %arg1 {layout_result_0 =
+        #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>} :
+        vector<32x32xf16>
+  // CHECK: %[[V0:.+]] = arith.extf
+  %1 = arith.extf %0 : vector<32x32xf16> to vector<32x32xf32>
+  // CHECK: %[[V1:.+]] = xegpu.convert_layout %[[V0]]
+  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+  // CHECK: %[[V0:.+]] = arith.truncf %[[V1]]
+  %2 = arith.truncf %1 : vector<32x32xf32> to vector<32x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.convert_layout %{{.*}}
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.yield
+  }
+}

>From 3cbe6cb31d2b7829d3e4069793d7e697926e8a96 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 11 Nov 2025 14:09:36 +0200
Subject: [PATCH 5/6] snake_case python functions returns op result value(s) if
 any

---
 mlir/python/mlir/dialects/transform/xegpu.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index fb2d0d307326c..3ca902922b850 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -46,8 +46,8 @@ def get_desc_op(
     *,
     loc=None,
     ip=None,
-) -> GetDescOp:
-    return GetDescOp(target, loc=loc, ip=ip)
+) -> OpResult:
+    return GetDescOp(target, loc=loc, ip=ip).result
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -104,7 +104,7 @@ def set_desc_layout(
     inst_data: Optional[MixedValues] = None,
     loc=None,
     ip=None,
-) -> SetDescLayoutOp:
+) -> OpResult:
     return SetDescLayoutOp(
         target,
         sg_layout,
@@ -112,7 +112,7 @@ def set_desc_layout(
         inst_data=inst_data,
         loc=loc,
         ip=ip,
-    )
+    ).result
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)

>From 20e59c53f56a2ffb2f08ec7ec9ea8181dc303631 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 11 Nov 2025 14:27:01 +0200
Subject: [PATCH 6/6] convert_layout transform op returns a handle to created
 convert op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  7 +++---
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 25 ++++++++++---------
 mlir/python/mlir/dialects/transform/xegpu.py  |  3 ++-
 .../Dialect/XeGPU/transform-ops-invalid.mlir  |  2 +-
 mlir/test/Dialect/XeGPU/transform-ops.mlir    |  6 ++---
 5 files changed, 23 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 2549d333252c8..53e9230a9a1f1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -174,7 +174,8 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
     failure is emitted if source layout cannot be found. An
     `xegpu.convert_layout` op, whose destination layout is defined by the
     `sg_layout`, `sg_data` and optional `inst_data` attributes, is emitted
-    before the first use of the value.
+    before the first use of the value. Returns a handle to the emitted
+    `xegpu.convert_layout` op.
   }];
 
   let arguments = (ins TransformValueHandleTypeInterface:$target,
@@ -186,7 +187,7 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
                    );
 
-  let results = (outs);
+  let results = (outs TransformHandleTypeInterface:$newConvertOp);
   let builders = [
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -200,7 +201,7 @@ def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
     `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
-    attr-dict `:` qualified(type(operands))
+    attr-dict `:` functional-type(operands, results)
   }];
 
   let extraClassDeclaration = [{
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index f701ab1c753db..e2edc7d630fd3 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -418,18 +418,18 @@ transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
            << "Value has no users to insert layout conversion.";
   Operation *userOp = *value.getUsers().begin();
 
-  if (producerLayoutAttr != targetLayoutAttr) {
-    rewriter.setInsertionPoint(userOp);
-    auto convLayoutOp = xegpu::ConvertLayoutOp::create(
-        rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
-        targetLayoutAttr);
-    // Replace load op result with the converted layout.
-    rewriter.replaceUsesWithIf(
-        value, convLayoutOp.getResult(), [&](OpOperand &use) {
-          return use.getOwner() != convLayoutOp.getOperation();
-        });
-  }
-
+  // Emit convert_layout op.
+  rewriter.setInsertionPoint(userOp);
+  auto convLayoutOp = xegpu::ConvertLayoutOp::create(
+      rewriter, value.getLoc(), value.getType(), value, producerLayoutAttr,
+      targetLayoutAttr);
+  // Replace load op result with the converted layout.
+  rewriter.replaceUsesWithIf(
+      value, convLayoutOp.getResult(), [&](OpOperand &use) {
+        return use.getOwner() != convLayoutOp.getOperation();
+      });
+
+  results.set(llvm::cast<OpResult>(getResult()), {convLayoutOp});
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -439,6 +439,7 @@ void transform::ConvertLayoutOp::getEffects(
   onlyReadsHandle(getSgLayoutMutable(), effects);
   onlyReadsHandle(getSgDataMutable(), effects);
   onlyReadsHandle(getInstDataMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);
 }
 
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 3ca902922b850..c1e2fe7d9ee3e 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -216,6 +216,7 @@ def __init__(
             _,
         ) = _dispatch_dynamic_index_list(inst_data)
         super().__init__(
+            transform.AnyOpType.get(),
             target,
             dynamic_sg_layout,
             dynamic_sg_data,
@@ -244,4 +245,4 @@ def convert_layout(
         inst_data=inst_data,
         loc=loc,
         ip=ip,
-    )
+    ).result
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 73cb74a701d3f..45305f22e370b 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -88,7 +88,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     // expected-error at below {{Could not find a layout attribute in the producer chain.}}
-    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index bc8772e391bd7..f0e0b4d401035 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -279,7 +279,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     // CHECK: transform.xegpu.convert_layout %{{.*}}
-    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
     transform.yield
   }
 }
@@ -311,7 +311,7 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
     // CHECK: transform.xegpu.convert_layout %{{.*}}
-    transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
+    transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
     transform.yield
   }
 }
@@ -339,7 +339,7 @@ module attributes {transform.with_named_sequence} {
     %0 = transform.structured.match ops{["arith.truncf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
     // CHECK: transform.xegpu.convert_layout %{{.*}}
-    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_value) -> !transform.any_op
     transform.yield
   }
 }



More information about the Mlir-commits mailing list