[Mlir-commits] [mlir] 3a68751 - [MLIR][XeGPU][Transform] add xegpu.set_desc_layout transform op (#165615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Nov 6 06:25:39 PST 2025


Author: Tuomas Kärnä
Date: 2025-11-06T14:25:34Z
New Revision: 3a6875119080ea31d318017673cbaf8c95f0a084

URL: https://github.com/llvm/llvm-project/commit/3a6875119080ea31d318017673cbaf8c95f0a084
DIFF: https://github.com/llvm/llvm-project/commit/3a6875119080ea31d318017673cbaf8c95f0a084.diff

LOG: [MLIR][XeGPU][Transform] add xegpu.set_desc_layout transform op (#165615)

Adds the first XeGPU transform op, `xegpu.set_desc_layout`, which attachs a `xegpu.layout` attribute to the descriptor that a `xegpu.create_nd_tdesc` op returns.

Added: 
    mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
    mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
    mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
    mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
    mlir/python/mlir/dialects/XeGPUTransformOps.td
    mlir/python/mlir/dialects/transform/xegpu.py
    mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
    mlir/test/Dialect/XeGPU/transform-ops.mlir
    mlir/test/python/dialects/transform_xegpu_ext.py

Modified: 
    mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
    mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
    mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
    mlir/lib/Dialect/XeGPU/CMakeLists.txt
    mlir/lib/RegisterAllExtensions.cpp
    mlir/python/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 8728e666cd59d..70d424bae9285 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -21,13 +21,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/IR/OpBase.td"
 include "mlir/IR/RegionKindInterface.td"
 
-// This is roughly similar to OpFoldResult assuming the handle produces a single
-// value in the payload IR.
-def TransformAnyParamTypeOrAnyHandle : Type<
-    Or<[TransformHandleTypeInterface.predicate,
-        TransformParamTypeInterface.predicate]>,
-    "transform any param type or any handle type">;
-
 //===----------------------------------------------------------------------===//
 // Apply...PatternsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
index 2d9a26e165b67..3e3fff4c63d2b 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformTypes.td
@@ -103,4 +103,9 @@ def TransformAnyHandle : Type<
         TransformValueHandleTypeInterface.predicate]>,
     "transform operation or value handle">;
 
+def TransformAnyParamTypeOrAnyHandle : Type<
+    Or<[TransformHandleTypeInterface.predicate,
+        TransformParamTypeInterface.predicate]>,
+    "transform any param type or any handle type">;
+
 #endif  // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMTYPES

diff  --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)

diff  --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..5924606402a02
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
+mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
+
+add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)

