[llvm] a8d8af3 - [OpenMP][OMPIRBuilder] Collect users of a value before replacing them in target outlined function (#139064)
via llvm-commits
llvm-commits at lists.llvm.org
Wed May 28 08:40:37 PDT 2025
Author: Kareem Ergawy
Date: 2025-05-28T17:40:34+02:00
New Revision: a8d8af3bfa13bf3173e24097ee4017cf7648c5a6
URL: https://github.com/llvm/llvm-project/commit/a8d8af3bfa13bf3173e24097ee4017cf7648c5a6
DIFF: https://github.com/llvm/llvm-project/commit/a8d8af3bfa13bf3173e24097ee4017cf7648c5a6.diff
LOG: [OpenMP][OMPIRBuilder] Collect users of a value before replacing them in target outlined function (#139064)
This PR fixes a crash that curently happens given the following input:
```fortran
subroutine caller()
real :: x
integer :: i
!$omp target
x = i
call callee(x,x)
!$omp end target
endsubroutine caller
subroutine callee(x1,x2)
real :: x1, x2
endsubroutine callee
```
The crash happens because the following sequence of events is taken by
the `OMPIRBuilder`:
1. ....
2. An outlined function for the target region is created. At first the
outlined function still refers to the SSA values from the original
function of the target region.
3. The builder then iterates over the users of SSA values used in the
target region to replace them with the corresponding function arguments
of outlined function.
4. If the same instruction references the SSA value more than once (say
m), all uses of that SSA value are replaced in the instruction. Deleting
all m uses of the value.
5. The next m-1 iterations will still iterate over the same instruction
dropping the last m-1 actual users of the value.
Hence, we collect all users first before modifying them.
Added:
mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir
Modified:
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2c4d3d8fb0a50..ca3d8438654dc 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7089,8 +7089,12 @@ static Expected<Function *> createOutlinedFunction(
if (auto *Const = dyn_cast<Constant>(Input))
convertUsersOfConstantsToInstructions(Const, Func, false);
+ // Collect users before iterating over them to avoid invalidating the
+ // iteration in case a user uses Input more than once (e.g. a call
+ // instruction).
+ SetVector<User *> Users(Input->users().begin(), Input->users().end());
// Collect all the instructions
- for (User *User : make_early_inc_range(Input->users()))
+ for (User *User : make_early_inc_range(Users))
if (auto *Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, InputCopy);
diff --git a/mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir b/mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir
new file mode 100644
index 0000000000000..ebaecc8cf203b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @caller_() {
+ %c1 = llvm.mlir.constant(1 : i64) : i64
+ %x_host = llvm.alloca %c1 x f32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %i_host = llvm.alloca %c1 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+ %x_map = omp.map.info var_ptr(%x_host : !llvm.ptr, f32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "x"}
+ %i_map = omp.map.info var_ptr(%i_host : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"}
+ omp.target map_entries(%x_map -> %x_arg, %i_map -> %i_arg : !llvm.ptr, !llvm.ptr) {
+ %1 = llvm.load %i_arg : !llvm.ptr -> i32
+ %2 = llvm.sitofp %1 : i32 to f32
+ llvm.store %2, %x_arg : f32, !llvm.ptr
+ // The call instruction uses %x_arg more than once. Hence modifying users
+ // while iterating them invalidates the iteration. Which is what is tested
+ // by this test.
+ llvm.call @callee_(%x_arg, %x_arg) : (!llvm.ptr, !llvm.ptr) -> ()
+ omp.terminator
+ }
+ llvm.return
+}
+
+llvm.func @callee_(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ llvm.return
+}
+
+
+// CHECK: define internal void @__omp_offloading_{{.*}}_caller__{{.*}}(ptr %[[X_PARAM:.*]], ptr %[[I_PARAM:.*]]) {
+
+// CHECK: %[[I_VAL:.*]] = load i32, ptr %[[I_PARAM]], align 4
+// CHECK: %[[I_VAL_FL:.*]] = sitofp i32 %[[I_VAL]] to float
+// CHECK: store float %[[I_VAL_FL]], ptr %[[X_PARAM]], align 4
+// CHECK: call void @callee_(ptr %[[X_PARAM]], ptr %[[X_PARAM]])
+// CHECK: br label %[[REGION_CONT:.*]]
+
+// CHECK: [[REGION_CONT]]:
+// CHECK: ret void
+// CHECK: }
More information about the llvm-commits
mailing list