[Mlir-commits] [mlir] [MLIR][XeGPU] add xegpu.set_desc_layout transform op (PR #165615)

Tuomas Kärnä llvmlistbot at llvm.org
Wed Nov 5 00:47:20 PST 2025


https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/165615

>From 409ac2f963aafd2439adedfc11ec09d5d7c5b6bc Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 1 Jul 2025 18:58:52 +0300
Subject: [PATCH 1/9] [mlir][xegpu] add xegpu.set_desc_layout transform op

---
 .../include/mlir/Dialect/XeGPU/CMakeLists.txt |   1 +
 .../Dialect/XeGPU/TransformOps/CMakeLists.txt |   6 +
 .../XeGPU/TransformOps/XeGPUTransformOps.h    |  28 ++
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  85 ++++++
 mlir/lib/Dialect/XeGPU/CMakeLists.txt         |   1 +
 .../Dialect/XeGPU/TransformOps/CMakeLists.txt |  17 ++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 250 ++++++++++++++++++
 mlir/lib/RegisterAllExtensions.cpp            |   2 +
 mlir/python/CMakeLists.txt                    |   9 +
 .../python/mlir/dialects/XeGPUTransformOps.td |  19 ++
 mlir/python/mlir/dialects/transform/xegpu.py  |  65 +++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    |  38 +++
 .../python/dialects/transform_xegpu_ext.py    |  54 ++++
 13 files changed, 575 insertions(+)
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
 create mode 100644 mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
 create mode 100644 mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
 create mode 100644 mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
 create mode 100644 mlir/python/mlir/dialects/XeGPUTransformOps.td
 create mode 100644 mlir/python/mlir/dialects/transform/xegpu.py
 create mode 100644 mlir/test/Dialect/XeGPU/transform-ops.mlir
 create mode 100644 mlir/test/python/dialects/transform_xegpu_ext.py

diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..5924606402a02
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
+mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
+
+add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
new file mode 100644
index 0000000000000..dab0c3f35adda
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -0,0 +1,28 @@
+//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+
+namespace mlir {
+class DialectRegistry;
+
+namespace xegpu {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..681b4861f0aeb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -0,0 +1,85 @@
+//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef XEGPU_EXTENSION
+#define XEGPU_EXTENSION
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def TransformAnyParamTypeOrAnyHandle : Type<
+    Or<[TransformHandleTypeInterface.predicate,
+        TransformParamTypeInterface.predicate]>,
+    "transform any param type or any handle type">;
+
+def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute to an xegpu op result.";
+  let description = [{
+    Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
+    attribute to the result tensor descriptor. The layout is defined by the
+    `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+  }];
+
+  let arguments = (ins
+                   TransformHandleTypeInterface : $target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   );
+
+  let results = (outs TransformHandleTypeInterface : $transformed);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+    attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
+#endif // XEGPU_EXTENSION
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..46b8251a57797 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..48fe841afaa83
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+  XeGPUTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+  DEPENDS
+  MLIRXeGPUTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRXeGPUDialect
+  MLIRXeGPUTransforms
+  MLIRIR
+  MLIRTransformDialect
+  MLIRFuncDialect
+  MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000000000..1875f1050eb03
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,250 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <numeric>
+
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+class XeGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          XeGPUTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+  using Base::Base;
+
+  void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+  declareGeneratedDialect<scf::SCFDialect>();
+  declareGeneratedDialect<arith::ArithDialect>();
+  declareGeneratedDialect<gpu::GPUDialect>();
+  declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+  registerTransformOps<
+#define GET_OP_LIST
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+      >();
+}
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<XeGPUTransformDialectExtension>();
+}
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    // Attribute case.
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        result.push_back(intAttr.getInt());
+      } else {
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      }
+      continue;
+    }
+
+    // Transform param case.
+    Value transformValue = cast<Value>(ofr);
+    if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure()
+               << "requires exactly one parameter associated";
+      result.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
+    }
+
+    // Payload value case.
+    auto payloadOps = state.getPayloadOps(transformValue);
+    if (!llvm::hasSingleElement(payloadOps)) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(transformValue.getLoc())
+          << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+      return diag;
+    }
+
+    Operation *op = *payloadOps.begin();
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+
+    IntegerAttr intAttr;
+    if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+      return transformOp.emitSilenceableError()
+             << "requires param or handle to be the result of a constant like "
+                "op";
+
+    result.push_back(intAttr.getInt());
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+                                   ArrayRef<int32_t> sgData,
+                                   std::optional<ArrayRef<int32_t>> instData) {
+  return xegpu::LayoutAttr::get(
+      ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+      DenseI32ArrayAttr::get(ctx, sgData),
+      instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+      /*lane_layout=*/nullptr,
+      /*lane_data=*/nullptr,
+      /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
+                                    xegpu::CreateNdDescOp descOp,
+                                    xegpu::LayoutAttr layout) {
+  auto oldTensorDesc = descOp.getResult();
+  auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+  auto descType = xegpu::TensorDescType::get(
+      descShapedType.getShape(), descShapedType.getElementType(),
+      /*array_length=*/1,
+      /*boundary_check=*/true,
+      /*memory_space=*/xegpu::MemorySpace::Global,
+      /*layout=*/layout);
+
+  rewriter.setInsertionPointAfter(descOp);
+  if (descOp.getMixedOffsets().size() > 0) {
+    auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+        descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
+        descOp.getMixedSizes(), descOp.getMixedStrides());
+    return newDescOp;
+  }
+  auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+      descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+      descOp.getMixedStrides());
+  return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+                                       OperationState &result, Value target,
+                                       ArrayRef<OpFoldResult> mixedSgLayout,
+                                       ArrayRef<OpFoldResult> mixedSgData,
+                                       ArrayRef<OpFoldResult> mixedInstData) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, result, target.getType(),
+        /*target=*/target,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  auto transformOp = cast<TransformOpInterface>(getOperation());
+
+  SmallVector<int32_t> sgLayout;
+  DiagnosedSilenceableFailure status =
+      convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> sgData;
+  status =
+      convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> instData;
+  status =
+      convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
+  if (!status.succeeded())
+    return status;
+
+  // For now only create_nd_desc op is supported.
+  auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+  if (!descOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Expected a xegpu.create_nd_desc op, but got: "
+                << target->getName();
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Set layout attr in desc op's return type. Replaces old desc op.
+  auto layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+  auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+  // Map result handles.
+  results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..c857c38df717c 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  xegpu::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
 
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 20ed3ab41a0b4..51c75764faf3c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/XeGPUTransformOps.td
+  SOURCES
+    dialects/transform/xegpu.py
+  DIALECT_NAME transform
+  EXTENSION_NAME xegpu_transform)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..5a5e7b912c4a5
--- /dev/null
+++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td
@@ -0,0 +1,19 @@
+//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the XeGPU transform ops.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
+
+#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
new file mode 100644
index 0000000000000..2f7e1793cf3e1
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -0,0 +1,65 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._xegpu_transform_ops_gen import *
+from .._xegpu_transform_ops_gen import _Dialect
+
+try:
+    from ...ir import *
+    from .._ods_common import _cext as _ods_cext
+    from .._ods_common import (
+        MixedValues,
+        get_op_result_or_value as _get_op_result_or_value,
+        _dispatch_dynamic_index_list,
+    )
+
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SetDescLayoutOp(SetDescLayoutOp):
+    """Specialization for SetDescLayoutOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        inst_data: MixedValues,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        target_value = _get_op_result_or_value(target)
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+
+        super().__init__(
+            target_value.type,
+            target_value,
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
new file mode 100644
index 0000000000000..7b34f46a71a1c
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @set_desc_layout
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_param
+func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param<i64>) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
new file mode 100644
index 0000000000000..230cc8d33b037
--- /dev/null
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -0,0 +1,54 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import xegpu
+from mlir.dialects.transform import structured
+
+
+def run(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
+
+
+ at run
+def setDescLayout():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayout
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]
+
+
+ at run
+def setDescLayoutDefaultIndex():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutDefaultIndex
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]

>From 52a30586c23b8fe1cdea57589b97db14a6288195 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 17:48:31 +0200
Subject: [PATCH 2/9] address Adam's comments

---
 .../XeGPU/TransformOps/XeGPUTransformOps.h    |  2 +-
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 15 ++++----
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 29 +++++---------
 mlir/python/mlir/dialects/transform/xegpu.py  |  6 ++-
 .../Dialect/XeGPU/transform-ops-invalid.mlir  | 15 ++++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 38 +++++++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    | 33 ++++++++++------
 7 files changed, 96 insertions(+), 42 deletions(-)
 create mode 100644 mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
index dab0c3f35adda..3e16d1e4a7c94 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -15,7 +15,7 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 
 #define GET_OP_CLASSES
-#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc"
 
 namespace mlir {
 class DialectRegistry;
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 681b4861f0aeb..9ca8e99739d45 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef XEGPU_EXTENSION
-#define XEGPU_EXTENSION
+#ifndef XEGPU_TRANSFORM_OPS
+#define XEGPU_TRANSFORM_OPS
 
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -27,11 +27,12 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   TransformOpInterface
 ]> {
 
-  let summary = "Set xegpu.layout attribute to an xegpu op result.";
+  let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
   let description = [{
     Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
     attribute to the result tensor descriptor. The layout is defined by the
-    `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+    `sg_layout`, and optional `sg_data` and `inst_data` attributes. Returns a handle
+    to the transformed op.
   }];
 
   let arguments = (ins
@@ -56,8 +57,8 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   let assemblyFormat = [{
     $target
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
-    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
-    `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+    (`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)^)?
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
     attr-dict `:` functional-type(operands, results)
   }];
 
@@ -82,4 +83,4 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
-#endif // XEGPU_EXTENSION
+#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 1875f1050eb03..4b3c99788ba3b 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,26 +7,10 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
-#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
-#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/SCF/Transforms/Transforms.h"
-#include "mlir/Dialect/SCF/Utils/Utils.h"
-#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformTypes.h"
-#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
-#include "mlir/Dialect/Transform/Utils/Utils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
-#include "mlir/IR/DialectRegistry.h"
-#include "mlir/IR/Operation.h"
-#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/ADT/StringRef.h"
 
 #include <numeric>
 
@@ -129,11 +113,11 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
 
 /// Create a layout attribute from the given parameters.
 xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
-                                   ArrayRef<int32_t> sgData,
+                                   std::optional<ArrayRef<int32_t>> sgData,
                                    std::optional<ArrayRef<int32_t>> instData) {
   return xegpu::LayoutAttr::get(
       ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
-      DenseI32ArrayAttr::get(ctx, sgData),
+      sgData ? DenseI32ArrayAttr::get(ctx, sgData.value()) : nullptr,
       instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
       /*lane_layout=*/nullptr,
       /*lane_data=*/nullptr,
@@ -211,12 +195,17 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
       convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
   if (!status.succeeded())
     return status;
+  auto maybeSgData =
+      sgData.empty() ? std::nullopt : std::optional<ArrayRef<int32_t>>(sgData);
 
   SmallVector<int32_t> instData;
   status =
       convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
   if (!status.succeeded())
     return status;
+  auto maybeInstData = instData.empty()
+                           ? std::nullopt
+                           : std::optional<ArrayRef<int32_t>>(instData);
 
   // For now only create_nd_desc op is supported.
   auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
@@ -229,8 +218,8 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto layoutAttr =
-      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+  auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout,
+                                     maybeSgData, maybeInstData);
   auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
 
   // Map result handles.
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2f7e1793cf3e1..74f321a3b4b27 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -28,13 +28,15 @@ def __init__(
         self,
         target: Union[Operation, Value],
         sg_layout: MixedValues,
-        sg_data: MixedValues,
-        inst_data: MixedValues,
         *,
+        sg_data: MixedValues = None,
+        inst_data: MixedValues = None,
         loc=None,
         ip=None,
     ):
         target_value = _get_op_result_or_value(target)
+        sg_data = [] if sg_data is None else sg_data
+        inst_data = [] if inst_data is None else inst_data
         (
             dynamic_sg_layout,
             static_sg_layout,
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
new file mode 100644
index 0000000000000..e28630f84aeb4
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
+
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index // expected-note {{target op}}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 7b34f46a71a1c..256626368cea0 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -19,6 +19,44 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// CHECK-LABEL: @set_desc_layout_minimal
+func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4]
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_sg_data
+func.func @set_desc_layout_sg_data(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 // CHECK-LABEL: @set_desc_layout_param
 func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 230cc8d33b037..d6df15988522b 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -17,38 +17,47 @@ def run(f):
 
 
 @run
-def setDescLayout():
+def setDescLayoutMinimal():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(
-            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
-        )
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4])
         transform.YieldOp()
-    # CHECK-LABEL: TEST: setDescLayout
+    # CHECK-LABEL: TEST: setDescLayoutMinimal
     # CHECK: %0 = transform.xegpu.set_desc_layout %
     # CHECK: sg_layout = [6, 4]
-    # CHECK: sg_data = [32, 16]
-    # CHECK: inst_data = [8, 16]
 
 
 @run
-def setDescLayoutDefaultIndex():
+def setDescLayoutSgData():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(
-            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
-        )
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
         transform.YieldOp()
-    # CHECK-LABEL: TEST: setDescLayoutDefaultIndex
+    # CHECK-LABEL: TEST: setDescLayoutSgData
     # CHECK: %0 = transform.xegpu.set_desc_layout %
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
+
+
+ at run
+def setDescLayoutInstData():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], inst_data=[8, 16])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutInstData
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
     # CHECK: inst_data = [8, 16]

