[Mlir-commits] [llvm] [mlir] [OpenMP][OMPIRBuilder] Collect users of a value before replacing them in target outlined function (PR #139064)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu May 8 04:07:27 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Kareem Ergawy (ergawy)
<details>
<summary>Changes</summary>
This PR fixes a crash that curently happens given the following input:
```fortran
subroutine caller()
real :: x, y, z
integer :: i
!$omp target
x = i
call callee(x,x)
!$omp end target
endsubroutine caller
subroutine callee(x1,x2)
real :: x1, x2
endsubroutine callee
```
The crash happens because the following sequence of events is taken by the `OMPIRBuilder`:
1. ....
2. An outlined function for the target region is created. At first the
outlined function still refers to the SSA values from the original
function of the target region.
3. The builder then iterates over the users of SSA values used in the
target region to replace them with the corresponding function arguments
of outlined function.
4. If the same instruction references the SSA value more than once (say m),
all uses of that SSA value are replaced in the instruction.
Deleting all m uses of the value.
5. The next m-1 iterations will still iterate over the same
instruction dropping the last m-1 actual users of the value.
Hence, we call all users first before modifying them.
---
Full diff: https://github.com/llvm/llvm-project/pull/139064.diff
2 Files Affected:
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+5-1)
- (added) mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir (+37)
``````````diff
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 7c1b64677a011..3dfde6f5e236d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -7035,8 +7035,12 @@ static Expected<Function *> createOutlinedFunction(
if (auto *Const = dyn_cast<Constant>(Input))
convertUsersOfConstantsToInstructions(Const, Func, false);
+ // Collect users before iterating over them to avoid invalidating the
+ // iteration in case a user uses Input more than once (e.g. a call
+ // instruction).
+ SetVector<User *> Users(Input->users().begin(), Input->users().end());
// Collect all the instructions
- for (User *User : make_early_inc_range(Input->users()))
+ for (User *User : make_early_inc_range(Users))
if (auto *Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
Instr->replaceUsesOfWith(Input, InputCopy);
diff --git a/mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir b/mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir
new file mode 100644
index 0000000000000..91b4129f6ff9d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omp-target-call-with-repeated-parameter.mlir
@@ -0,0 +1,37 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @caller_() {
+ %c1 = llvm.mlir.constant(1 : i64) : i64
+ %x_host = llvm.alloca %c1 x f32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+ %i_host = llvm.alloca %c1 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+ %x_map = omp.map.info var_ptr(%x_host : !llvm.ptr, f32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "x"}
+ %i_map = omp.map.info var_ptr(%i_host : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"}
+ omp.target map_entries(%x_map -> %x_arg, %i_map -> %i_arg : !llvm.ptr, !llvm.ptr) {
+ %1 = llvm.load %i_arg : !llvm.ptr -> i32
+ %2 = llvm.sitofp %1 : i32 to f32
+ llvm.store %2, %x_arg : f32, !llvm.ptr
+ // The call instruction uses %x_arg more than once. Hence modifying users
+ // while iterating them invalidates the iteration. Which is what is tested
+ // by here.
+ llvm.call @callee_(%x_arg, %x_arg) : (!llvm.ptr, !llvm.ptr) -> ()
+ omp.terminator
+ }
+ llvm.return
+}
+
+llvm.func @callee_(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
+ llvm.return
+}
+
+
+// CHECK: define internal void @__omp_offloading_{{.*}}_caller__{{.*}}(ptr %[[X_PARAM:.*]], ptr %[[I_PARAM:.*]]) {
+
+// CHECK: %[[I_VAL:.*]] = load i32, ptr %[[I_PARAM]], align 4
+// CHECK: %[[I_VAL_FL:.*]] = sitofp i32 %[[I_VAL]] to float
+// CHECK: store float %[[I_VAL_FL]], ptr %[[X_PARAM]], align 4
+// CHECK: call void @callee_(ptr %[[X_PARAM]], ptr %[[X_PARAM]])
+// CHECK: br label %[[REGION_CONT:.*]]
+
+// CHECK: [[REGION_CONT]]:
+// CHECK: ret void
+// CHECK: }
``````````
</details>
https://github.com/llvm/llvm-project/pull/139064
More information about the Mlir-commits
mailing list