[llvm-branch-commits] [mlir] [MLIR][OpenMP] Support target SPMD (PR #127821)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Feb 21 02:11:04 PST 2025
https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/127821
>From 27139f8f6260de93a0e6d6163b9562c7daa451b8 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 19 Feb 2025 14:41:12 +0000
Subject: [PATCH 1/2] [MLIR][OpenMP] Support target SPMD
This patch implements MLIR to LLVM IR translation of host-evaluated loop
bounds, completing initial support for `target teams distribute parallel do
[simd]` and `target teams distribute [simd]`.
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 83 ++++++++++++----
.../Target/LLVMIR/openmp-target-spmd.mlir | 96 +++++++++++++++++++
mlir/test/Target/LLVMIR/openmp-todo.mlir | 24 -----
3 files changed, 159 insertions(+), 44 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index fbea278b2511f..9d07bf7b5d224 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getHint())
op.emitWarning("hint clause discarded");
};
- auto checkHostEval = [](auto op, LogicalResult &result) {
- // Host evaluated clauses are supported, except for loop bounds.
- for (BlockArgument arg :
- cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
- for (Operation *user : arg.getUsers())
- if (isa<omp::LoopNestOp>(user))
- result = op.emitError("not yet implemented: host evaluation of loop "
- "bounds in omp.target operation");
- };
auto checkInReduction = [&todo](auto op, LogicalResult &result) {
if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
op.getInReductionSyms())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkBare(op, result);
checkDevice(op, result);
checkHasDeviceAddr(op, result);
- checkHostEval(op, result);
checkInReduction(op, result);
checkIsDevicePtr(op, result);
checkPrivate(op, result);
@@ -4053,9 +4043,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
-static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
- Value &numTeamsLower, Value &numTeamsUpper,
- Value &threadLimit) {
+static void
+extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+ Value &numTeamsLower, Value &numTeamsUpper,
+ Value &threadLimit,
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *steps = nullptr) {
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
@@ -4080,11 +4074,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::LoopNestOp loopOp) {
- // TODO: Extract bounds and step values. Currently, this cannot be
- // reached because translation would have been stopped earlier as a
- // result of `checkImplementationStatus` detecting and reporting
- // this situation.
- llvm_unreachable("unsupported host_eval use");
+ auto processBounds =
+ [&](OperandRange opBounds,
+ llvm::SmallVectorImpl<Value> *outBounds) -> bool {
+ bool found = false;
+ for (auto [i, lb] : llvm::enumerate(opBounds)) {
+ if (lb == blockArg) {
+ found = true;
+ if (outBounds)
+ (*outBounds)[i] = hostEvalVar;
+ }
+ }
+ return found;
+ };
+ bool found =
+ processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
+ found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
+ found;
+ found = processBounds(loopOp.getLoopSteps(), steps) || found;
+ if (!found)
+ llvm_unreachable("unsupported host_eval use");
})
.Default([](Operation *) {
llvm_unreachable("unsupported host_eval use");
@@ -4221,6 +4230,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
combinedMaxThreadsVal = maxThreadsVal;
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
+ attrs.ExecFlags = targetOp.getKernelExecFlags();
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -4238,9 +4248,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
omp::TargetOp targetOp,
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
+ omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
+ targetOp.getInnermostCapturedOmpOp());
+ unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
+
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
+ steps(numLoops);
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- teamsThreadLimit);
+ teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4260,7 +4276,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
- // TODO: Populate attrs.LoopTripCount if it is target SPMD.
+ if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ attrs.LoopTripCount = nullptr;
+
+ // To calculate the trip count, we multiply together the trip counts of
+ // every collapsed canonical loop. We don't need to create the loop nests
+ // here, since we're only interested in the trip count.
+ for (auto [loopLower, loopUpper, loopStep] :
+ llvm::zip_equal(lowerBounds, upperBounds, steps)) {
+ llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
+ llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
+ llvm::Value *step = moduleTranslation.lookupValue(loopStep);
+
+ llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+ llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
+ loc, lowerBound, upperBound, step, /*IsSigned=*/true,
+ loopOp.getLoopInclusive());
+
+ if (!attrs.LoopTripCount) {
+ attrs.LoopTripCount = tripCount;
+ continue;
+ }
+
+ // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
+ attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
+ {}, /*HasNUW=*/true);
+ }
+ }
}
static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
new file mode 100644
index 0000000000000..7930554cbe11a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
@@ -0,0 +1,96 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @main(%x : i32) {
+ omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.teams {
+ omp.parallel {
+ omp.distribute {
+ omp.wsloop {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// HOST-LABEL: define void @main
+// HOST: %omp_loop.tripcount = {{.*}}
+// HOST-NEXT: br label %[[ENTRY:.*]]
+// HOST: [[ENTRY]]:
+// HOST-NEXT: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST: [[OFFLOAD_FAILED]]:
+// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[TARGET_OUTLINE]]
+// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST: define internal void @[[TEAMS_OUTLINE]]
+// HOST: call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+
+// HOST: define internal void @[[PARALLEL_OUTLINE]]
+// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST: call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+
+//--- device.mlir
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+ llvm.func @main(%x : i32) {
+ omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+ omp.teams {
+ omp.parallel {
+ omp.distribute {
+ omp.wsloop {
+ omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+ omp.yield
+ }
+ } {omp.composite}
+ } {omp.composite}
+ omp.terminator
+ } {omp.composite}
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// DEVICE: @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
+// DEVICE: @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// DEVICE: @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
+// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// DEVICE: define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// DEVICE: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// DEVICE: call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// DEVICE: call void @__kmpc_target_deinit()
+
+// DEVICE: define internal void @[[TARGET_OUTLINE]]({{.*}})
+// DEVICE: call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// DEVICE: define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// DEVICE: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// DEVICE: define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// DEVICE: call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d1c745af9bff5..f907bb3f94a2a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
// -----
-llvm.func @target_host_eval(%x : i32) {
- // expected-error at below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
- // expected-error at below {{LLVM Translation failed for operation: omp.target}}
- omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
- omp.teams {
- omp.parallel {
- omp.distribute {
- omp.wsloop {
- omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
- omp.yield
- }
- } {omp.composite}
- } {omp.composite}
- omp.terminator
- } {omp.composite}
- omp.terminator
- }
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
>From d5cd2ca066cfc27f77ac5d1fcc00050b3aa5307f Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 21 Feb 2025 10:10:41 +0000
Subject: [PATCH 2/2] Address review comment
---
.../LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9d07bf7b5d224..6df061e0efdbe 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4092,8 +4092,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
found;
found = processBounds(loopOp.getLoopSteps(), steps) || found;
- if (!found)
- llvm_unreachable("unsupported host_eval use");
+ (void)found;
+ assert(found && "unsupported host_eval use");
})
.Default([](Operation *) {
llvm_unreachable("unsupported host_eval use");
More information about the llvm-branch-commits
mailing list