[Mlir-commits] [mlir] [MLIR][XeGPU] add xegpu.set_desc_layout transform op (PR #165615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 29 12:26:22 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Tuomas Kärnä (tkarna)

<details>
<summary>Changes</summary>

Adds the first XeGPU transform op, `xegpu.set_desc_layout`.

Given a handle to `xegpu.create_nd_tdesc` op, this transform op adds a `xegpu.layout` attribute to the returned descriptor:

```mlir
module  {
  func.func @<!-- -->run(%arg0: memref<4096x4096xf32>) {
    %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> ->
        !xegpu.tensor_desc<256x256xf32>
    return
  }
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @<!-- -->__transform_main(%arg0: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["xegpu.create_nd_tdesc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
    %1 = transform.xegpu.set_desc_layout %0 sg_layout = [8, 8] sg_data = [32, 32] inst_data = [8, 16] : (!transform.any_op) -> !transform.any_op
    transform.yield 
  }
}
```

Applying the transform op produces:

```mlir
    func.func @<!-- -->run(%arg0: memref<4096x4096xf32>) {
      %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf32> ->
          !xegpu.tensor_desc<256x256xf32,
            #xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>>
      return
    }
```

For reference, the rationale behind xegpu transform ops is outlined in [this RFC document](https://github.com/tkarna/llvm-project/blob/xegpu-transform-ops-doc/mlir/docs/XeGPUTransformOps.md).

---

Patch is 25.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165615.diff


13 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt (+1) 
- (added) mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt (+6) 
- (added) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h (+28) 
- (added) mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td (+85) 
- (modified) mlir/lib/Dialect/XeGPU/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt (+17) 
- (added) mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp (+250) 
- (modified) mlir/lib/RegisterAllExtensions.cpp (+2) 
- (modified) mlir/python/CMakeLists.txt (+9) 
- (added) mlir/python/mlir/dialects/XeGPUTransformOps.td (+19) 
- (added) mlir/python/mlir/dialects/transform/xegpu.py (+65) 
- (added) mlir/test/Dialect/XeGPU/transform-ops.mlir (+38) 
- (added) mlir/test/python/dialects/transform_xegpu_ext.py (+60) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
index 9f57627c321fb..cb1e9d01821a2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/XeGPU/CMakeLists.txt
@@ -1,2 +1,3 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
+add_subdirectory(TransformOps)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..5924606402a02
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,6 @@
+set(LLVM_TARGET_DEFINITIONS XeGPUTransformOps.td)
+mlir_tablegen(XeGPUTransformOps.h.inc -gen-op-decls)
+mlir_tablegen(XeGPUTransformOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRXeGPUTransformOpsIncGen)
+
+add_mlir_doc(XeGPUTransformOps XeGPUTransformOps Dialects/ -gen-op-doc)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
new file mode 100644
index 0000000000000..dab0c3f35adda
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h
@@ -0,0 +1,28 @@
+//===- XeGPUTransformOps.h - XeGPU transformation ops -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+#define MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
+
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h.inc>
+
+namespace mlir {
+class DialectRegistry;
+
+namespace xegpu {
+void registerTransformDialectExtension(DialectRegistry &registry);
+} // namespace xegpu
+} // namespace mlir
+
+#endif // MLIR_DIALECT_XEGPU_TRANSFORMOPS_XEGPUTRANSFORMOPS_H
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..681b4861f0aeb
--- /dev/null
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -0,0 +1,85 @@
+//===- XeGPUTransformOps.td - XeGPU transformation ops -----*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef XEGPU_EXTENSION
+#define XEGPU_EXTENSION
+
+include "mlir/Dialect/Transform/IR/TransformAttrs.td"
+include "mlir/Dialect/Transform/IR/TransformDialect.td"
+include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
+include "mlir/Dialect/Transform/IR/TransformTypes.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/OpBase.td"
+
+def TransformAnyParamTypeOrAnyHandle : Type<
+    Or<[TransformHandleTypeInterface.predicate,
+        TransformParamTypeInterface.predicate]>,
+    "transform any param type or any handle type">;
+
+def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
+  AttrSizedOperandSegments,
+  DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+  TransformOpInterface
+]> {
+
+  let summary = "Set xegpu.layout attribute to an xegpu op result.";
+  let description = [{
+    Given an `xegpu.create_nd_desc` operation, this transform adds `xegpu.layout`
+    attribute to the result tensor descriptor. The layout is defined by the
+    `sg_layout`, `sg_data` and `inst_data` attributes. Returns a handle to the transformed op.
+  }];
+
+  let arguments = (ins
+                   TransformHandleTypeInterface : $target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
+                   );
+
+  let results = (outs TransformHandleTypeInterface : $transformed);
+  let builders = [
+    OpBuilder<(ins "Value":$target,
+                   "ArrayRef<OpFoldResult>":$mixedSgLayout,
+                   "ArrayRef<OpFoldResult>":$mixedSgData,
+                   "ArrayRef<OpFoldResult>":$mixedInstData
+                   )>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `sg_layout` `=` custom<DynamicIndexList>($sg_layout, $static_sg_layout)
+    `sg_data` `=` custom<DynamicIndexList>($sg_data, $static_sg_data)
+    `inst_data` `=` custom<DynamicIndexList>($inst_data, $static_inst_data)
+    attr-dict `:` functional-type(operands, results)
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgLayout() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgLayout(), getSgLayout(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedSgData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticSgData(), getSgData(), b);
+    }
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedInstData() {
+      Builder b(getContext());
+      return getMixedValues(getStaticInstData(), getInstData(), b);
+    }
+  }];
+}
+
+#endif // XEGPU_EXTENSION
diff --git a/mlir/lib/Dialect/XeGPU/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
index 31167e6af908b..46b8251a57797 100644
--- a/mlir/lib/Dialect/XeGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/CMakeLists.txt
@@ -1,3 +1,4 @@
 add_subdirectory(IR)
 add_subdirectory(Transforms)
 add_subdirectory(Utils)
