[flang-commits] [mlir] [flang] [MLIR][OpenMP] Changes to function-filtering pass (PR #71850)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Mon Nov 13 03:24:28 PST 2023


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/71850

>From 7b2149fb36f30e689ddb1de452ff1df820e69749 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 9 Nov 2023 18:22:49 +0000
Subject: [PATCH] [MLIR][OpenMP] Changes to function-filtering pass

Currently, when deleting the device functions in the second stage of filtering during MLIR to LLVM translation we can end up with invalid calls to these functions. This is because of the removal of the EarlyOutliningPass which would have otherwise gotten rid of any such calls.

This patch aims to alter the function filtering pass in the following way:
	- Any host function is completely removed.
	- Call to the host function are also removed and their uses replaced with Undef values.
	- Any host function with target region code is marked to be removed during the the second stage.
	- Calls to such functions are still removed and their uses replaced with Undef values.

Co-authored-by: Sergio Afonso
<sergio.afonsofumero at amd.com>
---
 .../flang/Optimizer/Transforms/Passes.td      |  3 +-
 .../Transforms/OMPFunctionFiltering.cpp       | 40 ++++++++----
 flang/test/Driver/OpenMP/target-filtering.f90 | 61 +++++++++++++++++++
 flang/test/Lower/OpenMP/FIR/array-bounds.f90  |  4 +-
 .../test/Lower/OpenMP/function-filtering.f90  |  8 +--
 .../Transforms/omp-function-filtering.mlir    | 24 ++++++--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  2 -
 7 files changed, 113 insertions(+), 29 deletions(-)
 create mode 100644 flang/test/Driver/OpenMP/target-filtering.f90

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 6e23b87b7e276e9..c3768fd2d689c1a 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -330,7 +330,8 @@ def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
                 "for the target device.";
   let constructor = "::fir::createOMPFunctionFilteringPass()";
   let dependentDialects = [
-    "mlir::func::FuncDialect"
+    "mlir::func::FuncDialect",
+    "fir::FIROpsDialect"
   ];
 }
 
diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
index 43fa5b7c4de2414..e1ae4b452defab0 100644
--- a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
+++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -33,6 +34,8 @@ class OMPFunctionFilteringPass
   OMPFunctionFilteringPass() = default;
 
   void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OpBuilder opBuilder(context);
     auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
     if (!op || !op.getIsTargetDevice())
       return;
@@ -46,8 +49,6 @@ class OMPFunctionFilteringPass
               ->walk<WalkOrder::PreOrder>(
                   [&](omp::TargetOp) { return WalkResult::interrupt(); })
               .wasInterrupted();
-      if (hasTargetRegion)
-        return;
 
       omp::DeclareTargetDeviceType declareType =
           omp::DeclareTargetDeviceType::host;
@@ -56,18 +57,31 @@ class OMPFunctionFilteringPass
       if (declareTargetOp && declareTargetOp.isDeclareTarget())
         declareType = declareTargetOp.getDeclareTargetDeviceType();
 
