[Mlir-commits] [mlir] [MLIR][XeGPU] add xegpu.set_desc_layout transform op (PR #165615)

Tuomas Kärnä llvmlistbot at llvm.org
Thu Oct 30 10:39:15 PDT 2025


https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/165615

>From 1bbe829c40e38343fff440980e144a62c12cb267 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 1 Jul 2025 18:58:52 +0300
Subject: [PATCH 1/3] [mlir][xegpu] add xegpu.set_desc_layout transform op

---
 .../include/mlir/Dialect/XeGPU/CMakeLists.txt |   1 +
 .../Dialect/XeGPU/TransformOps/CMakeLists.txt |   6 +
 .../XeGPU/TransformOps/XeGPUTransformOps.h    |  28 ++
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  85 ++++++
 mlir/lib/Dialect/XeGPU/CMakeLists.txt         |   1 +
 .../Dialect/XeGPU/TransformOps/CMakeLists.txt |  17 ++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 250 ++++++++++++++++++
 mlir/lib/RegisterAllExtensions.cpp            |   2 +
 mlir/python/CMakeLists.txt                    |   9 +
 .../python/mlir/dialects/XeGPUTransformOps.td |  19 ++
 mlir/python/mlir/dialects/transform/xegpu.py  |  65 +++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    |  38 +++
 .../python/dialects/transform_xegpu_ext.py    |  54 ++++
 13 files changed, 575 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
 create mode 100644 mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
 create mode 100644 mlir/python/mlir/dialects/XeGPUTransformOps.td
 create mode 100644 mlir/python/mlir/dialects/transform/xegpu.py
 create mode 100644 mlir/test/Dialect/XeGPU/transform-ops.mlir
 create mode 100644 mlir/test/python/dialects/transform_xegpu_ext.py

diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..5924606402a02
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
+mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
+
+add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
new file mode 100644
index 0000000000000..dab0c3f35adda
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -0,0 +1,28 @@
+//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+
+namespace mlir {
+class DialectRegistry;
+
+namespace xegpu {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..681b4861f0aeb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -0,0 +1,85 @@
+//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef XEGPU_EXTENSION
+#define XEGPU_EXTENSION
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def TransformAnyParamTypeOrAnyHandle : Type<
+    Or<[TransformHandleTypeInterface.predicate,
+        TransformParamTypeInterface.predicate]>,
+    "transform any param type or any handle type">;
+
+def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute to an xegpu op result.";
+  let description = [{
+    Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
+    attribute to the result tensor descriptor. The layout is defined by the
+    `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+  }];
+
+  let arguments = (ins
+                   TransformHandleTypeInterface : $target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   );
+
+  let results = (outs TransformHandleTypeInterface : $transformed);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+    attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
+#endif // XEGPU_EXTENSION
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..46b8251a57797 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..48fe841afaa83
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+  XeGPUTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+  DEPENDS
+  MLIRXeGPUTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRXeGPUDialect
+  MLIRXeGPUTransforms
+  MLIRIR
+  MLIRTransformDialect
+  MLIRFuncDialect
+  MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000000000..1875f1050eb03
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,250 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <numeric>
+
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+class XeGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          XeGPUTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+  using Base::Base;
+
+  void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+  declareGeneratedDialect<scf::SCFDialect>();
+  declareGeneratedDialect<arith::ArithDialect>();
+  declareGeneratedDialect<gpu::GPUDialect>();
+  declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+  registerTransformOps<
+#define GET_OP_LIST
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+      >();
+}
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<XeGPUTransformDialectExtension>();
+}
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    // Attribute case.
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        result.push_back(intAttr.getInt());
+      } else {
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      }
+      continue;
+    }
+
+    // Transform param case.
+    Value transformValue = cast<Value>(ofr);
+    if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure()
+               << "requires exactly one parameter associated";
+      result.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
+    }
+
+    // Payload value case.
+    auto payloadOps = state.getPayloadOps(transformValue);
+    if (!llvm::hasSingleElement(payloadOps)) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(transformValue.getLoc())
+          << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+      return diag;
+    }
+
+    Operation *op = *payloadOps.begin();
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+
+    IntegerAttr intAttr;
+    if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+      return transformOp.emitSilenceableError()
+             << "requires param or handle to be the result of a constant like "
+                "op";
+
+    result.push_back(intAttr.getInt());
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+                                   ArrayRef<int32_t> sgData,
+                                   std::optional<ArrayRef<int32_t>> instData) {
+  return xegpu::LayoutAttr::get(
+      ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+      DenseI32ArrayAttr::get(ctx, sgData),
+      instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+      /*lane_layout=*/nullptr,
+      /*lane_data=*/nullptr,
+      /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
+                                    xegpu::CreateNdDescOp descOp,
+                                    xegpu::LayoutAttr layout) {
+  auto oldTensorDesc = descOp.getResult();
+  auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+  auto descType = xegpu::TensorDescType::get(
+      descShapedType.getShape(), descShapedType.getElementType(),
+      /*array_length=*/1,
+      /*boundary_check=*/true,
+      /*memory_space=*/xegpu::MemorySpace::Global,
+      /*layout=*/layout);
+
+  rewriter.setInsertionPointAfter(descOp);
+  if (descOp.getMixedOffsets().size() > 0) {
+    auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+        descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
+        descOp.getMixedSizes(), descOp.getMixedStrides());
+    return newDescOp;
+  }
+  auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+      descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+      descOp.getMixedStrides());
+  return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+                                       OperationState &result, Value target,
+                                       ArrayRef<OpFoldResult> mixedSgLayout,
+                                       ArrayRef<OpFoldResult> mixedSgData,
+                                       ArrayRef<OpFoldResult> mixedInstData) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, result, target.getType(),
+        /*target=*/target,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  auto transformOp = cast<TransformOpInterface>(getOperation());
+
+  SmallVector<int32_t> sgLayout;
+  DiagnosedSilenceableFailure status =
+      convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> sgData;
+  status =
+      convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> instData;
+  status =
+      convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
+  if (!status.succeeded())
+    return status;
+
+  // For now only create_nd_desc op is supported.
+  auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+  if (!descOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Expected a xegpu.create_nd_desc op, but got: "
+                << target->getName();
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Set layout attr in desc op's return type. Replaces old desc op.
+  auto layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+  auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+  // Map result handles.
+  results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..c857c38df717c 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  xegpu::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
 
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 20ed3ab41a0b4..51c75764faf3c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/XeGPUTransformOps.td
+  SOURCES
+    dialects/transform/xegpu.py
+  DIALECT_NAME transform
+  EXTENSION_NAME xegpu_transform)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..5a5e7b912c4a5
--- /dev/null
+++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td
@@ -0,0 +1,19 @@
+//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the XeGPU transform ops.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
+
+#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
new file mode 100644
index 0000000000000..2f7e1793cf3e1
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -0,0 +1,65 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._xegpu_transform_ops_gen import *
+from .._xegpu_transform_ops_gen import _Dialect
+
+try:
+    from ...ir import *
+    from .._ods_common import _cext as _ods_cext
+    from .._ods_common import (
+        MixedValues,
+        get_op_result_or_value as _get_op_result_or_value,
+        _dispatch_dynamic_index_list,
+    )
+
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SetDescLayoutOp(SetDescLayoutOp):
+    """Specialization for SetDescLayoutOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        inst_data: MixedValues,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        target_value = _get_op_result_or_value(target)
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+
+        super().__init__(
+            target_value.type,
+            target_value,
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
new file mode 100644
index 0000000000000..7b34f46a71a1c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @set_desc_layout
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_param
+func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param<i64>) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
new file mode 100644
index 0000000000000..230cc8d33b037
--- /dev/null
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -0,0 +1,54 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import xegpu
+from mlir.dialects.transform import structured
+
+
+def run(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
+
+
+ at run
+def setDescLayout():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayout
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]
+
+
+ at run
+def setDescLayoutDefaultIndex():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutDefaultIndex
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]

>From 0c82c93611b8eaff73a5360f6e8d47e2d36451b9 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 17:48:31 +0200
Subject: [PATCH 2/3] address Adam's comments

---
 .../XeGPU/TransformOps/XeGPUTransformOps.h    |  2 +-
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 15 ++++----
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 29 +++++---------
 mlir/python/mlir/dialects/transform/xegpu.py  |  6 ++-
 .../Dialect/XeGPU/transform-ops-invalid.mlir  | 15 ++++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 38 +++++++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    | 33 ++++++++++------
 7 files changed, 96 insertions(+), 42 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
index dab0c3f35adda..3e16d1e4a7c94 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 
 #define GET_OP_CLASSES
-#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc"
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 681b4861f0aeb..9ca8e99739d45 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef XEGPU_EXTENSION
-#define XEGPU_EXTENSION
+#ifndef XEGPU_TRANSFORM_OPS
+#define XEGPU_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -27,11 +27,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   TransformOpInterface
 ]> {
 
-  let summary = "Set xegpu.layout attribute to an xegpu op result.";
+  let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
   let description = [{
     Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
     attribute to the result tensor descriptor. The layout is defined by the
-    `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+    `sg_layout`, and optional `sg_data` and `inst_data` attributes. Returns a handle
+    to the transformed op.
   }];
 
   let arguments = (ins
@@ -56,8 +57,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   let assemblyFormat = [{
     $target
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
-    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
-    `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+    (`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)^)?
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
     attr-dict `:` functional-type(operands, results)
   }];
 
@@ -82,4 +83,4 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
-#endif // XEGPU_EXTENSION
+#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 1875f1050eb03..4b3c99788ba3b 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,26 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
-#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformTypes.h"
-#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-#include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
-#include "mlir/IR/DialectRegistry.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
 
 #include <numeric>
 
@@ -129,11 +113,11 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
 
 /// Create a layout attribute from the given parameters.
 xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
-                                   ArrayRef<int32_t> sgData,
+                                   std::optional<ArrayRef<int32_t>> sgData,
                                    std::optional<ArrayRef<int32_t>> instData) {
   return xegpu::LayoutAttr::get(
       ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
-      DenseI32ArrayAttr::get(ctx, sgData),
+      sgData ? DenseI32ArrayAttr::get(ctx, sgData.value()) : nullptr,
       instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
       /*lane_layout=*/nullptr,
       /*lane_data=*/nullptr,
@@ -211,12 +195,17 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
       convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
   if (!status.succeeded())
     return status;
+  auto maybeSgData =
+      sgData.empty() ? std::nullopt : std::optional<ArrayRef<int32_t>>(sgData);
 
   SmallVector<int32_t> instData;
   status =
       convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
   if (!status.succeeded())
     return status;
+  auto maybeInstData = instData.empty()
+                           ? std::nullopt
+                           : std::optional<ArrayRef<int32_t>>(instData);
 
   // For now only create_nd_desc op is supported.
   auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
@@ -229,8 +218,8 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto layoutAttr =
-      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+  auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout,
+                                     maybeSgData, maybeInstData);
   auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
 
   // Map result handles.
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2f7e1793cf3e1..74f321a3b4b27 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -28,13 +28,15 @@ def __init__(
         self,
         target: Union[Operation, Value],
         sg_layout: MixedValues,
-        sg_data: MixedValues,
-        inst_data: MixedValues,
         *,
+        sg_data: MixedValues = None,
+        inst_data: MixedValues = None,
         loc=None,
         ip=None,
     ):
         target_value = _get_op_result_or_value(target)
+        sg_data = [] if sg_data is None else sg_data
+        inst_data = [] if inst_data is None else inst_data
         (
             dynamic_sg_layout,
             static_sg_layout,
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
new file mode 100644
index 0000000000000..e28630f84aeb4
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
+
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index // expected-note {{target op}}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 7b34f46a71a1c..256626368cea0 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -19,6 +19,44 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @set_desc_layout_minimal
+func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4]
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_sg_data
+func.func @set_desc_layout_sg_data(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @set_desc_layout_param
 func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 230cc8d33b037..d6df15988522b 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -17,38 +17,47 @@ def run(f):
 
 
 @run
-def setDescLayout():
+def setDescLayoutMinimal():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(
-            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
-        )
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4])
         transform.YieldOp()
-    # CHECK-LABEL: TEST: setDescLayout
+    # CHECK-LABEL: TEST: setDescLayoutMinimal
     # CHECK: %0 = transform.xegpu.set_desc_layout %
     # CHECK: sg_layout = [6, 4]
-    # CHECK: sg_data = [32, 16]
-    # CHECK: inst_data = [8, 16]
 
 
 @run
-def setDescLayoutDefaultIndex():
+def setDescLayoutSgData():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(
-            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
-        )
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
         transform.YieldOp()
-    # CHECK-LABEL: TEST: setDescLayoutDefaultIndex
+    # CHECK-LABEL: TEST: setDescLayoutSgData
     # CHECK: %0 = transform.xegpu.set_desc_layout %
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
+
+
+ at run
+def setDescLayoutInstData():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], inst_data=[8, 16])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutInstData
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
     # CHECK: inst_data = [8, 16]

>From 86b0ee1d47e446be0c9ada9ed62504daa7fea0e2 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 19:38:54 +0200
Subject: [PATCH 3/3] nit comments

---
 .../lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 4b3c99788ba3b..5d4b906f50f10 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -39,12 +39,12 @@ void XeGPUTransformDialectExtension::init() {
 
   registerTransformOps<
 #define GET_OP_LIST
-#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
       >();
 }
 
 #define GET_OP_CLASSES
-#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
 
 void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
   registry.addExtensions<XeGPUTransformDialectExtension>();
@@ -61,10 +61,9 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
     if (auto attr = dyn_cast<Attribute>(ofr)) {
       if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
         result.push_back(intAttr.getInt());
-      } else {
-        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+        continue;
       }
-      continue;
+      return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
     }
 
     // Transform param case.



More information about the Mlir-commits mailing list