[Mlir-commits] [mlir] [mlir][gpu] Introduce the `gpu.conditional_execution` op (PR #78013)
Fabian Mora
llvmlistbot at llvm.org
Mon Jan 15 04:28:00 PST 2024
https://github.com/fabianmcg updated https://github.com/llvm/llvm-project/pull/78013
>From 0f9c861a8c0d60510f2942f7e5a469ee8722174f Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 15 Jan 2024 02:35:04 +0000
Subject: [PATCH 1/2] [mlir][gpu] Introduce the gpu.conditional_execution op
This patch adds the gpu.conditional_execution operation. This operation allows
selecting host or device code depending in the execution context.
For example:
func.func @conditional_execution(%dev: index, %host: index) {
%0 = gpu.conditional_execution device {
gpu.yield %dev : index
} host {
gpu.yield %host : index
} -> index
return
}
// mlir-opt --gpu-resolve-conditional-execution
func.func @conditional_execution(%dev: index, %host: index) {
%0 = scf.execute_region -> index {
scf.yield %host : index
}
return
}
This is a helpful operation combined with gpu.launch, as the kernel outlining
pass copies full symbols when outlining. Before this patch, functions called
from inside a launch op couldn't easily contain GPU operations -if the function
contained GPU ops, it had to be removed from the host module.
---
mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 40 ++++++++
.../mlir/Dialect/GPU/Transforms/Passes.h | 4 +
.../mlir/Dialect/GPU/Transforms/Passes.td | 31 ++++++
mlir/lib/Dialect/GPU/CMakeLists.txt | 1 +
mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 40 ++++++++
.../ResolveConditionalExecution.cpp | 95 +++++++++++++++++++
mlir/test/Dialect/GPU/invalid.mlir | 21 ++++
mlir/test/Dialect/GPU/ops.mlir | 15 +++
.../GPU/resolve-conditional-execution.mlir | 78 +++++++++++++++
9 files changed, 325 insertions(+)
create mode 100644 mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp
create mode 100644 mlir/test/Dialect/GPU/resolve-conditional-execution.mlir
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 71f6a2bc5fa2f8..591ce25c9d8e8a 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2727,4 +2727,44 @@ def GPU_SetCsrPointersOp : GPU_Op<"set_csr_pointers", [GPU_AsyncOpInterface]> {
}];
}
+def GPU_ConditionalExecutionOp : GPU_Op<"conditional_execution", [
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>
+ ]> {
+ let summary = "Executes a region of code based on the surrounding context.";
+ let description = [{
+ The `conditional_execution` operation executes a region of host or device
+ code depending on the surrounding execution context of the operation. If
+ the operation is inside a GPU module or launch operation, it executes the
+ device region; otherwise, it runs the host region.
+
+ This operation can yield a variadic set of results. If the operation yields
+ results, then both regions have to be present. However, if there are no
+ results, then it's valid to implement only one of the regions.
+
+ Examples:
+ ```mlir
+ // Conditional exeution with results.
+ %res = gpu.conditional_execution device {
+ ...
+ gpu.yield %val : i32
+ } host {
+ ...
+ gpu.yield %val : i32
+ } -> i32
+ // Conditional exeution with no results and only the host region.
+ gpu.conditional_execution host {
+ ...
+ gpu.yield
+ }
+ ```
+ }];
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region AnyRegion:$hostRegion, AnyRegion:$deviceRegion);
+ let assemblyFormat = [{
+ (`device` $deviceRegion^)? (`host` $hostRegion^)? attr-dict
+ (`->` type($results)^)?
+ }];
+ let hasVerifier = 1;
+}
+
#endif // GPU_OPS
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 5885facd07541e..62c06cc604aef3 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -63,6 +63,10 @@ void populateGpuShufflePatterns(RewritePatternSet &patterns);
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
void populateGpuAllReducePatterns(RewritePatternSet &patterns);
+/// Collect a set of patterns to rewrite conditional-execution ops within the
+/// GPU dialect.
+void populateGpuConditionalExecutionPatterns(RewritePatternSet &patterns);
+
/// Collect a set of patterns to break down subgroup_reduce ops into smaller
/// ones supported by the target of `size <= maxShuffleBitwidth`, where `size`
/// is the subgroup_reduce value bitwidth.
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3e0f6a3022f935..c694af71296de6 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -250,4 +250,35 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
];
}
+def GpuResolveConditionalExecutionPass :
+ Pass<"gpu-resolve-conditional-execution", ""> {
+ let summary = "Resolve all conditional execution operations";
+ let description = [{
+ This pass searches for all `gpu.conditional_execution` operations and
+ inlines the appropriate region depending on the execution context. If the
+ operation is inside any of the [`gpu.module`, `gpu.func`, `gpu.launch`]
+ operations, then the pass inlines the device region; otherwise, it
+ inlines the host region.
+ Example:
+ ```
+ func.func @conditional_execution(%dev: index, %host: index) {
+ %0 = gpu.conditional_execution device {
+ gpu.yield %dev : index
+ } host {
+ gpu.yield %host : index
+ } -> index
+ return
+ }
+ // mlir-opt --gpu-resolve-conditional-execution
+ func.func @conditional_execution(%dev: index, %host: index) {
+ %0 = scf.execute_region -> index {
+ scf.yield %host : index
+ }
+ return
+ }
+ ```
+ }];
+ let dependentDialects = ["scf::SCFDialect"];
+}
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index e5776e157b612c..9692bda34269db 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -58,6 +58,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ModuleToBinary.cpp
Transforms/NVVMAttachTarget.cpp
Transforms/ParallelLoopMapper.cpp
+ Transforms/ResolveConditionalExecution.cpp
Transforms/ROCDLAttachTarget.cpp
Transforms/SerializeToBlob.cpp
Transforms/SerializeToCubin.cpp
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 020900934c9f72..ef8f3f80a2f553 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2204,6 +2204,46 @@ LogicalResult gpu::DynamicSharedMemoryOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// ConditionalExecutionOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConditionalExecutionOp::verify() {
+ Region &devRegion = getDeviceRegion();
+ Region &hostRegion = getHostRegion();
+ if (devRegion.empty() && hostRegion.empty())
+ return emitError("both regions can't be empty");
+ if (getResults().size() > 0 && (devRegion.empty() || hostRegion.empty()))
+ return emitError(
+ "when there are results both regions have to be specified");
+ if ((!devRegion.empty() &&
+ !mlir::isa<YieldOp>(devRegion.back().getTerminator())) ||
+ (!hostRegion.empty() &&
+ !mlir::isa<YieldOp>(hostRegion.back().getTerminator()))) {
+ return emitError(
+ "conditional execution regions must terminate with gpu.yield");
+ }
+ return success();
+}
+
+void ConditionalExecutionOp::getSuccessorRegions(
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Both sub-regions always return to the parent.
+ if (!point.isParent()) {
+ regions.push_back(RegionSuccessor(getResults()));
+ return;
+ }
+
+ Region &devRegion = getDeviceRegion();
+ Region &hostRegion = getHostRegion();
+
+ // Don't consider the regions if they are empty.
+ regions.push_back(devRegion.empty() ? RegionSuccessor()
+ : RegionSuccessor(&devRegion));
+ regions.push_back(hostRegion.empty() ? RegionSuccessor()
+ : RegionSuccessor(&hostRegion));
+}
+
//===----------------------------------------------------------------------===//
// GPU target options
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp b/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp
new file mode 100644
index 00000000000000..6861a66435ba12
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp
@@ -0,0 +1,95 @@
+//===- ResolveConditionalExecution.cpp - Resolve conditional exec ops ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the `gpu-resolve-conditional-execution` pass.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPURESOLVECONDITIONALEXECUTIONPASS
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+namespace {
+class GpuResolveConditionalExecutionPass
+ : public impl::GpuResolveConditionalExecutionPassBase<
+ GpuResolveConditionalExecutionPass> {
+public:
+ using Base::Base;
+ void runOnOperation() final;
+};
+} // namespace
+
+void GpuResolveConditionalExecutionPass::runOnOperation() {
+ RewritePatternSet patterns(&getContext());
+ mlir::populateGpuConditionalExecutionPatterns(patterns);
+ if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+}
+
+namespace {
+struct GpuConditionalExecutionOpRewriter
+ : public OpRewritePattern<ConditionalExecutionOp> {
+ using OpRewritePattern<ConditionalExecutionOp>::OpRewritePattern;
+ // Check whether the operation is inside a device execution context.
+ bool isDevice(Operation *op) const {
+ while ((op = op->getParentOp()))
+ if (isa<GPUFuncOp, LaunchOp, GPUModuleOp>(op))
+ return true;
+ return false;
+ }
+ LogicalResult matchAndRewrite(ConditionalExecutionOp op,
+ PatternRewriter &rewriter) const override {
+ bool isDev = isDevice(op);
+ // Remove the op if the device region is empty and we are in a device
+ // context.
+ if (isDev && op.getDeviceRegion().empty()) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+ // Remove the op if the host region is empty and we are in a host context.
+ if (!isDev && op.getHostRegion().empty()) {
+ rewriter.eraseOp(op);
+ return success();
+ }
+ // Replace `ConditionalExecutionOp` with a `scf::ExecuteRegionOp`.
+ auto execRegionOp = rewriter.create<scf::ExecuteRegionOp>(
+ op.getLoc(), op.getResults().getTypes());
+ if (isDev)
+ rewriter.inlineRegionBefore(op.getDeviceRegion(),
+ execRegionOp.getRegion(),
+ execRegionOp.getRegion().begin());
+ else
+ rewriter.inlineRegionBefore(op.getHostRegion(), execRegionOp.getRegion(),
+ execRegionOp.getRegion().begin());
+ rewriter.eraseOp(op);
+ // This is safe because `ConditionalExecutionOp` always terminates with
+ // `gpu::YieldOp`
+ auto yieldOp =
+ dyn_cast<YieldOp>(execRegionOp.getRegion().back().getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, yieldOp.getValues());
+ return success();
+ }
+};
+} // namespace
+
+void mlir::populateGpuConditionalExecutionPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<GpuConditionalExecutionOpRewriter>(patterns.getContext());
+}
diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir
index 4d3a898fdd1565..920cca98296eb7 100644
--- a/mlir/test/Dialect/GPU/invalid.mlir
+++ b/mlir/test/Dialect/GPU/invalid.mlir
@@ -818,3 +818,24 @@ func.func @main(%arg0 : index) {
return
}
+// -----
+
+func.func @conditional_execution(%sz : index) {
+ // @expected-error at +1 {{when there are results both regions have to be specified}}
+ %val = gpu.conditional_execution device {
+ gpu.yield %sz: index
+ } -> index
+ return
+}
+
+// -----
+
+func.func @conditional_execution(%sz : index) {
+ // @expected-error at +1 {{'gpu.conditional_execution' op region control flow edge from Region #0 to parent results: source has 0 operands, but target successor needs 1}}
+ %val = gpu.conditional_execution device {
+ gpu.yield %sz: index
+ } host {
+ gpu.yield
+ } -> index
+ return
+}
diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index 5e60d91e475795..cccaa39c22834a 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -423,3 +423,18 @@ gpu.module @module_with_two_target [#nvvm.target, #rocdl.target<chip = "gfx90a">
gpu.return
}
}
+
+func.func @conditional_execution(%sz : index) {
+ %val = gpu.conditional_execution device {
+ gpu.yield %sz: index
+ } host {
+ gpu.yield %sz: index
+ } -> index
+ gpu.conditional_execution device {
+ gpu.yield
+ }
+ gpu.conditional_execution host {
+ gpu.yield
+ }
+ return
+}
diff --git a/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir b/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir
new file mode 100644
index 00000000000000..5c7420db374a55
--- /dev/null
+++ b/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir
@@ -0,0 +1,78 @@
+// RUN: mlir-opt %s --gpu-resolve-conditional-execution -split-input-file | FileCheck %s
+
+// CHECK-LABEL:func.func @conditional_execution_host
+// CHECK: (%[[DEV:.*]]: index, %[[HOST:.*]]: index)
+func.func @conditional_execution_host(%dev : index, %host : index) {
+ // CHECK: %{{.*}} = scf.execute_region -> index {
+ // CHECK-NEXT: scf.yield %[[HOST]] : index
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ // Test that it returns %host.
+ %v = gpu.conditional_execution device {
+ gpu.yield %dev: index
+ } host {
+ gpu.yield %host: index
+ } -> index
+ return
+}
+
+// -----
+
+// CHECK-LABEL:func.func @conditional_execution_host
+func.func @conditional_execution_host(%memref: memref<f32>) {
+ // CHECK-NEXT: return
+ // CHECK-NEXT: }
+ // Test that the operation gets erased.
+ gpu.conditional_execution device {
+ %c1 = arith.constant 1.0 : f32
+ memref.store %c1, %memref[] : memref<f32>
+ gpu.yield
+ }
+ return
+}
+
+// -----
+
+gpu.module @conditional_execution_dev {
+// CHECK-LABEL:gpu.func @kernel
+// CHECK: (%[[DEV:.*]]: index, %[[HOST:.*]]: index)
+ gpu.func @kernel(%dev : index, %host : index) kernel {
+ // CHECK: %{{.*}} = scf.execute_region -> index {
+ // CHECK-NEXT: scf.yield %[[DEV]] : index
+ // CHECK-NEXT: }
+ // CHECK-NEXT: return
+ // Test that it returns %dev.
+ %v = gpu.conditional_execution device {
+ gpu.yield %dev: index
+ } host {
+ gpu.yield %host: index
+ } -> index
+ gpu.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL:func.func @conditional_execution_dev
+// CHECK: (%[[MEMREF:.*]]: memref<f32>, %[[DEV:.*]]: f32, %[[HOST:.*]]: f32)
+func.func @conditional_execution_dev(%memref: memref<f32>, %fdev: f32, %fhost: f32) {
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%sbx = %c1, %sby = %c1, %sbz = %c1)
+ threads(%tx, %ty, %tz) in (%stx = %c1, %sty = %c1, %stz = %c1) {
+ // CHECK: scf.execute_region {
+ // CHECK-NEXT: memref.store %[[DEV]], %[[MEMREF]][] : memref<f32>
+ // CHECK-NEXT: scf.yield
+ // CHECK-NEXT: }
+ // CHECK-NEXT: gpu.terminator
+ // Test that it uses %fdev.
+ gpu.conditional_execution device {
+ memref.store %fdev, %memref[] : memref<f32>
+ gpu.yield
+ } host {
+ memref.store %fhost, %memref[] : memref<f32>
+ gpu.yield
+ }
+ gpu.terminator
+ }
+ return
+}
>From 67d48c985a5b800460cb293e7a5e4fc72a127e3b Mon Sep 17 00:00:00 2001
From: Fabian Mora <fmora.dev at gmail.com>
Date: Mon, 15 Jan 2024 12:27:29 +0000
Subject: [PATCH 2/2] Fix rewrite bug and test case with results and outlining
---
.../ResolveConditionalExecution.cpp | 11 +++--
.../GPU/resolve-conditional-execution.mlir | 46 +++++++++++++++++++
2 files changed, 54 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp b/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp
index 6861a66435ba12..ff93eb458cc03f 100644
--- a/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/ResolveConditionalExecution.cpp
@@ -77,9 +77,14 @@ struct GpuConditionalExecutionOpRewriter
else
rewriter.inlineRegionBefore(op.getHostRegion(), execRegionOp.getRegion(),
execRegionOp.getRegion().begin());
- rewriter.eraseOp(op);
- // This is safe because `ConditionalExecutionOp` always terminates with
- // `gpu::YieldOp`
+ // Update the calling site.
+ if (op.getResults().empty())
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, execRegionOp);
+
+ // This should be safe because `ConditionalExecutionOp` always terminates
+ // with `gpu::YieldOp`.
auto yieldOp =
dyn_cast<YieldOp>(execRegionOp.getRegion().back().getTerminator());
rewriter.setInsertionPoint(yieldOp);
diff --git a/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir b/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir
index 5c7420db374a55..ca52086d9cf151 100644
--- a/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir
+++ b/mlir/test/Dialect/GPU/resolve-conditional-execution.mlir
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s --gpu-resolve-conditional-execution -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --gpu-kernel-outlining --gpu-resolve-conditional-execution -split-input-file | FileCheck --check-prefix=LAUNCH %s
// CHECK-LABEL:func.func @conditional_execution_host
// CHECK: (%[[DEV:.*]]: index, %[[HOST:.*]]: index)
@@ -76,3 +77,48 @@ func.func @conditional_execution_dev(%memref: memref<f32>, %fdev: f32, %fhost: f
}
return
}
+
+// -----
+
+// LAUNCH-LABEL: func.func @thread_id() -> index
+// LAUNCH: %[[HOST_ID:.*]] = arith.constant 0 : index
+// LAUNCH-NEXT: %[[HOST_RES:.*]] = scf.execute_region -> index {
+// LAUNCH-NEXT: scf.yield %[[HOST_ID]] : index
+// LAUNCH-NEXT: }
+// LAUNCH-NEXT: return %[[HOST_RES]] : index
+func.func @thread_id() -> index {
+ %val = gpu.conditional_execution device {
+ %id = gpu.thread_id x
+ gpu.yield %id: index
+ } host {
+ %id = arith.constant 0 : index
+ gpu.yield %id: index
+ } -> index
+ return %val : index
+}
+// LAUNCH-LABEL: func.func @launch()
+// LAUNCH: gpu.launch_func
+// LAUNCH-NEXT: %{{.*}} = call @thread_id() : () -> index
+// LAUNCH-NEXT: return
+func.func @launch() {
+ %c1 = arith.constant 1 : index
+ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1,
+ %grid_z = %c1)
+ threads(%tx, %ty, %tz) in (%block_x = %c1, %block_y = %c1,
+ %block_z = %c1) {
+ %id = func.call @thread_id() : () -> index
+ gpu.terminator
+ }
+ %id = func.call @thread_id() : () -> index
+ return
+}
+// LAUNCH: gpu.module @[[LAUNCH_ID:.*]] {
+// LAUNCH: gpu.func @[[LAUNCH_ID]]
+// LAUNCH: %{{.*}} = func.call @thread_id() : () -> index
+// LAUNCH-NEXT: gpu.return
+// LAUNCH-LABEL: func.func @thread_id() -> index
+// LAUNCH-NEXT: %[[DEV_RES:.*]] = scf.execute_region -> index {
+// LAUNCH-NEXT: %[[DEV_ID:.*]] = gpu.thread_id x
+// LAUNCH-NEXT: scf.yield %[[DEV_ID]] : index
+// LAUNCH-NEXT: }
+// LAUNCH-NEXT: return %[[DEV_RES]] : index
More information about the Mlir-commits
mailing list