[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add get_desc_op (PR #166801)

Tuomas Kärnä llvmlistbot at llvm.org
Thu Nov 6 09:00:45 PST 2025


https://github.com/tkarna created https://github.com/llvm/llvm-project/pull/166801

Add `transform.xegpu.get_desc_op` transform op that finds a `xegpu.create_nd_tdesc` producer op of a `Value`.

For reference, the rationale behind xegpu transform ops is outlined in [this RFC document](https://github.com/tkarna/llvm-project/blob/xegpu-transform-ops-doc/mlir/docs/XeGPUTransformOps.md).

Contrary to the RFC, `get_desc_op` takes a value handle, instead of an operation handle and operand index. The operand value handle can be obtained with `transform.get_operand`:

```mlir
%tile_c = transform.get_operand %dpas_op[2] : (!transform.any_op) -> !transform.any_value
%desc_op_c = transform.xegpu.get_desc_op %tile_c : (!transform.any_value) -> !transform.any_op
```



>From 88978917df16af46442ddf6ce16f10de6a9046f0 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 08:35:32 +0200
Subject: [PATCH] [mlir][xegpu][transformops] add get_desc_op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 17 +++++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 66 +++++++++++++++++++
 mlir/python/mlir/dialects/transform/xegpu.py  | 21 ++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 25 +++++++
 .../python/dialects/transform_xegpu_ext.py    | 17 ++++-
 5 files changed, 145 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..199bd2024c373 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -16,6 +16,23 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/OpBase.td"
 
+def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
+  DeclareOpInterfaceMethods<TransformOpInterface>,
+  NavigationTransformOpTrait, MemoryEffectsOpInterface
+]> {
+
+  let summary = "Get a handle to the descriptor op of a value.";
+  let description = [{
+    Traces the producers of the given value until an `xegpu.create_nd_tdesc`
+    descriptor op is found. Returns a handle to it.
+  }];
+
+  let arguments = (ins TransformValueHandleTypeInterface : $target);
+
+  let results = (outs TransformHandleTypeInterface : $descHandle);
+  let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
 def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   AttrSizedOperandSegments,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..a4aaed14dfff6 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -13,6 +13,9 @@
 
 #include <optional>
 
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
 using namespace mlir;
 using namespace mlir::transform;
 
@@ -76,6 +79,47 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
   return DiagnosedSilenceableFailure::success();
 }
 
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+  Value currentValue = val;
+  if (!currentValue.getDefiningOp()) {
+    // Value may be a block argument initialized outside a loop.
+    if (val.getNumUses() == 0) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to find producer op, value has no uses.");
+      return std::nullopt;
+    }
+    auto userOp = val.getUsers().begin();
+    auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+    if (!parentLoop) {
+      LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
+      return std::nullopt;
+    }
+    int64_t iterArgIdx;
+    if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+      auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+      iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+      currentValue = parentLoop.getInits()[iterArgIdx];
+    } else {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Failed to find producer op, value not in init values.");
+      return std::nullopt;
+    }
+  }
+  Operation *producerOp = currentValue.getDefiningOp();
+
+  if (auto matchingOp = dyn_cast<T>(producerOp))
+    return matchingOp;
+
+  if (producerOp->getNumOperands() == 0)
+    return std::nullopt;
+
+  return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
 /// Create a layout attribute from the given parameters.
 static xegpu::LayoutAttr
 createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -111,6 +155,28 @@ setDescLayout(transform::TransformRewriter &rewriter,
   return newDescOp;
 }
 
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+                            transform::TransformResults &results,
+                            transform::TransformState &state) {
+
+  auto targetValues = state.getPayloadValues(getTarget());
+  if (!llvm::hasSingleElement(targetValues)) {
+    return emitDefiniteFailure()
+           << "requires exactly one target value handle (got "
+           << llvm::range_size(targetValues) << ")";
+  }
+
+  auto maybeDescOp =
+      findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+  if (!maybeDescOp) {
+    return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+  }
+
+  results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+  return DiagnosedSilenceableFailure::success();
+}
+
 void transform::SetDescLayoutOp::build(OpBuilder &builder,
                                        OperationState &result, Value target,
                                        ArrayRef<OpFoldResult> mixedSgLayout,
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..d23f2ac16429f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -7,6 +7,7 @@
 
 try:
     from ...ir import *
+    from ...dialects import transform
     from .._ods_common import _cext as _ods_cext
     from .._ods_common import (
         MixedValues,
@@ -20,6 +21,26 @@
 from typing import Union, Optional
 
 
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetDescOp(GetDescOp):
+    """Specialization for GetDescOp class."""
+
+    def __init__(
+        self,
+        target: Value,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        desc_type = transform.AnyOpType.get()
+        super().__init__(
+            desc_type,
+            target,
+            loc=loc,
+            ip=ip,
+        )
+
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class SetDescLayoutOp(SetDescLayoutOp):
     """Specialization for SetDescLayoutOp class."""
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..be8b3155be270 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -1,5 +1,30 @@
 // RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
 
+// CHECK-LABEL: @get_desc_op
+func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+  %c0 = arith.constant 0 : index
+  %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+  %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+  // expected-remark @below {{found desc op}}
+  %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+  %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+  %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+  %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+  %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
+    %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+    transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
 // CHECK-LABEL: @set_desc_layout
 func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
   // CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..f83c807f571e1 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
 from mlir.ir import *
 from mlir.dialects import transform
 from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import structured
+from mlir.dialects.transform import AnyValueType
 
 
 def run(f):
@@ -16,6 +16,21 @@ def run(f):
     return f
 
 
+ at run
+def getDescOpDefaultIndex():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("xegpu.dpas"),
+    )
+    with InsertionPoint(sequence.body):
+        operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+        desc_handle = xegpu.GetDescOp(operand)
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: getDescOpDefaultIndex
+    # CHECK: transform.xegpu.get_desc_op %
+
+
 @run
 def setDescLayoutMinimal():
     sequence = transform.SequenceOp(



More information about the Mlir-commits mailing list