[Mlir-commits] [mlir] [MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (PR #166865)

Tuomas Kärnä llvmlistbot at llvm.org
Mon Nov 10 04:08:00 PST 2025


https://github.com/tkarna updated https://github.com/llvm/llvm-project/pull/166865

>From 5fb499a0fee081c17a82341e18da341fbf7d8dc6 Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Fri, 31 Oct 2025 21:27:06 +0200
Subject: [PATCH 1/3] [mlir][xegpu][transformops] add set_gpu_launch_threads op

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 37 +++++++++++
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  | 64 +++++++++++++++++++
 mlir/python/mlir/dialects/transform/xegpu.py  | 27 ++++++++
 .../Dialect/XeGPU/transform-ops-invalid.mlir  | 53 +++++++++++++++
 mlir/test/Dialect/XeGPU/transform-ops.mlir    | 55 ++++++++++++++++
 .../python/dialects/transform_xegpu_ext.py    | 15 +++++
 6 files changed, 251 insertions(+)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index b985d5450be0e..0b138d13d8aee 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -78,4 +78,41 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 }
 
+def SetGPULaunchThreadsOp
+    : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
+      DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+      TransformOpInterface
+    ]> {
+
+  let summary = "Set number of threads for a given gpu.launch operation";
+  let description = "Set number of threads for a given `gpu.launch` operation.";
+
+  let arguments = (ins TransformHandleTypeInterface : $target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
+                   DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
+                   );
+  let results = (outs);
+  let builders = [
+    OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
+  ];
+
+  let assemblyFormat = [{
+    $target
+    `threads` `=` custom<DynamicIndexList>($threads, $static_threads)
+    attr-dict `:` qualified(type(operands))
+  }];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
+      Builder b(getContext());
+      return getMixedValues(getStaticThreads(), getThreads(), b);
+    }
+  }];
+}
+
 #endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 8943ba09d9c34..73c3776a75276 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -193,6 +194,69 @@ void transform::SetDescLayoutOp::getEffects(
   modifiesPayload(effects);
 }
 
+void transform::SetGPULaunchThreadsOp::build(
+    OpBuilder &builder, OperationState &ostate, Value target,
+    ArrayRef<OpFoldResult> mixedThreads) {
+  SmallVector<int64_t> staticThreads;
+  SmallVector<Value> dynamicThreads;
+  dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+  build(builder, ostate, target.getType(),
+        /*target=*/target,
+        /*threads=*/dynamicThreads,
+        /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+
+  auto targetOps = state.getPayloadOps(getTarget());
+  if (!llvm::hasSingleElement(targetOps)) {
+    return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+                                 << llvm::range_size(targetOps) << ")";
+  }
+  Operation *target = *targetOps.begin();
+
+  auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+  if (!launchOp) {
+    auto diag = emitSilenceableFailure(getLoc())
+                << "Expected a gpu.launch op, but got: " << target->getName();
+    diag.attachNote(target->getLoc()) << "target op";
+    return diag;
+  }
+
+  SmallVector<int32_t> threads;
+  DiagnosedSilenceableFailure status =
+      convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+  if (!status.succeeded())
+    return status;
+
+  if (threads.size() != 3) {
+    return emitSilenceableFailure(getLoc())
+           << "Expected threads to be a 3D vector";
+  }
+
+  rewriter.setInsertionPoint(launchOp);
+  auto createConstValue = [&](int value) {
+    return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+  };
+
+  // Replace threads in-place.
+  launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+  launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+  launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+    ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  onlyReadsHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getThreadsMutable(), effects);
+  modifiesPayload(effects);
+}
+
 namespace {
 class XeGPUTransformDialectExtension
     : public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index 2918bf592880a..dce0e0e97c01f 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -64,3 +64,30 @@ def __init__(
             loc=loc,
             ip=ip,
         )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
+    """Specialization for SetGPULaunchThreadsOp class."""
+
+    def __init__(
+        self,
+        launch_op: Union[Operation, Value],
+        threads: MixedValues,
+        *,
+        loc=None,
+        ip=None,
+    ):
+        (
+            dynamic_threads,
+            static_threads,
+            _,
+        ) = _dispatch_dynamic_index_list(threads)
+
+        super().__init__(
+            _get_op_result_or_value(launch_op),
+            dynamic_threads,
+            static_threads=static_threads,
+            loc=loc,
+            ip=ip,
+        )
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 303584518f9f4..be6943c5b8db9 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -13,3 +13,56 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index // expected-note {{target op}}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Expected a gpu.launch op, but got: arith.constant}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
+  %c32 = arith.constant 32 : index
+  %c64 = arith.constant 64 : index
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
+  %c1 = arith.constant 1 : index
+  %c16 = arith.constant 16 : index
+  gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+    gpu.terminator
+  } {SCFToGPU_visited}
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // expected-error at below {{Expected threads to be a 3D vector}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index 23e1cd946b4cd..cb8d1e9afcc6e 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -56,3 +56,58 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads
+func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[C16:.+]] = arith.constant 16 : index
+  %c16 = arith.constant 16 : index
+  // CHECK: %[[C8:.+]] = arith.constant 8 : index
+  // CHECK: %[[C4:.+]] = arith.constant 4 : index
+  // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+  // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+  // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+  gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+    gpu.terminator
+  } {SCFToGPU_visited}
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads_param
+func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
+  // CHECK: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[C16:.+]] = arith.constant 16 : index
+  %c16 = arith.constant 16 : index
+  // CHECK: %[[C8:.+]] = arith.constant 8 : index
+  // CHECK: %[[C4:.+]] = arith.constant 4 : index
+  // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+  // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+  // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+  gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+    gpu.terminator
+  } {SCFToGPU_visited}
+  return
+}
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+    %th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+    transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
+    transform.yield
+  }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 1c8a2bcc6a2fb..ff8506c749f0f 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -49,3 +49,18 @@ def setDescLayoutInstData():
     # CHECK: sg_layout = [6, 4]
     # CHECK: sg_data = [32, 16]
     # CHECK: inst_data = [8, 16]