diff  --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
new file mode 100644
index 0000000000000..3e16d1e4a7c94
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -0,0 +1,28 @@
+//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc"
+
+namespace mlir {
+class DialectRegistry;
+
+namespace xegpu {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H

diff  --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..b985d5450be0e
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -0,0 +1,81 @@
+//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef XEGPU_TRANSFORM_OPS
+#define XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute to a xegpu.create_nd_desc op result.";
+  let description = [{
+    Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
+    attribute to the result tensor descriptor. The layout is defined by the
+    `sg_layout`, and `sg_data` and optional `inst_data` attributes. Returns a handle
+    to the transformed op.
+  }];
+
+  let arguments = (ins
+                   TransformHandleTypeInterface : $target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   );
+
+  let results = (outs TransformHandleTypeInterface : $transformed);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    (`inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)^)?
+    attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
+#endif // XEGPU_TRANSFORM_OPS

diff  --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..46b8251a57797 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)
+add_subdirectory(TransformOps)

diff  --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..48fe841afaa83
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+  XeGPUTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+  DEPENDS
+  MLIRXeGPUTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRXeGPUDialect
+  MLIRXeGPUTransforms
+  MLIRIR
+  MLIRTransformDialect
+  MLIRFuncDialect
+  MLIRSCFDialect
+)

diff  --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000000000..8943ba09d9c34
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,225 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+
+#include <optional>
+
+using namespace mlir;
+using namespace mlir::transform;
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    // Attribute case.
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        result.push_back(intAttr.getInt());
+        continue;
+      }
+      return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+    }
+
+    // Transform param case.
+    Value transformValue = cast<Value>(ofr);
+    if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure()
+               << "requires exactly one parameter associated";
+      result.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
+    }
+
+    // Payload value case.
+    auto payloadOps = state.getPayloadOps(transformValue);
+    if (!llvm::hasSingleElement(payloadOps)) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(transformValue.getLoc())
+          << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+      return diag;
+    }
+
+    Operation *op = *payloadOps.begin();
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+
+    IntegerAttr intAttr;
+    if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+      return transformOp.emitSilenceableError()
+             << "requires param or handle to be the result of a constant like "
+                "op";
+
+    result.push_back(intAttr.getInt());
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+static xegpu::LayoutAttr
+createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+                 ArrayRef<int32_t> sgData,
+                 std::optional<ArrayRef<int32_t>> instData) {
+  return xegpu::LayoutAttr::get(
+      ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+      DenseI32ArrayAttr::get(ctx, sgData),
+      instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+      /*lane_layout=*/nullptr,
+      /*lane_data=*/nullptr,
+      /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+static xegpu::CreateNdDescOp
+setDescLayout(transform::TransformRewriter &rewriter,
+              xegpu::CreateNdDescOp descOp, xegpu::LayoutAttr layout) {
+  assert(descOp.getMixedOffsets().size() == 0 &&
+         "create desc op with offsets is not supported");
+  auto oldTensorDesc = descOp.getType();
+  auto descType = xegpu::TensorDescType::get(
+      oldTensorDesc.getShape(), oldTensorDesc.getElementType(),
+      /*array_length=*/oldTensorDesc.getArrayLength(),
+      /*boundary_check=*/oldTensorDesc.getBoundaryCheck(),
+      /*memory_space=*/oldTensorDesc.getMemorySpace(),
+      /*layout=*/layout);
+
+  rewriter.setInsertionPointAfter(descOp);
+  auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+      descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+      descOp.getMixedStrides());
+  return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+                                       OperationState &result, Value target,
+                                       ArrayRef<OpFoldResult> mixedSgLayout,
+                                       ArrayRef<OpFoldResult> mixedSgData,
+                                       ArrayRef<OpFoldResult> mixedInstData) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, result, target.getType(),
+        /*target=*/target,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  SmallVector<int32_t> sgLayout;
+  DiagnosedSilenceableFailure status =
+      convertMixedValuesToInt(state, (*this), sgLayout, getMixedSgLayout());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> sgData;
+  status = convertMixedValuesToInt(state, (*this), sgData, getMixedSgData());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> instData;
+  status =
+      convertMixedValuesToInt(state, (*this), instData, getMixedInstData());
+  if (!status.succeeded())
+    return status;
+  auto maybeInstData = instData.empty()
+                           ? std::nullopt
+                           : std::optional<ArrayRef<int32_t>>(instData);
+
+  // For now only create_nd_desc op is supported.
+  auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+  if (!descOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Expected a xegpu.create_nd_desc op, but got: "
+                << target->getName();
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Set layout attr in desc op's return type. Replaces old desc op.
+  auto layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, maybeInstData);
+  auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+  // Map result handles.
+  results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+namespace {
+class XeGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          XeGPUTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+  using Base::Base;
+
+  void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+  declareGeneratedDialect<scf::SCFDialect>();
+  declareGeneratedDialect<arith::ArithDialect>();
+  declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+  registerTransformOps<
+#define GET_OP_LIST
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+      >();
+}
+} // namespace
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc"
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<XeGPUTransformDialectExtension>();
+}

diff  --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..c857c38df717c 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  xegpu::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
 

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 20ed3ab41a0b4..51c75764faf3c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/XeGPUTransformOps.td
+  SOURCES
+    dialects/transform/xegpu.py
+  DIALECT_NAME transform
+  EXTENSION_NAME xegpu_transform)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

diff  --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..5a5e7b912c4a5
--- /dev/null
+++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td
@@ -0,0 +1,19 @@
+//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the XeGPU transform ops.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
+
+#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS

diff  --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
new file mode 100644
index 0000000000000..2918bf592880a
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -0,0 +1,66 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from .._xegpu_transform_ops_gen import *
+from .._xegpu_transform_ops_gen import _Dialect
+
+try:
+    from ...ir import *
+    from .._ods_common import _cext as _ods_cext
+    from .._ods_common import (
+        MixedValues,
+        get_op_result_or_value as _get_op_result_or_value,
+        _dispatch_dynamic_index_list,
+    )
+
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union, Optional
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SetDescLayoutOp(SetDescLayoutOp):
+    """Specialization for SetDescLayoutOp class."""
+
+    def __init__(
+        self,
+        target: Union[Operation, Value],
+        sg_layout: MixedValues,
+        sg_data: MixedValues,
+        *,
+        inst_data: Optional[MixedValues] = None,
+        loc=None,
+        ip=None,
+    ):
+        target_handle = _get_op_result_or_value(target)
+        inst_data = [] if inst_data is None else inst_data
+        (
+            dynamic_sg_layout,
+            static_sg_layout,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_layout)
+        (
+            dynamic_sg_data,
+            static_sg_data,
+            _,
+        ) = _dispatch_dynamic_index_list(sg_data)
+        (
+            dynamic_inst_data,
+            static_inst_data,
+            _,
+        ) = _dispatch_dynamic_index_list(inst_data)
+
+        super().__init__(
+            target_handle.type,
+            target_handle,
+            dynamic_sg_layout,
+            dynamic_sg_data,
+            dynamic_inst_data,
+            static_sg_layout=static_sg_layout,
+            static_sg_data=static_sg_data,
+            static_inst_data=static_inst_data,
+            loc=loc,
+            ip=ip,
+        )

diff  --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
new file mode 100644
index 0000000000000..303584518f9f4
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics
+
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index // expected-note {{target op}}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Expected a xegpu.create_nd_desc op, but got: arith.constant}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

diff  --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
new file mode 100644
index 0000000000000..23e1cd946b4cd
--- /dev/null
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -0,0 +1,58 @@
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @set_desc_layout
+func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.block_tdesc_attr<boundary_check = false>
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_minimal
+func.func @set_desc_layout_minimal(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 4] sg_data = [32, 32] : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_desc_layout_param
+func.func @set_desc_layout_param(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
+  // CHECK-SAME: #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], inst_data = [8, 16]>>
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_desc_layout %{{.*}}
+    %layout0 = transform.param.constant 8 : i64 -> !transform.param<i64>
+    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [%layout0, 4] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op, !transform.param<i64>) -> !transform.any_op
+    transform.yield
+  }
+}

diff  --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
new file mode 100644
index 0000000000000..1c8a2bcc6a2fb
--- /dev/null
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -0,0 +1,51 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import xegpu
+from mlir.dialects.transform import structured
+
+
+def run(f):
+    with Context(), Location.unknown():
+        module = Module.create()
+        with InsertionPoint(module.body):
+            print("\nTEST:", f.__name__)
+            f()
+        print(module)
+    return f
+
+
+ at run
+def setDescLayoutMinimal():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutMinimal
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+
+
+ at run
+def setDescLayoutInstData():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.create_nd_tdesc"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetDescLayoutOp(
+            sequence.bodyTarget, sg_layout=[6, 4], sg_data=[32, 16], inst_data=[8, 16]
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setDescLayoutInstData
+    # CHECK: %0 = transform.xegpu.set_desc_layout %
+    # CHECK: sg_layout = [6, 4]
+    # CHECK: sg_data = [32, 16]
+    # CHECK: inst_data = [8, 16]


        


More information about the Mlir-commits mailing list