>From 1be8cef9cdfebc1b86faa7770780a57e43f12325 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 19:38:54 +0200
Subject: [PATCH 3/9] nit comments

---
 .../lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 4b3c99788ba3b..5d4b906f50f10 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -39,12 +39,12 @@ void XeGPUTransformDialectExtension::init() {
 
   registerTransformOps<
 #define GET_OP_LIST
-#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
       >();
 }
 
 #define GET_OP_CLASSES
-#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
 
 void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
   registry.addExtensions<XeGPUTransformDialectExtension>();
@@ -61,10 +61,9 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
     if (auto attr = dyn_cast<Attribute>(ofr)) {
       if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
         result.push_back(intAttr.getInt());
-      } else {
-        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+        continue;
       }
-      continue;
+      return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
     }
 
     // Transform param case.

>From a89af771994e9a5c2ccd2a8a7e6edc8231552792 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 4 Nov 2025 21:16:17 +0200
Subject: [PATCH 4/9] more nit comments

---
 .../Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp | 12 +++---------
 mlir/python/mlir/dialects/transform/xegpu.py         |  6 +++---
 2 files changed, 6 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 5d4b906f50f10..0d7f0c505e1e8 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,7 +7,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -34,7 +33,6 @@ class XeGPUTransformDialectExtension
 void XeGPUTransformDialectExtension::init() {
   declareGeneratedDialect<scf::SCFDialect>();
   declareGeneratedDialect<arith::ArithDialect>();
-  declareGeneratedDialect<gpu::GPUDialect>();
   declareGeneratedDialect<xegpu::XeGPUDialect>();
 
   registerTransformOps<
@@ -173,7 +171,6 @@ DiagnosedSilenceableFailure
 transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
                                   transform::TransformResults &results,
                                   transform::TransformState &state) {
-
   auto targetOps = state.getPayloadOps(getTarget());
   if (!llvm::hasSingleElement(targetOps)) {
     return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
@@ -181,17 +178,14 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
   Operation *target = *targetOps.begin();
 
-  auto transformOp = cast<TransformOpInterface>(getOperation());
-
   SmallVector<int32_t> sgLayout;
   DiagnosedSilenceableFailure status =
-      convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
+      convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
   if (!status.succeeded())
     return status;
 
   SmallVector<int32_t> sgData;
-  status =
-      convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
+  status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
   if (!status.succeeded())
     return status;
   auto maybeSgData =
@@ -199,7 +193,7 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
 
   SmallVector<int32_t> instData;
   status =
-      convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
+      convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
   if (!status.succeeded())
     return status;
   auto maybeInstData = instData.empty()
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 74f321a3b4b27..53fd984514b10 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -17,7 +17,7 @@
 except ImportError as e:
     raise RuntimeError("Error loading imports from extension module") from e
 
-from typing import Union
+from typing import Union, Optional
 
 
 @_ods_cext.register_operation(_Dialect, replace=True)
@@ -29,8 +29,8 @@ def __init__(
         target: Union[Operation, Value],
         sg_layout: MixedValues,
         *,
-        sg_data: MixedValues = None,
-        inst_data: MixedValues = None,
+        sg_data: Optional[MixedValues] = None,
+        inst_data: Optional[MixedValues] = None,
         loc=None,
         ip=None,
     ):

>From 8543b916e26d4bbf749a9d0a35bfbb6097409545 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 4 Nov 2025 21:21:20 +0200
Subject: [PATCH 5/9] move TransformAnyParamTypeOrAnyHandle to transform
 dialect

---
 .../mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td | 7 -------
 mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td   | 5 +++++
 .../mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td   | 5 -----
 3 files changed, 5 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8728e666cd59d..70d424bae9285 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -21,13 +21,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
-// This is roughly similar to OpFoldResult assuming the handle produces a single
-// value in the payload IR.
-def TransformAnyParamTypeOrAnyHandle : Type<
-    Or<[TransformHandleTypeInterface.predicate,
-        TransformParamTypeInterface.predicate]>,
-    "transform any param type or any handle type">;
-
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
index 2d9a26e165b67..3e3fff4c63d2b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
@@ -103,4 +103,9 @@ def TransformAnyHandle : Type<
         TransformValueHandleTypeInterface.predicate]>,
     "transform operation or value handle">;
 
+def TransformAnyParamTypeOrAnyHandle : Type<
+    Or<[TransformHandleTypeInterface.predicate,
+        TransformParamTypeInterface.predicate]>,
+    "transform any param type or any handle type">;
+
 #endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 9ca8e99739d45..dfbe61558e7fa 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -16,11 +16,6 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
-def TransformAnyParamTypeOrAnyHandle : Type<
-    Or<[TransformHandleTypeInterface.predicate,
-        TransformParamTypeInterface.predicate]>,
-    "transform any param type or any handle type">;
-
 def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   AttrSizedOperandSegments,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

>From 6992b1414af617aa36df239518170ed6f9332ab6 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Tue, 4 Nov 2025 22:44:29 +0200
Subject: [PATCH 6/9] xegpu: setDescLayout retains TensorDesc
 BlockTensorDescAttrs

---
 .../Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp  | 11 +++++------
 mlir/test/Dialect/XeGPU/transform-ops.mlir            |  3 ++-
 2 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 0d7f0c505e1e8..fd44f46f5226e 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -125,13 +125,12 @@ xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
 xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
                                     xegpu::CreateNdDescOp descOp,
                                     xegpu::LayoutAttr layout) {
-  auto oldTensorDesc = descOp.getResult();
-  auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+  auto oldTensorDesc = descOp.getType();
   auto descType = xegpu::TensorDescType::get(
-      descShapedType.getShape(), descShapedType.getElementType(),
-      /*array_length=*/1,
-      /*boundary_check=*/true,
-      /*memory_space=*/xegpu::MemorySpace::Global,
+      oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
+      /*array_length=*/oldTensorDesc.getArrayLength(),
+      /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
+      /*memory_space=*/oldTensorDesc.getMemorySpace(),
       /*layout=*/layout);
 
   rewriter.setInsertionPointAfter(descOp);
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 256626368cea0..d6d68c2ebb894 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -3,8 +3,9 @@
 // CHECK-LABEL: @set_desc_layout
 func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
   // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
-  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
   return
 }
 

