[Mlir-commits] [mlir] 300750d - [MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (#166865)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 11 03:57:58 PST 2025
Author: Tuomas Kärnä
Date: 2025-11-11T11:57:54Z
New Revision: 300750d4bea3fc2a17de13aa26f71aa10f2f5d2f
URL: https://github.com/llvm/llvm-project/commit/300750d4bea3fc2a17de13aa26f71aa10f2f5d2f
DIFF: https://github.com/llvm/llvm-project/commit/300750d4bea3fc2a17de13aa26f71aa10f2f5d2f.diff
LOG: [MLIR][XeGPU][TransformOps] Add set_gpu_launch_threads op (#166865)
Adds `transform.xegpu.set_gpu_launch_threads` that overrides `gpu.launch` operation threads.
Added:
Modified:
mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
mlir/python/mlir/dialects/transform/xegpu.py
mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
mlir/test/Dialect/XeGPU/transform-ops.mlir
mlir/test/python/dialects/transform_xegpu_ext.py
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
index 34f333e556deb..f5e4afad535e5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td
@@ -161,4 +161,43 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
}];
}
+def SetGPULaunchThreadsOp
+ : Op<Transform_Dialect, "xegpu.set_gpu_launch_threads", [
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+ TransformOpInterface
+ ]> {
+
+ let summary = "Set number of threads for a given gpu.launch operation";
+ let description = [{
+ Overrides the x,y,z threads operands of a given `gpu.launch` operation in-place.
+ }];
+
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ Variadic<TransformAnyParamTypeOrAnyHandle>:$threads,
+ DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_threads
+ );
+ let results = (outs);
+ let builders = [
+ OpBuilder<(ins "Value":$target, "ArrayRef<OpFoldResult>":$mixedThreads)>,
+ ];
+
+ let assemblyFormat = [{
+ $target
+ `threads` `=` custom<DynamicIndexList>($threads, $static_threads)
+ attr-dict `:` qualified(type(operands))
+ }];
+
+ let extraClassDeclaration = [{
+ ::mlir::DiagnosedSilenceableFailure apply(
+ ::mlir::transform::TransformRewriter &rewriter,
+ ::mlir::transform::TransformResults &transformResults,
+ ::mlir::transform::TransformState &state);
+
+ ::llvm::SmallVector<::mlir::OpFoldResult> getMixedThreads() {
+ Builder b(getContext());
+ return getMixedValues(getStaticThreads(), getThreads(), b);
+ }
+ }];
+}
+
#endif // XEGPU_TRANSFORM_OPS
diff --git a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
index 5fdd8534e4e51..7a7a8c9066f09 100644
--- a/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
@@ -341,6 +342,69 @@ void transform::SetOpLayoutAttrOp::getEffects(
modifiesPayload(effects);
}
+void transform::SetGPULaunchThreadsOp::build(
+ OpBuilder &builder, OperationState &ostate, Value target,
+ ArrayRef<OpFoldResult> mixedThreads) {
+ SmallVector<int64_t> staticThreads;
+ SmallVector<Value> dynamicThreads;
+ dispatchIndexOpFoldResults(mixedThreads, dynamicThreads, staticThreads);
+ build(builder, ostate, target.getType(),
+ /*target=*/target,
+ /*threads=*/dynamicThreads,
+ /*static_threads=*/staticThreads);
+}
+
+DiagnosedSilenceableFailure
+transform::SetGPULaunchThreadsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ auto targetOps = state.getPayloadOps(getTarget());
+ if (!llvm::hasSingleElement(targetOps)) {
+ return emitDefiniteFailure() << "Requires exactly one targetOp handle (got "
+ << llvm::range_size(targetOps) << ")";
+ }
+ Operation *target = *targetOps.begin();
+
+ auto launchOp = dyn_cast<gpu::LaunchOp>(target);
+ if (!launchOp) {
+ auto diag = emitSilenceableFailure(getLoc())
+ << "Expected a gpu.launch op, but got: " << target->getName();
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
+
+ SmallVector<int32_t> threads;
+ DiagnosedSilenceableFailure status =
+ convertMixedValuesToInt(state, (*this), threads, getMixedThreads());
+ if (!status.succeeded())
+ return status;
+
+ if (threads.size() != 3) {
+ return emitSilenceableFailure(getLoc())
+ << "Expected threads argument to consist of three values (got "
+ << threads.size() << ")";
+ }
+
+ rewriter.setInsertionPoint(launchOp);
+ auto createConstValue = [&](int value) {
+ return arith::ConstantIndexOp::create(rewriter, launchOp.getLoc(), value);
+ };
+
+ // Replace threads in-place.
+ launchOp.getBlockSizeXMutable().assign(createConstValue(threads[0]));
+ launchOp.getBlockSizeYMutable().assign(createConstValue(threads[1]));
+ launchOp.getBlockSizeZMutable().assign(createConstValue(threads[2]));
+
+ return DiagnosedSilenceableFailure::success();
+}
+
+void transform::SetGPULaunchThreadsOp::getEffects(
+ ::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ onlyReadsHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getThreadsMutable(), effects);
+ modifiesPayload(effects);
+}
+
namespace {
class XeGPUTransformDialectExtension
: public transform::TransformDialectExtension<
diff --git a/mlir/python/mlir/dialects/transform/xegpu.py b/mlir/python/mlir/dialects/transform/xegpu.py
index ce8015d8f557b..309883cfc4518 100644
--- a/mlir/python/mlir/dialects/transform/xegpu.py
+++ b/mlir/python/mlir/dialects/transform/xegpu.py
@@ -132,3 +132,39 @@ def __init__(
loc=loc,
ip=ip,
)
+
+
+class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
+ """Specialization for SetGPULaunchThreadsOp class."""
+
+ def __init__(
+ self,
+ launch_op: Union[Operation, Value],
+ threads: MixedValues,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ (
+ dynamic_threads,
+ static_threads,
+ _,
+ ) = _dispatch_dynamic_index_list(threads)
+
+ super().__init__(
+ _get_op_result_or_value(launch_op),
+ dynamic_threads,
+ static_threads=static_threads,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def set_gpu_launch_threads(
+ launch_op: Union[Operation, Value],
+ threads: MixedValues,
+ *,
+ loc=None,
+ ip=None,
+) -> SetGPULaunchThreadsOp:
+ return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
index 726b6748452ae..24f500658f740 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir
@@ -71,3 +71,56 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_handle(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index // expected-note {{target op}}
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Expected a gpu.launch op, but got: arith.constant}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_many_handles(%arg0: memref<4096x4096xf16>) {
+ %c32 = arith.constant 32 : index
+ %c64 = arith.constant 64 : index
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Requires exactly one targetOp handle (got 2)}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @set_gpu_launch_threads_bad_threads(%arg0: memref<4096x4096xf16>) {
+ %c1 = arith.constant 1 : index
+ %c16 = arith.constant 16 : index
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // expected-error at below {{Expected threads argument to consist of three values (got 2)}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4] : !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/transform-ops.mlir b/mlir/test/Dialect/XeGPU/transform-ops.mlir
index bd6a79244ed30..7f2fbe4271a43 100644
--- a/mlir/test/Dialect/XeGPU/transform-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/transform-ops.mlir
@@ -230,6 +230,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
// -----
// CHECK-LABEL: @set_op_layout_attr_operand1
@@ -252,3 +253,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads
+func.func @set_gpu_launch_threads(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C16:.+]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+ // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ }
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, 4, 1] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @set_gpu_launch_threads_param
+func.func @set_gpu_launch_threads_param(%arg0: memref<4096x4096xf16>) {
+ // CHECK: %[[C1:.+]] = arith.constant 1 : index
+ %c1 = arith.constant 1 : index
+ // CHECK: %[[C16:.+]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[C8:.+]] = arith.constant 8 : index
+ // CHECK: %[[C4:.+]] = arith.constant 4 : index
+ // CHECK: %[[C1_0:.+]] = arith.constant 1 : index
+ // CHECK: gpu.launch blocks(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C16]], %{{.*}} = %[[C16]], %{{.*}} = %[[C1]])
+ // CHECK-SAME: threads(%{{.*}}, %{{.*}}, %{{.*}}) in (%{{.*}} = %[[C8]], %{{.*}} = %[[C4]], %{{.*}} = %[[C1_0]])
+ gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg9 = %c16, %arg10 = %c16, %arg11 = %c1) threads(%arg6, %arg7, %arg8) in (%arg12 = %c1, %arg13 = %c1, %arg14 = %c1) {
+ gpu.terminator
+ }
+ return
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["gpu.launch"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ // CHECK: transform.xegpu.set_gpu_launch_threads %{{.*}}
+ %th1 = transform.param.constant 4 : i64 -> !transform.param<i64>
+ transform.xegpu.set_gpu_launch_threads %0 threads = [8, %th1, 1] : !transform.any_op, !transform.param<i64>
+ transform.yield
+ }
+}
diff --git a/mlir/test/python/dialects/transform_xegpu_ext.py b/mlir/test/python/dialects/transform_xegpu_ext.py
index 0b587d2020aa6..dc91f5e982579 100644
--- a/mlir/test/python/dialects/transform_xegpu_ext.py
+++ b/mlir/test/python/dialects/transform_xegpu_ext.py
@@ -113,3 +113,18 @@ def setOpLayoutAttrResult():
# CHECK: sg_layout = [6, 4]
# CHECK: sg_data = [32, 16]
# CHECK: inst_data = [8, 16]
+
+
+ at run
+def setGPULaunchThreadsOp():
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.OperationType.get("gpu.launch"),
+ )
+ with InsertionPoint(sequence.body):
+ xegpu.set_gpu_launch_threads(sequence.bodyTarget, threads=[8, 4, 1])
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: setGPULaunchThreadsOp
+ # CHECK: transform.xegpu.set_gpu_launch_threads
+ # CHECK: threads = [8, 4, 1]
More information about the Mlir-commits
mailing list