[llvm] [mlir] [mlir][openmp] - Fix crash in OpenMPIRBuilder when converting to LLVMIR (PR #84611)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 8 23:44:52 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-openmp

Author: Pranav Bhandarkar (bhandarkar-pranav)

<details>
<summary>Changes</summary>

This patch fixes an issue in the conversion of the `omp` MLIR dialect to LLVMIR using the OpenMPIRBuilder seen during offloading.
The `finalize` method in OpenMPIRBuilder outlines regions previoulsy marked for outlining (For example, when processing `omp.task` ops). To do so, it uses a datastructure that holds `BasicBlock`s that need to be outline.
However, if these regions belong to host code in a device module then these regions will have been removed when the `omp.declare_target` attribute is converted. As a result the `finalize` method ends up accessing bad data and crashes.

Fixes https://github.com/llvm/llvm-project/issues/84606

---
Full diff: https://github.com/llvm/llvm-project/pull/84611.diff


4 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+7) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+12) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+2) 
- (added) mlir/test/Target/LLVMIR/openmp-task-target-device.mlir (+27) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 5bbaa8c208b8cd..54908d01bc006c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -465,6 +465,10 @@ class OpenMPIRBuilder {
 
   void setConfig(OpenMPIRBuilderConfig C) { Config = C; }
 
+  /// Remove all references or state about Func that OpenMPIRBuilder may
+  /// be keeping
+  void dropFunction(Function *Func);
+
   /// Finalize the underlying module, e.g., by outlining regions.
   /// \param Fn                    The function to be finalized. If not used,
   ///                              all functions are finalized.
@@ -1518,6 +1522,9 @@ class OpenMPIRBuilder {
   /// Add a new region that will be outlined later.
   void addOutlineInfo(OutlineInfo &&OI) { OutlineInfos.emplace_back(OI); }
 
+  /// Remove outlining information if it refers to a certain function
+  void removeFuncFromOutlineInfo(llvm::Function *Func);
+
   /// An ordered map of auto-generated variables to their unique names.
   /// It stores variables with the following names: 1) ".gomp_critical_user_" +
   /// <critical_section_name> + ".var" for "omp critical" directives; 2)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d65ed8c11d86cc..e634e221a96de7 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -488,6 +488,14 @@ void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
 // OpenMPIRBuilder
 //===----------------------------------------------------------------------===//
 
+void OpenMPIRBuilder::removeFuncFromOutlineInfo(llvm::Function *Func) {
+  OutlineInfos.erase(std::remove_if(OutlineInfos.begin(), OutlineInfos.end(),
+                                    [&](const OutlineInfo &OI) {
+                                      return OI.getFunction() == Func;
+                                    }),
+                     OutlineInfos.end());
+}
+
 void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
                                           IRBuilderBase &Builder,
                                           SmallVector<Value *> &ArgsVector) {
@@ -794,6 +802,10 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
     createOffloadEntriesAndInfoMetadata(ErrorReportFn);
 }
 
+void OpenMPIRBuilder::dropFunction(Function *Func) {
+  removeFuncFromOutlineInfo(Func);
+}
+
 OpenMPIRBuilder::~OpenMPIRBuilder() {
   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
 }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index bef227f2c583ed..ab836251e15eb0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2853,6 +2853,8 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
       if (declareType == omp::DeclareTargetDeviceType::host) {
         llvm::Function *llvmFunc =
             moduleTranslation.lookupFunction(funcOp.getName());
+
+        moduleTranslation.getOpenMPBuilder()->dropFunction(llvmFunc);
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
       }
diff --git a/mlir/test/Target/LLVMIR/openmp-task-target-device.mlir b/mlir/test/Target/LLVMIR/openmp-task-target-device.mlir
new file mode 100644
index 00000000000000..c167604bcd12c0
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-task-target-device.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// This tests the fix for https://github.com/llvm/llvm-project/issues/84606
+// We are only interested in ensuring that the -mlir-to-llmvir pass doesn't crash.
+// CHECK: {{.*}} = add i32 {{.*}}, 5
+module attributes {omp.is_target_device = true } {
+  llvm.func @_QQmain() attributes {fir.bindc_name = "main", omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>} {
+    %0 = llvm.mlir.constant(0 : i32) : i32
+    %1 = llvm.mlir.constant(1 : i64) : i64
+    %2 = llvm.alloca %1 x i32 {bindc_name = "a"} : (i64) -> !llvm.ptr<5>
+    %3 = llvm.addrspacecast %2 : !llvm.ptr<5> to !llvm.ptr
+    omp.task {
+      llvm.store %0, %3 : i32, !llvm.ptr
+      omp.terminator
+    }
+    %4 = omp.map_info var_ptr(%3 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "a"}
+    omp.target map_entries(%4 -> %arg0 : !llvm.ptr) {
+    ^bb0(%arg0: !llvm.ptr):
+      %5 = llvm.mlir.constant(5 : i32) : i32
+      %6 = llvm.load %arg0  : !llvm.ptr -> i32
+      %7 = llvm.add %6, %5  : i32
+      llvm.store %7, %arg0  : i32, !llvm.ptr
+      omp.terminator
+    }
+    llvm.return
+  }
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/84611


More information about the llvm-commits mailing list