>From 05250fb9a591cb6a9c0f4cf27f9a84779f4f3b15 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Wed, 5 Nov 2025 10:22:53 +0200
Subject: [PATCH 7/9] move extension registration to the end + minor updates

---
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 86 +++++++++----------
 1 file changed, 41 insertions(+), 45 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index fd44f46f5226e..92960c969b716 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -11,43 +11,11 @@
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 
-#include <numeric>
-
-#include "llvm/Support/Debug.h"
-#define DEBUG_TYPE "xegpu-transforms"
+#include <optional>
 
 using namespace mlir;
 using namespace mlir::transform;
 
-class XeGPUTransformDialectExtension
-    : public transform::TransformDialectExtension<
-          XeGPUTransformDialectExtension> {
-public:
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
-
-  using Base::Base;
-
-  void init();
-};
-
-void XeGPUTransformDialectExtension::init() {
-  declareGeneratedDialect<scf::SCFDialect>();
-  declareGeneratedDialect<arith::ArithDialect>();
-  declareGeneratedDialect<xegpu::XeGPUDialect>();
-
-  registerTransformOps<
-#define GET_OP_LIST
-#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
-      >();
-}
-
-#define GET_OP_CLASSES
-#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
-
-void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
-  registry.addExtensions<XeGPUTransformDialectExtension>();
-}
-
 /// Assuming that `ofr` is an index attr or a param of index type
 /// or a transform dialect handle mapped to exactly one op
 /// with one index result, get that value and cast it to int type.