+
+
+ at run
+def setGPULaunchThreadsOp():
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate,
+        [],
+        transform.OperationType.get("gpu.lauch"),
+    )
+    with InsertionPoint(sequence.body):
+        xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: setGPULaunchThreadsOp
+    # CHECK: transform.xegpu.set_gpu_launch_threads
+    # CHECK: threads = [8, 4, 1]

>From 4c08b46dde8a4b2e2ba422e70c622291ead1261e Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Mon, 10 Nov 2025 13:38:19 +0200
Subject: [PATCH 2/3] address review comments

---
 .../XeGPU/TransformOps/XeGPUTransformOps.td   | 20 ++++++++++---------
 .../XeGPU/TransformOps/XeGPUTransformOps.cpp  |  4 ++--
 .../python/dialects/transform_xegpu_ext.py    |  2 +-
 3 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 0b138d13d8aee..5f747c4f1f708 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -31,16 +31,16 @@ def SetDescLayoutOp : Op<Transform_Dialect, "xegpu.set_desc_layout", [
   }];
 
   let arguments = (ins
-                   TransformHandleTypeInterface : $target,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_layout,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $sg_data,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $inst_data,
+                   TransformHandleTypeInterface:$target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_layout,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$sg_data,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$inst_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_layout,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sg_data,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_inst_data
                    );
 
-  let results = (outs TransformHandleTypeInterface : $transformed);
+  let results = (outs TransformHandleTypeInterface:$transformed);
   let builders = [
     OpBuilder<(ins "Value":$target,
                    "ArrayRef<OpFoldResult>":$mixedSgLayout,
@@ -84,11 +84,13 @@ def SetGPULaunchThreadsOp
       TransformOpInterface
     ]> {
 
-  let summary = "Set number of threads for a given gpu.launch operation";
-  let description = "Set number of threads for a given `gpu.launch` operation.";
+  let summary = "Set number of threads for a given gpu.launch operation"; let
+  description = [{
+    Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
+  }];
 
-  let arguments = (ins TransformHandleTypeInterface : $target,
-                   Variadic<TransformAnyParamTypeOrAnyHandle> : $threads,
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                   Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
                    DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
                    );
   let results = (outs);
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 73c3776a75276..ed067148409c3 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -210,7 +210,6 @@ DiagnosedSilenceableFailure
 transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
                                         transform::TransformResults &results,
                                         transform::TransformState &state) {
-
   auto targetOps = state.getPayloadOps(getTarget());
   if (!llvm::hasSingleElement(targetOps)) {
     return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
@@ -234,7 +233,8 @@ transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
 
   if (threads.size() != 3) {
     return emitSilenceableFailure(getLoc())
-           << "Expected threads to be a 3D vector";
+           << "Expected threads argument to consist of three values (got "
+           << threads.size() << ")";
   }
 
   rewriter.setInsertionPoint(launchOp);
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index ff8506c749f0f..23b284766dbf0 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -56,7 +56,7 @@ def setGPULaunchThreadsOp():
     sequence = transform.SequenceOp(
         transform.FailurePropagationMode.Propagate,
         [],
-        transform.OperationType.get("gpu.lauch"),
+        transform.OperationType.get("gpu.launch"),
     )
     with InsertionPoint(sequence.body):
         xegpu.SetGPULaunchThreadsOp(sequence.bodyTarget, threads=[8, 4, 1])

>From 505c8a966c5b8ed10a4c93e01b4bba0b4ec4f18b Mon Sep 17 00:00:00 2001
From: Tuomas Karna <tuomas.karna at intel.com>
Date: Mon, 10 Nov 2025 14:07:38 +0200
Subject: [PATCH 3/3] fix test

---
 mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index be6943c5b8db9..12bd9dd656c61 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -61,7 +61,7 @@ func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    // expected-error at below {{Expected threads to be a 3D vector}}
+    // expected-error at below {{Expected threads argument to consist of three values (got 2)}}
     transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
     transform.yield
   }



More information about the Mlir-commits mailing list