+add_subdirectory(TransformOps)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
new file mode 100644
index 0000000000000..48fe841afaa83
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt
@@ -0,0 +1,17 @@
+add_mlir_dialect_library(MLIRXeGPUTransformOps
+  XeGPUTransformOps.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${PROJECT_SOURCE_DIR}/mlir/Dialect/XeGPU/TransformOps/
+
+  DEPENDS
+  MLIRXeGPUTransformOpsIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRXeGPUDialect
+  MLIRXeGPUTransforms
+  MLIRIR
+  MLIRTransformDialect
+  MLIRFuncDialect
+  MLIRSCFDialect
+)
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
new file mode 100644
index 0000000000000..1875f1050eb03
--- /dev/null
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -0,0 +1,250 @@
+//===- XeGPUTransformOps.cpp - Implementation of XeGPU transformation ops -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
+#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+
+#include <numeric>
+
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
+using namespace mlir;
+using namespace mlir::transform;
+
+class XeGPUTransformDialectExtension
+    : public transform::TransformDialectExtension<
+          XeGPUTransformDialectExtension> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XeGPUTransformDialectExtension)
+
+  using Base::Base;
+
+  void init();
+};
+
+void XeGPUTransformDialectExtension::init() {
+  declareGeneratedDialect<scf::SCFDialect>();
+  declareGeneratedDialect<arith::ArithDialect>();
+  declareGeneratedDialect<gpu::GPUDialect>();
+  declareGeneratedDialect<xegpu::XeGPUDialect>();
+
+  registerTransformOps<
+#define GET_OP_LIST
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+      >();
+}
+
+#define GET_OP_CLASSES
+#include <mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp.inc>
+
+void mlir::xegpu::registerTransformDialectExtension(DialectRegistry &registry) {
+  registry.addExtensions<XeGPUTransformDialectExtension>();
+}
+
+/// Assuming that `ofr` is an index attr or a param of index type
+/// or a transform dialect handle mapped to exactly one op
+/// with one index result, get that value and cast it to int type.
+static DiagnosedSilenceableFailure convertMixedValuesToInt(
+    transform::TransformState &state, TransformOpInterface transformOp,
+    SmallVectorImpl<int32_t> &result, ArrayRef<OpFoldResult> ofrs) {
+  for (OpFoldResult ofr : ofrs) {
+    // Attribute case.
+    if (auto attr = dyn_cast<Attribute>(ofr)) {
+      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+        result.push_back(intAttr.getInt());
+      } else {
+        return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
+      }
+      continue;
+    }
+
+    // Transform param case.
+    Value transformValue = cast<Value>(ofr);
+    if (isa<TransformParamTypeInterface>(transformValue.getType())) {
+      ArrayRef<Attribute> params = state.getParams(transformValue);
+      if (params.size() != 1)
+        return transformOp.emitDefiniteFailure()
+               << "requires exactly one parameter associated";
+      result.push_back(
+          cast<IntegerAttr>(params.front()).getValue().getSExtValue());
+      continue;
+    }
+
+    // Payload value case.
+    auto payloadOps = state.getPayloadOps(transformValue);
+    if (!llvm::hasSingleElement(payloadOps)) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "handle must be mapped to exactly one payload op";
+      diag.attachNote(transformValue.getLoc())
+          << "mapped to " << llvm::range_size(payloadOps) << " payload ops";
+      return diag;
+    }
+
+    Operation *op = *payloadOps.begin();
+    if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
+      DiagnosedSilenceableFailure diag =
+          transformOp.emitSilenceableError()
+          << "payload op must have exactly 1 index result";
+      diag.attachNote(op->getLoc())
+          << "has " << op->getNumResults() << " results";
+      return diag;
+    }
+
+    IntegerAttr intAttr;
+    if (!matchPattern(op->getResult(0), m_Constant(&intAttr)))
+      return transformOp.emitSilenceableError()
+             << "requires param or handle to be the result of a constant like "
+                "op";
+
+    result.push_back(intAttr.getInt());
+  }
+  return DiagnosedSilenceableFailure::success();
+}
+
+/// Create a layout attribute from the given parameters.
+xegpu::LayoutAttr createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
+                                   ArrayRef<int32_t> sgData,
+                                   std::optional<ArrayRef<int32_t>> instData) {
+  return xegpu::LayoutAttr::get(
+      ctx, DenseI32ArrayAttr::get(ctx, sgLayout),
+      DenseI32ArrayAttr::get(ctx, sgData),
+      instData ? DenseI32ArrayAttr::get(ctx, instData.value()) : nullptr,
+      /*lane_layout=*/nullptr,
+      /*lane_data=*/nullptr,
+      /*order=*/nullptr);
+}
+
+/// Replace xegpu.create_nd_desc op with a new one with the given layout.
+xegpu::CreateNdDescOp setDescLayout(transform::TransformRewriter &rewriter,
+                                    xegpu::CreateNdDescOp descOp,
+                                    xegpu::LayoutAttr layout) {
+  auto oldTensorDesc = descOp.getResult();
+  auto descShapedType = cast<ShapedType>(oldTensorDesc.getType());
+  auto descType = xegpu::TensorDescType::get(
+      descShapedType.getShape(), descShapedType.getElementType(),
+      /*array_length=*/1,
+      /*boundary_check=*/true,
+      /*memory_space=*/xegpu::MemorySpace::Global,
+      /*layout=*/layout);
+
+  rewriter.setInsertionPointAfter(descOp);
+  if (descOp.getMixedOffsets().size() > 0) {
+    auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+        descOp, descType, descOp.getSource(), descOp.getMixedOffsets(),
+        descOp.getMixedSizes(), descOp.getMixedStrides());
+    return newDescOp;
+  }
+  auto newDescOp = rewriter.replaceOpWithNewOp<xegpu::CreateNdDescOp>(
+      descOp, descType, descOp.getSource(), descOp.getMixedSizes(),
+      descOp.getMixedStrides());
+  return newDescOp;
+}
+
+void transform::SetDescLayoutOp::build(OpBuilder &builder,
+                                       OperationState &result, Value target,
+                                       ArrayRef<OpFoldResult> mixedSgLayout,
+                                       ArrayRef<OpFoldResult> mixedSgData,
+                                       ArrayRef<OpFoldResult> mixedInstData) {
+  SmallVector<int64_t> staticSgLayout, staticSgData, staticInstData;
+  SmallVector<Value> dynamicSgLayout, dynamicSgData, dynamicInstData;
+  dispatchIndexOpFoldResults(mixedSgLayout, dynamicSgLayout, staticSgLayout);
+  dispatchIndexOpFoldResults(mixedSgData, dynamicSgData, staticSgData);
+  dispatchIndexOpFoldResults(mixedInstData, dynamicInstData, staticInstData);
+  build(builder, result, target.getType(),
+        /*target=*/target,
+        /*sg_layout=*/dynamicSgLayout,
+        /*sg_data=*/dynamicSgData,
+        /*inst_data=*/dynamicInstData,
+        /*static_sg_layout=*/staticSgLayout,
+        /*static_sg_data=*/staticSgData,
+        /*static_inst_data=*/staticInstData);
+}
+
+DiagnosedSilenceableFailure
+transform::SetDescLayoutOp::apply(transform::TransformRewriter &rewriter,
+                                  transform::TransformResults &results,
+                                  transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  auto transformOp = cast<TransformOpInterface>(getOperation());
+
+  SmallVector<int32_t> sgLayout;
+  DiagnosedSilenceableFailure status =
+      convertMixedValuesToInt(state, transformOp, sgLayout, getMixedSgLayout());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> sgData;
+  status =
+      convertMixedValuesToInt(state, transformOp, sgData, getMixedSgData());
+  if (!status.succeeded())
+    return status;
+
+  SmallVector<int32_t> instData;
+  status =
+      convertMixedValuesToInt(state, transformOp, instData, getMixedInstData());
+  if (!status.succeeded())
+    return status;
+
+  // For now only create_nd_desc op is supported.
+  auto descOp = dyn_cast<xegpu::CreateNdDescOp>(target);
+  if (!descOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Expected a xegpu.create_nd_desc op, but got: "
+                << target->getName();
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  // Set layout attr in desc op's return type. Replaces old desc op.
+  auto layoutAttr =
+      createLayoutAttr(rewriter.getContext(), sgLayout, sgData, instData);
+  auto newdescOp = setDescLayout(rewriter, descOp, layoutAttr);
+
+  // Map result handles.
+  results.set(cast<OpResult>(getTransformed()), {newdescOp.getOperation()});
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetDescLayoutOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getSgLayoutMutable(), effects);
+  onlyReadsHandle(getSgDataMutable(), effects);
+  onlyReadsHandle(getInstDataMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 3839172fd0b42..c857c38df717c 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -56,6 +56,7 @@
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
 #include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
 #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
 #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
@@ -112,6 +113,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
   transform::registerSMTExtension(registry);
   transform::registerTuneExtension(registry);
   vector::registerTransformDialectExtension(registry);
+  xegpu::registerTransformDialectExtension(registry);
   arm_neon::registerTransformDialectExtension(registry);
   arm_sve::registerTransformDialectExtension(registry);
 
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 20ed3ab41a0b4..51c75764faf3c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -322,6 +322,15 @@ declare_mlir_dialect_extension_python_bindings(
     "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
 )
 
+declare_mlir_dialect_extension_python_bindings(
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+  TD_FILE dialects/XeGPUTransformOps.td
+  SOURCES
+    dialects/transform/xegpu.py
+  DIALECT_NAME transform
+  EXTENSION_NAME xegpu_transform)
+
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/XeGPUTransformOps.td b/mlir/python/mlir/dialects/XeGPUTransformOps.td
new file mode 100644
index 0000000000000..5a5e7b912c4a5
--- /dev/null
+++ b/mlir/python/mlir/dialects/XeGPUTransformOps.td
@@ -0,0 +1,19 @@
+//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the Python bindings generator for the XeGPU transform ops.
+//
+//===----------------------------------------------------------------------===//
+
+
+#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
+
+include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
+
+#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
new file mode 100644
index 0000000000000..720ee4070fbec
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/xe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/165615


More information about the Mlir-commits mailing list