@@ -109,9 +77,10 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
 }
 
 /// Create a layout attribute from the given parameters.
-xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
-                                   std::optional<ArrayRef<int32_t>> sgData,
-                                   std::optional<ArrayRef<int32_t>> instData) {
+static xegpu::LayoutAttr
+createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+                 std::optional<ArrayRef<int32_t>> sgData,
+                 std::optional<ArrayRef<int32_t>> instData) {
   return xegpu::LayoutAttr::get(
       ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
       sgData ? DenseI32ArrayAttr::get(ctx, sgData.value()) : nullptr,
@@ -122,9 +91,9 @@ xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
 }
 
 /// Replace xegpu.create_nd_desc op with a new one with the given layout.
-xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
-                                    xegpu::CreateNdDescOp descOp,
-                                    xegpu::LayoutAttr layout) {
+static xegpu::CreateNdDescOp
+setDescLayout(transform::TransformRewriter &rewriter,
+              xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
   auto oldTensorDesc = descOp.getType();
   auto descType = xegpu::TensorDescType::get(
       oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
@@ -134,12 +103,8 @@ xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
       /*layout=*/layout);
 
   rewriter.setInsertionPointAfter(descOp);
-  if (descOp.getMixedOffsets().size() > 0) {
-    auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
-        descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
-        descOp.getMixedSizes(), descOp.getMixedStrides());
-    return newDescOp;
-  }
+  assert(descOp.getMixedOffsets().size() == 0 &&
+         "create desc op with offsets is not supported");
   auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
       descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
       descOp.getMixedStrides());
