[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add convert_layout op (PR #167342)

Tuomas Kärnä llvmlistbot at llvm.org
Mon Nov 10 08:49:20 PST 2025


https://github.com/tkarna created https://github.com/llvm/llvm-project/pull/167342

Adds `transform.xegpu.convert_layout` transform op that inserts an `xegpu.convert_layout` op for a given `Value`.

For reference, the rationale behind xegpu transform ops is outlined in [this RFC document](https://github.com/tkarna/llvm-project/blob/xegpu-transform-ops-doc/mlir/docs/XeGPUTransformOps.md).

Contrary to the RFC, `convert_layout` op takes a value handle, in contrast to `convert_operand_layout` op which used an operation handle and operand index. The operand value handle can be obtained with `transform.get_operand`:

```mlir
%tile_a = transform.get_operand %dpas_op[0] : (!transform.any_op) -> !transform.any_value
transform.xegpu.convert_layout %tile_a sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
```

>From b54b8c6361247065bfe4d1e158395ac3d25e81eb Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 09:48:02 +0200
Subject: [PATCH] [mlir][xegpu][transformops] add convert_layout op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 62 +++++++++++++++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 79 +++++++++++++++++++
 mlir/python/mlir/dialects/transform/xegpu.py  | 43 ++++++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 63 +++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    | 44 +++++++++++
 5 files changed, 291 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 34f333e556deb..b33b0a6110b1e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -161,4 +161,66 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
   }];
 }
 
+def ConvertLayoutOp : Op<Transform_Dialect, "xegpu.convert_layout", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Convert xegpu.layout attribute for a value.";
+  let description = [{
+    Adds an `xegpu.convert_layout` op to convert the `xegpu.layout` attribute
+    of a value. First, the `xegpu.load_nd` producer op of the value is found.
+    It must already be annotated with a layout. An `xegpu.convert_layout` op,
+    whose destination layout is defined by the `sg_layout`, `sg_data` and
+    optional `inst_data` attributes, is inserted after the load op.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface:$target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   );
+
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 5fdd8534e4e51..45c76a7859a19 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -341,6 +341,85 @@ void transform::SetOpLayoutAttrOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::ConvertLayoutOp::build(OpBuilder &builder,
+                                       OperationState &ostate, Value target,
+                                       ArrayRef<OpFoldResult> mixedSgLayout,
+                                       ArrayRef<OpFoldResult> mixedSgData,
+                                       ArrayRef<OpFoldResult> mixedInstData) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::ConvertLayoutOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+
+  auto value = *targetValues.begin();
+
+  xegpu::LayoutAttr layoutAttr = nullptr;
+  auto status = getLayoutAttrFromOperands(getContext(), state, (*this),
+                                          getMixedSgLayout(), getMixedSgData(),
+                                          getMixedInstData(), layoutAttr);
+  if (!status.succeeded())
+    return status;
+
+  // Get load op.
+  auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
+  if (!maybeLoadOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find load op.";
+  }
+  auto loadOp = *maybeLoadOp;
+  // Get load op operand value layout
+  auto producerLayoutAttr =
+      xegpu::getDistributeLayoutAttr(loadOp.getOperand(0));
+  if (!producerLayoutAttr) {
+    return emitSilenceableFailure(getLoc())
+           << "Operand producer op does not have a layout attr.";
+  }
+
+  if (producerLayoutAttr != layoutAttr) {
+    rewriter.setInsertionPointAfter(loadOp.getOperation());
+    auto source = loadOp.getResult();
+    auto convLayoutOp = xegpu::ConvertLayoutOp::create(
+        rewriter, loadOp.getLoc(), source.getType(), source, producerLayoutAttr,
+        layoutAttr);
+    // Replace load op result with the converted layout.
+    rewriter.replaceUsesWithIf(
+        source, convLayoutOp.getResult(), [&](OpOperand &use) {
+          return use.getOwner() != convLayoutOp.getOperation();
+        });
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::ConvertLayoutOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index ce8015d8f557b..6bf8ad3064be1 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -132,3 +132,46 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class ConvertLayoutOp(ConvertLayoutOp):
+    """Specialization for ConvertLayoutOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        *,
+        inst_data: Optional[MixedValues] = None,
+        loc=None,
+        ip=None,
+    ):
+        inst_data = [] if inst_data is None else inst_data
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+        super().__init__(
+            target,
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index bd6a79244ed30..2a914d7604ba9 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -252,3 +252,66 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @convert_layout_a
+func.func @convert_layout_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+  // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+  // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
+  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas %[[V2]]
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.convert_layout %{{.*}}
+    transform.xegpu.convert_layout %1 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @convert_layout_a_sg_param
+func.func @convert_layout_a_sg_param(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>>
+  // CHECK: %[[V1:.+]] = xegpu.load_nd %[[V0]]
+  %1 = xegpu.load_nd %0[%c0, %c0]  : !xegpu.tensor_desc<256x32xf16, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>> -> vector<256x32xf16>
+  // CHECK: %[[V2:.+]] = xegpu.convert_layout %[[V1]]
+  // CHECK: input_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [32, 16]>
+  // CHECK: target_layout = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0]  : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0]  : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  // CHECK: = xegpu.dpas %[[V2]]
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+    // CHECK: transform.xegpu.convert_layout %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    transform.xegpu.convert_layout %1 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : !transform.any_value, !transform.param<i64>
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 0b587d2020aa6..4fda801e48964 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -113,3 +113,47 @@ def setOpLayoutAttrResult():
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+
+
+ at run
+def ConvertLayoutMinimal():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        xegpu.ConvertLayoutOp(
+            operand,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: ConvertLayoutMinimal
+    # CHECK: transform.xegpu.convert_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+
+
+ at run
+def ConvertLayout():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [1])
+        xegpu.ConvertLayoutOp(
+            operand,
+            sg_layout=[6, 4],
+            sg_data=[32, 16],
+            inst_data=[8, 16],
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: ConvertLayout
+    # CHECK: transform.xegpu.convert_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]



More information about the Mlir-commits mailing list