[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add get_desc_op (PR #166801)
Tuomas Kärnä
llvmlistbot at llvm.org
Mon Nov 10 02:37:52 PST 2025
https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/166801
>From 88978917df16af46442ddf6ce16f10de6a9046f0 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Thu, 30 Oct 2025 08:35:32 +0200
Subject: [PATCH 1/2] [mlir][xegpu][transformops] add get_desc_op
---
.../XeGPU/TransformOps/XeGPUTransformOps.td | 17 +++++
.../XeGPU/TransformOps/XeGPUTransformOps.cpp | 66 +++++++++++++++++++
mlir/python/mlir/dialects/transform/xegpu.py | 21 ++++++
mlir/test/Dialect/XeGPU/transform-ops.mlir | 25 +++++++
.../python/dialects/transform_xegpu_ext.py | 17 ++++-
5 files changed, 145 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..199bd2024c373 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -16,6 +16,23 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"
+def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ NavigationTransformOpTrait, MemoryEffectsOpInterface
+]> {
+
+ let summary = "Get a handle to the descriptor op of a value.";
+ let description = [{
+ Traces the producers of the given value until an `xegpu.create_nd_tdesc`
+ descriptor op is found. Returns a handle to it.
+ }];
+
+ let arguments = (ins TransformValueHandleTypeInterface : $target);
+
+ let results = (outs TransformHandleTypeInterface : $descHandle);
+ let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
+}
+
def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..a4aaed14dfff6 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -13,6 +13,9 @@
#include <optional>
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "xegpu-transforms"
+
using namespace mlir;
using namespace mlir::transform;
@@ -76,6 +79,47 @@ static DiagnosedSilenceableFailure convertMixedValuesToInt(
return DiagnosedSilenceableFailure::success();
}
+/// Find producer operation of type T for the given value.
+/// It's assumed that producer ops are chained through their first operand.
+/// Producer chain is traced trough loop block arguments (init values).
+template <typename T>
+static std::optional<T> findProducerOfType(Value val) {
+ Value currentValue = val;
+ if (!currentValue.getDefiningOp()) {
+ // Value may be a block argument initialized outside a loop.
+ if (val.getNumUses() == 0) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Failed to find producer op, value has no uses.");
+ return std::nullopt;
+ }
+ auto userOp = val.getUsers().begin();
+ auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
+ if (!parentLoop) {
+ LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
+ return std::nullopt;
+ }
+ int64_t iterArgIdx;
+ if (auto iterArg = llvm::dyn_cast<BlockArgument>(currentValue)) {
+ auto numInductionVars = parentLoop.getLoopInductionVars()->size();
+ iterArgIdx = iterArg.getArgNumber() - numInductionVars;
+ currentValue = parentLoop.getInits()[iterArgIdx];
+ } else {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Failed to find producer op, value not in init values.");
+ return std::nullopt;
+ }
+ }
+ Operation *producerOp = currentValue.getDefiningOp();
+
+ if (auto matchingOp = dyn_cast<T>(producerOp))
+ return matchingOp;
+
+ if (producerOp->getNumOperands() == 0)
+ return std::nullopt;
+
+ return findProducerOfType<T>(producerOp->getOperand(0));
+}
+
/// Create a layout attribute from the given parameters.
static xegpu::LayoutAttr
createLayoutAttr(MLIRContext *ctx, ArrayRef<int32_t> sgLayout,
@@ -111,6 +155,28 @@ setDescLayout(transform::TransformRewriter &rewriter,
return newDescOp;
}
+DiagnosedSilenceableFailure
+transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+
+ auto targetValues = state.getPayloadValues(getTarget());
+ if (!llvm::hasSingleElement(targetValues)) {
+ return emitDefiniteFailure()
+ << "requires exactly one target value handle (got "
+ << llvm::range_size(targetValues) << ")";
+ }
+
+ auto maybeDescOp =
+ findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
+ if (!maybeDescOp) {
+ return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ }
+
+ results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
+ return DiagnosedSilenceableFailure::success();
+}
+
void transform::SetDescLayoutOp::build(OpBuilder &builder,
OperationState &result, Value target,
ArrayRef<OpFoldResult> mixedSgLayout,
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..d23f2ac16429f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -7,6 +7,7 @@
try:
from ...ir import *
+ from ...dialects import transform
from .._ods_common import _cext as _ods_cext
from .._ods_common import (
MixedValues,
@@ -20,6 +21,26 @@
from typing import Union, Optional
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GetDescOp(GetDescOp):
+ """Specialization for GetDescOp class."""
+
+ def __init__(
+ self,
+ target: Value,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ desc_type = transform.AnyOpType.get()
+ super().__init__(
+ desc_type,
+ target,
+ loc=loc,
+ ip=ip,
+ )
+
+
@_ods_cext.register_operation(_Dialect, replace=True)
class SetDescLayoutOp(SetDescLayoutOp):
"""Specialization for SetDescLayoutOp class."""
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..be8b3155be270 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -1,5 +1,30 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+// CHECK-LABEL: @get_desc_op
+func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ // expected-remark @below {{found desc op}}
+ %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
// CHECK-LABEL: @set_desc_layout
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..f83c807f571e1 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -3,7 +3,7 @@
from mlir.ir import *
from mlir.dialects import transform
from mlir.dialects.transform import xegpu
-from mlir.dialects.transform import structured
+from mlir.dialects.transform import AnyValueType
def run(f):
@@ -16,6 +16,21 @@ def run(f):
return f
+ at run
+def getDescOpDefaultIndex():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("xegpu.dpas"),
+ )
+ with InsertionPoint(sequence.body):
+ operand = transform.GetOperandOp(AnyValueType.get(), sequence.bodyTarget, [0])
+ desc_handle = xegpu.GetDescOp(operand)
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: getDescOpDefaultIndex
+ # CHECK: transform.xegpu.get_desc_op %
+
+
@run
def setDescLayoutMinimal():
sequence = transform.SequenceOp(
>From a759aa600ca2c5c9a0485466ebff72db8e795007 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Mon, 10 Nov 2025 12:36:48 +0200
Subject: [PATCH 2/2] address review comments
---
.../XeGPU/TransformOps/XeGPUTransformOps.td | 17 +++---
.../XeGPU/TransformOps/XeGPUTransformOps.cpp | 15 +++--
mlir/test/Dialect/XeGPU/transform-ops.mlir | 57 +++++++++++++++----
3 files changed, 63 insertions(+), 26 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 199bd2024c373..ed277ef7bd554 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -24,12 +24,13 @@ def GetDescOp : Op<Transform_Dialect, "xegpu.get_desc_op", [
let summary = "Get a handle to the descriptor op of a value.";
let description = [{
Traces the producers of the given value until an `xegpu.create_nd_tdesc`
- descriptor op is found. Returns a handle to it.
+ descriptor op is found. Returns a handle to it. Currently traces
+ producers by following only the first operand of producer ops.
}];
- let arguments = (ins TransformValueHandleTypeInterface : $target);
+ let arguments = (ins TransformValueHandleTypeInterface:$target);
- let results = (outs TransformHandleTypeInterface : $descHandle);
+ let results = (outs TransformHandleTypeInterface:$descHandle);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}
@@ -48,16 +49,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
}];
let arguments = (ins
- TransformHandleTypeInterface : $target,
- Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
- Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
- Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+ TransformHandleTypeInterface:$target,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
);
- let results = (outs TransformHandleTypeInterface : $transformed);
+ let results = (outs TransformHandleTypeInterface:$transformed);
let builders = [
OpBuilder<(ins "Value":$target,
"ArrayRef<OpFoldResult>":$mixedSgLayout,
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index a4aaed14dfff6..0683699f467e9 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -13,7 +13,7 @@
#include <optional>
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "xegpu-transforms"
using namespace mlir;
@@ -88,14 +88,13 @@ static std::optional<T> findProducerOfType(Value val) {
if (!currentValue.getDefiningOp()) {
// Value may be a block argument initialized outside a loop.
if (val.getNumUses() == 0) {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to find producer op, value has no uses.");
+ LDBG() << "Failed to find producer op, value has no uses.";
return std::nullopt;
}
auto userOp = val.getUsers().begin();
auto parentLoop = userOp->getParentOfType<LoopLikeOpInterface>();
if (!parentLoop) {
- LLVM_DEBUG(llvm::dbgs() << "Failed to find producer op, not in a loop.");
+ LDBG() << "Failed to find producer op, not in a loop.";
return std::nullopt;
}
int64_t iterArgIdx;
@@ -104,8 +103,7 @@ static std::optional<T> findProducerOfType(Value val) {
iterArgIdx = iterArg.getArgNumber() - numInductionVars;
currentValue = parentLoop.getInits()[iterArgIdx];
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "Failed to find producer op, value not in init values.");
+ LDBG() << "Failed to find producer op, value not in init values.";
return std::nullopt;
}
}
@@ -159,7 +157,6 @@ DiagnosedSilenceableFailure
transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
-
auto targetValues = state.getPayloadValues(getTarget());
if (!llvm::hasSingleElement(targetValues)) {
return emitDefiniteFailure()
@@ -170,7 +167,9 @@ transform::GetDescOp::apply(transform::TransformRewriter &rewriter,
auto maybeDescOp =
findProducerOfType<xegpu::CreateNdDescOp>(*targetValues.begin());
if (!maybeDescOp) {
- return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
+ return emitSilenceableFailure(getLoc())
+ << "Could not find a matching descriptor op when walking the "
+ "producer chain of the first operand.";
}
results.set(llvm::cast<OpResult>(getResult()), {*maybeDescOp});
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index be8b3155be270..342de429d2e90 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -1,23 +1,59 @@
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
-// CHECK-LABEL: @get_desc_op
-func.func @get_desc_op(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+// CHECK-LABEL: @get_desc_op_a
+func.func @get_desc_op_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
+ %c0 = arith.constant 0 : index
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ // expected-remark @below {{found desc op}}
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
+ %2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
+ transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @get_desc_op_c
+func.func @get_desc_op_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c4096 = arith.constant 4096 : index
%c0 = arith.constant 0 : index
- %0 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
- %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
// expected-remark @below {{found desc op}}
- %2 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
- %3 = xegpu.load_nd %2[%c0, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
- %4 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
- %5 = xegpu.load_nd %4[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
- %6 = xegpu.dpas %1, %3, %5 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ %0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
+ %1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
+ %3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
+ %4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
+ %2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
+ %5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
+ %6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
+ %7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
+ scf.yield %7 : vector<256x256xf16>
+ }
return
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- %1 = transform.get_operand %0[1] : (!transform.any_op) -> !transform.any_value
+ %1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
%2 = transform.xegpu.get_desc_op %1 : (!transform.any_value) -> !transform.any_op
transform.debug.emit_remark_at %2, "found desc op" : !transform.any_op
transform.yield
@@ -25,6 +61,7 @@ module attributes {transform.with_named_sequence} {
}
// -----
+
// CHECK-LABEL: @set_desc_layout
func.func @set_desc_layout(%arg0: memref<4096x4096xf16>) {
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
More information about the Mlir-commits
mailing list