@@ -229,3 +194,34 @@ void transform::SetDescLayoutOp::getEffects(
   producesHandle(getOperation()->getOpResults(), effects);
   modifiesPayload(effects);
 }
+
+namespace {
+class XeGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          XeGPUTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+  using Base::Base;
+
+  void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+  declareGeneratedDialect<scf::SCFDialect>();
+  declareGeneratedDialect<arith::ArithDialect>();
+  declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+  registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+      >();
+}
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<XeGPUTransformDialectExtension>();
+}

>From 3a2fa81c6e093b32f94c032a4eaa06316dcb9cea Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Wed, 5 Nov 2025 10:34:41 +0200
Subject: [PATCH 8/9] sg_data is now required arg

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   |  4 ++--
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 10 ++++-----
 mlir/python/mlir/dialects/transform/xegpu.py  |  3 +--
 .../Dialect/XeGPU/transform-ops-invalid.mlir  |  2 +-
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 19 ----------------
 .../python/dialects/transform_xegpu_ext.py    | 22 +++++--------------
 6 files changed, 13 insertions(+), 47 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index dfbe61558e7fa..b985d5450be0e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -26,7 +26,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   let description = [{
     Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
     attribute to the result tensor descriptor. The layout is defined by the
-    `sg_layout`, and optional `sg_data` and `inst_data` attributes. Returns a handle
+    `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
     to the transformed op.
   }];
 
@@ -52,7 +52,7 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   let assemblyFormat = [{
     $target
     `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
-    (`sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)^)?
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
     (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
     attr-dict `:` functional-type(operands, results)
   }];
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 92960c969b716..ea6fd1fbd6d69 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -79,11 +79,11 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
 /// Create a layout attribute from the given parameters.
 static xegpu::LayoutAttr
 createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
-                 std::optional<ArrayRef<int32_t>> sgData,
+                 ArrayRef<int32_t> sgData,
                  std::optional<ArrayRef<int32_t>> instData) {
   return xegpu::LayoutAttr::get(
       ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
-      sgData ? DenseI32ArrayAttr::get(ctx, sgData.value()) : nullptr,
+      DenseI32ArrayAttr::get(ctx, sgData),
       instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
       /*lane_layout=*/nullptr,
       /*lane_data=*/nullptr,
@@ -152,8 +152,6 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
   if (!status.succeeded())
     return status;
-  auto maybeSgData =
-      sgData.empty() ? std::nullopt : std::optional<ArrayRef<int32_t>>(sgData);
 
   SmallVector<int32_t> instData;
   status =
@@ -175,8 +173,8 @@ transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
   }
 
   // Set layout attr in desc op's return type. Replaces old desc op.
-  auto layoutAttr = createLayoutAttr(rewriter.getContext(), sgLayout,
-                                     maybeSgData, maybeInstData);
+  auto layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
   auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
 
   // Map result handles.
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 53fd984514b10..cef5eb60c1a53 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -28,14 +28,13 @@ def __init__(
         self,
         target: Union[Operation, Value],
         sg_layout: MixedValues,
+        sg_data: MixedValues,
         *,
-        sg_data: Optional[MixedValues] = None,
         inst_data: Optional[MixedValues] = None,
         loc=None,
         ip=None,
     ):
         target_value = _get_op_result_or_value(target)
-        sg_data = [] if sg_data is None else sg_data
         inst_data = [] if inst_data is None else inst_data
         (
             dynamic_sg_layout,
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index e28630f84aeb4..303584518f9f4 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -9,7 +9,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error at below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
-    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index d6d68c2ebb894..23e1cd946b4cd 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -22,25 +22,6 @@ module attributes {transform.with_named_sequence} {
 
 // CHECK-LABEL: @set_desc_layout_minimal
 func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
-  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
-  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4]
-  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
-  return
-}
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
-    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
-    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] : (!transform.any_op) -> !transform.any_op
-    transform.yield
-  }
-}
-
-// -----
-
-// CHECK-LABEL: @set_desc_layout_sg_data
-func.func @set_desc_layout_sg_data(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
   // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
   %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index d6df15988522b..1c8a2bcc6a2fb 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -18,21 +18,6 @@ def run(f):
 
 @run
 def setDescLayoutMinimal():
-    sequence = transform.SequenceOp(
-        transform.FailurePropagationMode.Propagate,
-        [],
-        transform.OperationType.get("xegpu.create_nd_tdesc"),
-    )
-    with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4])
-        transform.YieldOp()
-    # CHECK-LABEL: TEST: setDescLayoutMinimal
-    # CHECK: %0 = transform.xegpu.set_desc_layout %
-    # CHECK: sg_layout = [6, 4]
-
-
- at run
-def setDescLayoutSgData():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
@@ -41,7 +26,7 @@ def setDescLayoutSgData():
     with InsertionPoint(sequence.body):
         xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
         transform.YieldOp()
-    # CHECK-LABEL: TEST: setDescLayoutSgData
+    # CHECK-LABEL: TEST: setDescLayoutMinimal
     # CHECK: %0 = transform.xegpu.set_desc_layout %
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
@@ -55,9 +40,12 @@ def setDescLayoutInstData():
         transform.OperationType.get("xegpu.create_nd_tdesc"),
     )
     with InsertionPoint(sequence.body):
-        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], inst_data=[8, 16])
+        xegpu.SetDescLayoutOp(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+        )
         transform.YieldOp()
     # CHECK-LABEL: TEST: setDescLayoutInstData
     # CHECK: %0 = transform.xegpu.set_desc_layout %
     # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]

>From 7ea25289cf6e1b3196f6827d5a0bc78679085b0a Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Wed, 5 Nov 2025 10:46:33 +0200
Subject: [PATCH 9/9] py bindings: target_value -> target_handle

---
 mlir/python/mlir/dialects/transform/xegpu.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index cef5eb60c1a53..2918bf592880a 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -34,7 +34,7 @@ def __init__(
         loc=None,
         ip=None,
     ):
-        target_value = _get_op_result_or_value(target)
+        target_handle = _get_op_result_or_value(target)
         inst_data = [] if inst_data is None else inst_data
         (
             dynamic_sg_layout,
@@ -53,8 +53,8 @@ def __init__(
         ) = _dispatch_dynamic_index_list(inst_data)
 
         super().__init__(
-            target_value.type,
-            target_value,
+            target_handle.type,
+            target_handle,
             dynamic_sg_layout,
             dynamic_sg_data,
             dynamic_inst_data,



More information about the Mlir-commits mailing list