[flang-commits] [flang] [mlir] [MLIR][OpenMP] Changes to function-filtering pass (PR #71850)

via flang-commits flang-commits at lists.llvm.org
Thu Nov 9 11:14:04 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-driver
@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-mlir-llvm

Author: Akash Banerjee (TIFitis)

<details>
<summary>Changes</summary>

Currently, when deleting the device functions in the second stage of filtering during MLIR to LLVM translation we can end up with invalid calls to these functions. This is because of the removal of the EarlyOutliningPass which would have otherwise gotten rid of any such calls.

This patch aims to alter the function filtering pass in the following way:
	- Any host function is completely removed.
	- Call to the host function are also removed and their uses replaced with Undef values.
	- Any host function with target region code is marked to be removed during the the second stage.
	- Calls to such functions are still removed and their uses replaced with Undef values.

Coauthored by: skatrak(Sergio.Afonso@<!-- -->amd.com)

---
Full diff: https://github.com/llvm/llvm-project/pull/71850.diff


6 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp (+25-12) 
- (added) flang/test/Driver/OpenMP/target-filtering.f90 (+37) 
- (modified) flang/test/Lower/OpenMP/FIR/array-bounds.f90 (+2-2) 
- (modified) flang/test/Lower/OpenMP/function-filtering.f90 (+1-2) 
- (modified) flang/test/Transforms/omp-function-filtering.mlir (+2-4) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (-2) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
index 43fa5b7c4de2414..9e235c2d250fe93 100644
--- a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
+++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
@@ -33,6 +33,8 @@ class OMPFunctionFilteringPass
   OMPFunctionFilteringPass() = default;
 
   void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OpBuilder opBuilder(context);
     auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
     if (!op || !op.getIsTargetDevice())
       return;
@@ -46,8 +48,6 @@ class OMPFunctionFilteringPass
               ->walk<WalkOrder::PreOrder>(
                   [&](omp::TargetOp) { return WalkResult::interrupt(); })
               .wasInterrupted();
-      if (hasTargetRegion)
-        return;
 
       omp::DeclareTargetDeviceType declareType =
           omp::DeclareTargetDeviceType::host;
@@ -56,17 +56,30 @@ class OMPFunctionFilteringPass
       if (declareTargetOp && declareTargetOp.isDeclareTarget())
         declareType = declareTargetOp.getDeclareTargetDeviceType();
 
-      // Filtering a function here means removing its body and explicitly
-      // setting its omp.declare_target attribute, so that following
-      // translation/lowering/transformation passes will skip processing its
-      // contents, but preventing the calls to undefined symbols that could
-      // result if the function were deleted. The second stage of function
-      // filtering, at the MLIR to LLVM IR translation level, will remove these
-      // from the IR thanks to the mismatch between the omp.declare_target
-      // attribute and the target device.
+      // Filtering a function here means deleting it if it doesn't containt a
+      // target region. Else we explicitly setting its omp.declare_target
+      // attribute. The second stage of function filtering at the MLIR to LLVM
+      // IR translation level will remove functions that contain the target
+      // region from the generated llvm IR.
       if (declareType == omp::DeclareTargetDeviceType::host) {
-        funcOp.eraseBody();
-        funcOp.setVisibility(SymbolTable::Visibility::Private);
+        auto funcUses = *funcOp.getSymbolUses(op);
+        for (auto use : funcUses) {
+          auto *callOp = use.getUser();
+          // If the callOp has users then replace them with Undef values.
+          if (!callOp->use_empty()) {
+            SmallVector<Value> undefResults;
+            for (auto res : callOp->getResults()) {
+              opBuilder.setInsertionPoint(callOp);
+              undefResults.emplace_back(
+                  opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
+            }
+            callOp->replaceAllUsesWith(undefResults);
+          }
+          // Remove the callOp
+          callOp->erase();
+        }
+        if (!hasTargetRegion)
+          funcOp.erase();
         if (declareTargetOp)
           declareTargetOp.setDeclareTarget(declareType,
                                            omp::DeclareTargetCaptureClause::to);
diff --git a/flang/test/Driver/OpenMP/target-filtering.f90 b/flang/test/Driver/OpenMP/target-filtering.f90
new file mode 100644
index 000000000000000..6e1bd34975c2728
--- /dev/null
+++ b/flang/test/Driver/OpenMP/target-filtering.f90
@@ -0,0 +1,37 @@
+!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s --check-prefixes HOST
+!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes DEVICE
+
+!HOST: define {{.*}}@{{.*}}before{{.*}}(
+!DEVICE-NOT: define {{.*}}@before{{.*}}(
+!DEVICE-NOT: declare {{.*}}@before{{.*}}
+integer function before(x)
+integer, intent(in) :: x
+before = x + 200
+end function
+
+!HOST: define {{.*}}@{{.*}}main{{.*}}(
+!DEVICE: define {{.*}}@{{.*}}main{{.*}}(
+program main
+integer :: x, before, after
+!$omp target map(tofrom : x)
+   x = 100
+!$omp end target
+!HOST: call {{.*}}@{{.*}}before{{.*}}(
+!DEVICE-NOT: call {{.*}}@before{{.*}}(
+!HOST: call {{.*}}@{{.*}}after{{.*}}(
+!DEVICE-NOT: call {{.*}}@after{{.*}}(
+x = x + before(x) + after(x)
+WRITE(*,*) "Test:", x
+!$omp target map(tofrom : x)
+   x = 50
+!$omp end target
+WRITE(*,*) "Test:", x
+end program
+
+!HOST: define {{.*}}@{{.*}}after{{.*}}(
+!DEVICE-NOT: define {{.*}}@after{{.*}}(
+!DEVICE-NOT: declare {{.*}}@after{{.*}}
+integer function after(x)
+integer, intent(in) :: x
+after = x + 300
+end function
\ No newline at end of file
diff --git a/flang/test/Lower/OpenMP/FIR/array-bounds.f90 b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
index abef31af22ba663..deb2e063f9db714 100644
--- a/flang/test/Lower/OpenMP/FIR/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
@@ -32,7 +32,7 @@ end subroutine read_write_section
 module assumed_array_routines
 contains
 !ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
 !ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
 !ALL: %[[C0:.*]] = arith.constant 1 : index
 !ALL: %[[C1:.*]] = arith.constant 0 : index
@@ -56,7 +56,7 @@ subroutine assumed_shape_array(arr_read_write)
         end subroutine assumed_shape_array
 
 !ALL-LABEL:   func.func @_QMassumed_array_routinesPassumed_size_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
 !ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
 !ALL: %[[C0:.*]] = arith.constant 1 : index
 !ALL: %[[C1:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/function-filtering.f90 b/flang/test/Lower/OpenMP/function-filtering.f90
index e550348e50692c5..11472bc1de957e3 100644
--- a/flang/test/Lower/OpenMP/function-filtering.f90
+++ b/flang/test/Lower/OpenMP/function-filtering.f90
@@ -21,8 +21,7 @@ end function device_fn
 
 ! MLIR-HOST: func.func @{{.*}}host_fn(
 ! MLIR-HOST: return
-! MLIR-DEVICE: func.func private @{{.*}}host_fn(
-! MLIR-DEVICE-NOT: return
+! MLIR-DEVICE-NOT: func.func {{.*}}host_fn(
 
 ! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
 ! LLVM-DEVICE-NOT: {{.*}} @{{.*}}host_fn{{.*}}(
diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir
index 44777e2cac30c50..b831968cb7dc6fd 100644
--- a/flang/test/Transforms/omp-function-filtering.mlir
+++ b/flang/test/Transforms/omp-function-filtering.mlir
@@ -4,10 +4,8 @@
 // CHECK: return
 // CHECK: func.func @nohost
 // CHECK: return
-// CHECK: func.func private @host
-// CHECK-NOT: return
-// CHECK: func.func private @none
-// CHECK-NOT: return
+// CHECK-NOT: func.func {{.*}}}} @host
+// CHECK-NOT: func.func {{.*}}}} @none
 // CHECK: func.func @nohost_target
 // CHECK: return
 // CHECK: func.func @host_target
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d9ab785a082835d..4c1dc6603af1599 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2496,8 +2496,6 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
       if (declareType == omp::DeclareTargetDeviceType::host) {
         llvm::Function *llvmFunc =
             moduleTranslation.lookupFunction(funcOp.getName());
-        llvmFunc->replaceAllUsesWith(
-            llvm::UndefValue::get(llvmFunc->getType()));
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
       }

``````````

</details>


https://github.com/llvm/llvm-project/pull/71850


More information about the flang-commits mailing list