-      // Filtering a function here means removing its body and explicitly
-      // setting its omp.declare_target attribute, so that following
-      // translation/lowering/transformation passes will skip processing its
-      // contents, but preventing the calls to undefined symbols that could
-      // result if the function were deleted. The second stage of function
-      // filtering, at the MLIR to LLVM IR translation level, will remove these
-      // from the IR thanks to the mismatch between the omp.declare_target
-      // attribute and the target device.
+      // Filtering a function here means deleting it if it doesn't containt a
+      // target region. Else we explicitly set the omp.declare_target
+      // attribute. The second stage of function filtering at the MLIR to LLVM
+      // IR translation level will remove functions that contain the target
+      // region from the generated llvm IR.
       if (declareType == omp::DeclareTargetDeviceType::host) {
-        funcOp.eraseBody();
-        funcOp.setVisibility(SymbolTable::Visibility::Private);
-        if (declareTargetOp)
+        SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
+        for (SymbolTable::SymbolUse use : funcUses) {
+          Operation *callOp = use.getUser();
+          // If the callOp has users then replace them with Undef values.
+          if (!callOp->use_empty()) {
+            SmallVector<Value> undefResults;
+            for (Value res : callOp->getResults()) {
+              opBuilder.setInsertionPoint(callOp);
+              undefResults.emplace_back(
+                  opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
+            }
+            callOp->replaceAllUsesWith(undefResults);
+          }
+          // Remove the callOp
+          callOp->erase();
+        }
+        if (!hasTargetRegion)
+          funcOp.erase();
+        else if (declareTargetOp)
           declareTargetOp.setDeclareTarget(declareType,
                                            omp::DeclareTargetCaptureClause::to);
       }
diff --git a/flang/test/Driver/OpenMP/target-filtering.f90 b/flang/test/Driver/OpenMP/target-filtering.f90
new file mode 100644
index 000000000000000..5db5ade6c119bf5
--- /dev/null
+++ b/flang/test/Driver/OpenMP/target-filtering.f90
@@ -0,0 +1,61 @@
+!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s --check-prefixes HOST,ALL
+!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes DEVICE,ALL
+
+!HOST: define {{.*}}@{{.*}}before{{.*}}(
+!DEVICE-NOT: define {{.*}}@before{{.*}}(
+!DEVICE-NOT: declare {{.*}}@before{{.*}}
+integer function before(x)
+   integer, intent(in) :: x
+   before = x + 200
+end function
+
+!ALL: define {{.*}}@{{.*}}main{{.*}}(
+program main
+   integer :: x, before, after
+   !$omp target map(tofrom : x)
+      x = 100
+   !$omp end target
+   !HOST: call {{.*}}@{{.*}}before{{.*}}(
+   !DEVICE-NOT: call {{.*}}@before{{.*}}(
+   !HOST: call {{.*}}@{{.*}}after{{.*}}(
+   !DEVICE-NOT: call {{.*}}@after{{.*}}(
+   x = x + before(x) + after(x)
+end program
+
+!HOST: define {{.*}}@{{.*}}after{{.*}}(
+!DEVICE-NOT: define {{.*}}@after{{.*}}(
+!DEVICE-NOT: declare {{.*}}@after{{.*}}
+integer function after(x)
+   integer, intent(in) :: x
+   after = x + 300
+end function
+
+!ALL: define {{.*}}@{{.*}}before_target{{.*}}(
+subroutine before_target(x)
+   integer, intent(out) :: x
+   !$omp target map(from: x)
+      x = 1
+   !$omp end target
+end subroutine
+
+!ALL: define {{.*}}@{{.*}}middle{{.*}}(
+subroutine middle()
+   integer :: x
+   !$omp target map(from: x)
+      x = 0
+   !$omp end target
+   !HOST: call {{.*}}@{{.*}}before_target{{.*}}(
+   !DEVICE-NOT: call {{.*}}@{{.*}}before_target{{.*}}(
+   !HOST: call {{.*}}@{{.*}}after_target{{.*}}(
+   !DEVICE-NOT: call {{.*}}@{{.*}}after_target{{.*}}(
+   call before_target(x)
+   call after_target(x)
+end subroutine
+
+!ALL: define {{.*}}@{{.*}}after_target{{.*}}(
+subroutine after_target(x)
+   integer, intent(out) :: x
+   !$omp target map(from:x)
+      x = 2
+   !$omp end target
+end subroutine
diff --git a/flang/test/Lower/OpenMP/FIR/array-bounds.f90 b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
index 7e6ac02aefe6057..02b5ebcee022675 100644
--- a/flang/test/Lower/OpenMP/FIR/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
@@ -32,7 +32,7 @@ end subroutine read_write_section
 module assumed_array_routines
 contains
 !ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
 !ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
 !ALL: %[[C0:.*]] = arith.constant 1 : index
 !ALL: %[[C1:.*]] = arith.constant 0 : index
@@ -56,7 +56,7 @@ subroutine assumed_shape_array(arr_read_write)
         end subroutine assumed_shape_array
 
 !ALL-LABEL:   func.func @_QMassumed_array_routinesPassumed_size_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
 !ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
 !ALL: %[[C0:.*]] = arith.constant 1 : index
 !ALL: %[[C1:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/function-filtering.f90 b/flang/test/Lower/OpenMP/function-filtering.f90
index e550348e50692c5..45d8c2e2533d07a 100644
--- a/flang/test/Lower/OpenMP/function-filtering.f90
+++ b/flang/test/Lower/OpenMP/function-filtering.f90
@@ -21,8 +21,7 @@ end function device_fn
 
 ! MLIR-HOST: func.func @{{.*}}host_fn(
 ! MLIR-HOST: return
-! MLIR-DEVICE: func.func private @{{.*}}host_fn(
-! MLIR-DEVICE-NOT: return
+! MLIR-DEVICE-NOT: func.func {{.*}}host_fn(
 
 ! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
 ! LLVM-DEVICE-NOT: {{.*}} @{{.*}}host_fn{{.*}}(
@@ -32,9 +31,8 @@ function host_fn() result(x)
   x = 10
 end function host_fn
 
-! MLIR-HOST: func.func @{{.*}}target_subr(
-! MLIR-HOST: return
-! MLIR-DEVICE: return
+! MLIR-ALL: func.func @{{.*}}target_subr(
+! MLIR-ALL: return
 
 ! LLVM-HOST: define {{.*}} @{{.*}}target_subr{{.*}}(
 ! LLVM-ALL: define {{.*}} @__omp_offloading_{{.*}}_{{.*}}_target_subr__{{.*}}(
diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir
index 44777e2cac30c50..46291e9321f76a2 100644
--- a/flang/test/Transforms/omp-function-filtering.mlir
+++ b/flang/test/Transforms/omp-function-filtering.mlir
@@ -4,16 +4,18 @@
 // CHECK: return
 // CHECK: func.func @nohost
 // CHECK: return
-// CHECK: func.func private @host
-// CHECK-NOT: return
-// CHECK: func.func private @none
-// CHECK-NOT: return
+// CHECK-NOT: func.func {{.*}}}} @host
+// CHECK-NOT: func.func {{.*}}}} @none
 // CHECK: func.func @nohost_target
 // CHECK: return
 // CHECK: func.func @host_target
 // CHECK: return
 // CHECK: func.func @none_target
 // CHECK: return
+// CHECK: func.func @host_target_call
+// CHECK-NOT: call @none_target
+// CHECK: %[[UNDEF:.*]] = fir.undefined i32
+// CHECK: return %[[UNDEF]] : i32
 module attributes {omp.is_target_device = true} {
   func.func @any() -> ()
       attributes {
@@ -55,9 +57,19 @@ module attributes {omp.is_target_device = true} {
     omp.target {}
     func.return
   }
-  func.func @none_target() -> () {
+  func.func @none_target() -> i32 {
     omp.target {}
-    func.return
+    %0 = arith.constant 25 : i32
+    func.return %0 : i32
+  }
+  func.func @host_target_call() -> i32
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    omp.target {}
+    %0 = call @none_target() : () -> i32
+    func.return %0 : i32
   }
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d9ab785a082835d..4c1dc6603af1599 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2496,8 +2496,6 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
       if (declareType == omp::DeclareTargetDeviceType::host) {
         llvm::Function *llvmFunc =
             moduleTranslation.lookupFunction(funcOp.getName());
-        llvmFunc->replaceAllUsesWith(
-            llvm::UndefValue::get(llvmFunc->getType()));
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
       }



More information about the flang-commits mailing list