[Mlir-commits] [mlir] debdfc0 - [Flang][OpenMP][MLIR] Filter emitted code depending on declare target and device

Sergio Afonso llvmlistbot at llvm.org
Mon Jul 17 01:08:17 PDT 2023


Author: Sergio Afonso
Date: 2023-07-17T09:07:54+01:00
New Revision: debdfc0ae21befa4124c227a916fc69913a4f146

URL: https://github.com/llvm/llvm-project/commit/debdfc0ae21befa4124c227a916fc69913a4f146
DIFF: https://github.com/llvm/llvm-project/commit/debdfc0ae21befa4124c227a916fc69913a4f146.diff

LOG: [Flang][OpenMP][MLIR] Filter emitted code depending on declare target and device

This patch adds support for selecting which functions are lowered to LLVM IR
from MLIR depending on declare target information and whether host or device
code is being generated.

The approach proposed by this patch is to perform the filtering in two stages:
  - An MLIR transformation pass, which is added to the Flang translation flow
    after the `OMPEarlyOutliningPass`. The functions that are kept are those
    that match the OpenMP processor (host or device) the compiler invocation
    is targeting, according to the presence of the `-fopenmp-is-target-device`
    compiler option and declare target information. All functions contaning an
    `omp.target` are also kept, regardless of the declare target information of
    the function, due to the need for keeping target regions visible for both
    host and device compilation.
  - A filtering step during translation to LLVM IR, which is peformed for those
    functions that were kept because of the presence of a target region inside.
    If the targeted OpenMP processor does not match the declare target
    information of the function, then it is removed from the LLVM IR after its
    contents have been processed and translated. Since they should only contain
    an omp.target operation which, in turn, should have been outlined into
    another LLVM IR function, the wrapper can be deleted at that point.

Depends on D150328 and D150329.

Differential Revision: https://reviews.llvm.org/D147641

Added: 
    flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
    flang/test/Lower/OpenMP/function-filtering.f90
    flang/test/Transforms/omp-function-filtering.mlir

Modified: 
    flang/include/flang/Optimizer/Transforms/Passes.h
    flang/include/flang/Optimizer/Transforms/Passes.td
    flang/lib/Frontend/FrontendActions.cpp
    flang/lib/Optimizer/Transforms/CMakeLists.txt
    flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90
    flang/test/Lower/OpenMP/omp-declare-target-program-var.f90
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
    mlir/test/Target/LLVMIR/openmp-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.h b/flang/include/flang/Optimizer/Transforms/Passes.h
index 8d150462f1f2de..3272cb3f609dc7 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.h
+++ b/flang/include/flang/Optimizer/Transforms/Passes.h
@@ -73,8 +73,11 @@ std::unique_ptr<mlir::Pass> createAlgebraicSimplificationPass();
 std::unique_ptr<mlir::Pass>
 createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
 std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
+
 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
 createOMPEarlyOutliningPass();
+std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
+
 // declarative passes
 #define GEN_PASS_REGISTRATION
 #include "flang/Optimizer/Transforms/Passes.h.inc"

diff  --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index e6ecdede294842..40a08c9959abbb 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -311,4 +311,13 @@ def OMPEarlyOutliningPass
   let dependentDialects = ["mlir::omp::OpenMPDialect"];
 }
 
+def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
+  let summary = "Filters out functions intended for the host when compiling "
+                "for the device and vice versa.";
+  let constructor = "::fir::createOMPFunctionFilteringPass()";
+  let dependentDialects = [
+    "mlir::func::FuncDialect"
+  ];
+}
+
 #endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

diff  --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 173c25789f720c..c03c3fd8e9c129 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -312,6 +312,7 @@ bool CodeGenAction::beginSourceFileAction() {
 
     if (isDevice)
       pm.addPass(fir::createOMPEarlyOutliningPass());
+    pm.addPass(fir::createOMPFunctionFilteringPass());
   }
 
   pm.enableVerifier(/*verifyPasses=*/true);

diff  --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index bd4aee363906e9..18085422b1c46b 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -17,6 +17,7 @@ add_flang_library(FIRTransforms
   PolymorphicOpConversion.cpp
   LoopVersioning.cpp
   OMPEarlyOutlining.cpp
+  OMPFunctionFiltering.cpp
 
   DEPENDS
   FIRDialect

diff  --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
new file mode 100644
index 00000000000000..7784c90e930000
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
@@ -0,0 +1,73 @@
+//===- OMPFunctionFiltering.cpp -------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements transforms to filter out functions intended for the host
+// when compiling for the device and vice versa.
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "llvm/ADT/SmallVector.h"
+
+namespace fir {
+#define GEN_PASS_DEF_OMPFUNCTIONFILTERING
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace fir;
+using namespace mlir;
+
+namespace {
+class OMPFunctionFilteringPass
+    : public fir::impl::OMPFunctionFilteringBase<OMPFunctionFilteringPass> {
+public:
+  OMPFunctionFilteringPass() = default;
+
+  void runOnOperation() override {
+    auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
+    if (!op)
+      return;
+
+    bool isDeviceCompilation = op.getIsTargetDevice();
+    op->walk<WalkOrder::PostOrder>([&](func::FuncOp funcOp) {
+      // Do not filter functions with target regions inside, because they have
+      // to be available for both host and device so that regular and reverse
+      // offloading can be supported.
+      bool hasTargetRegion =
+          funcOp
+              ->walk<WalkOrder::PreOrder>(
+                  [&](omp::TargetOp) { return WalkResult::interrupt(); })
+              .wasInterrupted();
+      if (hasTargetRegion)
+        return;
+
+      omp::DeclareTargetDeviceType declareType =
+          omp::DeclareTargetDeviceType::host;
+      auto declareTargetOp =
+          dyn_cast<omp::DeclareTargetInterface>(funcOp.getOperation());
+      if (declareTargetOp && declareTargetOp.isDeclareTarget())
+        declareType = declareTargetOp.getDeclareTargetDeviceType();
+
+      if ((isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::host) ||
+          (!isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::nohost))
+        funcOp->erase();
+    });
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> fir::createOMPFunctionFilteringPass() {
+  return std::make_unique<OMPFunctionFilteringPass>();
+}

diff  --git a/flang/test/Lower/OpenMP/function-filtering.f90 b/flang/test/Lower/OpenMP/function-filtering.f90
new file mode 100644
index 00000000000000..4386cb43c144f6
--- /dev/null
+++ b/flang/test/Lower/OpenMP/function-filtering.f90
@@ -0,0 +1,44 @@
+! RUN: %flang_fc1 -fopenmp -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-HOST,LLVM-ALL %s
+! RUN: %flang_fc1 -fopenmp -emit-mlir %s -o - | FileCheck --check-prefix=MLIR-HOST %s
+! RUN: %flang_fc1 -fopenmp -fopenmp-is-target-device -emit-llvm %s -o - | FileCheck --check-prefixes=LLVM-DEVICE,LLVM-ALL %s
+! RUN: %flang_fc1 -fopenmp -fopenmp-is-target-device -emit-mlir %s -o - | FileCheck --check-prefix=MLIR-DEVICE %s
+
+! Check that the correct LLVM IR functions are kept for the host and device
+! after running the whole set of translation and transformation passes from
+! Fortran.
+
+! MLIR-HOST-NOT: func.func @{{.*}}device_fn(
+! MLIR-DEVICE: func.func @{{.*}}device_fn(
+! LLVM-HOST-NOT: define {{.*}} @{{.*}}device_fn{{.*}}(
+! LLVM-DEVICE: define {{.*}} @{{.*}}device_fn{{.*}}(
+function device_fn() result(x)
+  !$omp declare target to(device_fn) device_type(nohost)
+  integer :: x
+  x = 10
+end function device_fn
+
+! MLIR-HOST: func.func @{{.*}}host_fn(
+! MLIR-DEVICE-NOT: func.func @{{.*}}host_fn(
+! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
+! LLVM-DEVICE-NOT: define {{.*}} @{{.*}}host_fn{{.*}}(
+function host_fn() result(x)
+  !$omp declare target to(host_fn) device_type(host)
+  integer :: x
+  x = 10
+end function host_fn
+
+! MLIR-HOST: func.func @{{.*}}target_subr(
+! MLIR-HOST-NOT: func.func @{{.*}}target_subr_omp_outline_0(
+! MLIR-DEVICE-NOT: func.func @{{.*}}target_subr(
+! MLIR-DEVICE: func.func @{{.*}}target_subr_omp_outline_0(
+
+! LLVM-ALL-NOT: define {{.*}} @{{.*}}target_subr_omp_outline_0{{.*}}(
+! LLVM-HOST: define {{.*}} @{{.*}}target_subr{{.*}}(
+! LLVM-DEVICE-NOT: define {{.*}} @{{.*}}target_subr{{.*}}(
+! LLVM-ALL: define {{.*}} @__omp_offloading_{{.*}}_{{.*}}_target_subr__{{.*}}(
+subroutine target_subr(x)
+  integer, intent(out) :: x
+  !$omp target map(from:x)
+    x = 10
+  !$omp end target
+end subroutine target_subr

diff  --git a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90 b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90
index 6e197c59b211bd..26741c6795a6ce 100644
--- a/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-target-func-and-subr.f90
@@ -1,51 +1,52 @@
-!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes ALL,HOST
+!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-device %s -o - | FileCheck %s --check-prefixes ALL,DEVICE
 
 ! Check specification valid forms of declare target with functions 
 ! utilising device_type and to clauses as well as the default 
 ! zero clause declare target
 
-! CHECK-LABEL: func.func @_QPfunc_t_device()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+! DEVICE-LABEL: func.func @_QPfunc_t_device()
+! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_T_DEVICE() RESULT(I)
 !$omp declare target to(FUNC_T_DEVICE) device_type(nohost)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_T_DEVICE
 
-! CHECK-LABEL: func.func @_QPfunc_t_host()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
+! HOST-LABEL: func.func @_QPfunc_t_host()
+! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_T_HOST() RESULT(I)
 !$omp declare target to(FUNC_T_HOST) device_type(host)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_T_HOST
 
-! CHECK-LABEL: func.func @_QPfunc_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_T_ANY() RESULT(I)
 !$omp declare target to(FUNC_T_ANY) device_type(any)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_T_ANY
 
-! CHECK-LABEL: func.func @_QPfunc_default_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_default_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_DEFAULT_T_ANY() RESULT(I)
 !$omp declare target to(FUNC_DEFAULT_T_ANY)
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_DEFAULT_T_ANY
 
-! CHECK-LABEL: func.func @_QPfunc_default_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_default_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_DEFAULT_ANY() RESULT(I)
 !$omp declare target
     INTEGER :: I
     I = 1
 END FUNCTION FUNC_DEFAULT_ANY
 
-! CHECK-LABEL: func.func @_QPfunc_default_extendedlist()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPfunc_default_extendedlist()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 FUNCTION FUNC_DEFAULT_EXTENDEDLIST() RESULT(I)
 !$omp declare target(FUNC_DEFAULT_EXTENDEDLIST)
     INTEGER :: I
@@ -58,46 +59,46 @@ END FUNCTION FUNC_DEFAULT_EXTENDEDLIST
 ! utilising device_type and to clauses as well as the default 
 ! zero clause declare target
 
-! CHECK-LABEL: func.func @_QPsubr_t_device()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+! DEVICE-LABEL: func.func @_QPsubr_t_device()
+! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_T_DEVICE()
 !$omp declare target to(SUBR_T_DEVICE) device_type(nohost)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_t_host()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
+! HOST-LABEL: func.func @_QPsubr_t_host()
+! HOST-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (host), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_T_HOST()
 !$omp declare target to(SUBR_T_HOST) device_type(host)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_T_ANY()
 !$omp declare target to(SUBR_T_ANY) device_type(any)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_default_t_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_default_t_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_DEFAULT_T_ANY()
 !$omp declare target to(SUBR_DEFAULT_T_ANY)
 END
 
-! CHECK-LABEL: func.func @_QPsubr_default_any()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_default_any()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_DEFAULT_ANY()
 !$omp declare target
 END
 
-! CHECK-LABEL: func.func @_QPsubr_default_extendedlist()
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
+! ALL-LABEL: func.func @_QPsubr_default_extendedlist()
+! ALL-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>{{.*}}
 SUBROUTINE SUBR_DEFAULT_EXTENDEDLIST()
 !$omp declare target(SUBR_DEFAULT_EXTENDEDLIST)
 END
 
 !! -----
 
-! CHECK-LABEL: func.func @_QPrecursive_declare_target
-! CHECK-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
+! DEVICE-LABEL: func.func @_QPrecursive_declare_target
+! DEVICE-SAME: {{.*}}attributes {omp.declare_target = #omp.declaretarget<device_type = (nohost), capture_clause = (to)>{{.*}}
 RECURSIVE FUNCTION RECURSIVE_DECLARE_TARGET(INCREMENT) RESULT(K)
 !$omp declare target to(RECURSIVE_DECLARE_TARGET) device_type(nohost)
     INTEGER :: INCREMENT, K

diff  --git a/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90 b/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90
index ef39a985cbc3a5..0da76f6d9ad2ce 100644
--- a/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90
+++ b/flang/test/Lower/OpenMP/omp-declare-target-program-var.f90
@@ -1,12 +1,12 @@
-!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s 
-!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes=HOST,ALL
+!RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefix=ALL
 
 PROGRAM main
-    ! CHECK-DAG: %0 = fir.alloca f32 {bindc_name = "i", uniq_name = "_QFEi"}
+    ! HOST-DAG: %0 = fir.alloca f32 {bindc_name = "i", uniq_name = "_QFEi"}
     REAL :: I
-    ! CHECK-DAG: fir.global internal @_QFEi {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : f32 {
-    ! CHECK-DAG: %0 = fir.undefined f32
-    ! CHECK-DAG: fir.has_value %0 : f32
-    ! CHECK-DAG: }
+    ! ALL-DAG: fir.global internal @_QFEi {omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>} : f32 {
+    ! ALL-DAG: %0 = fir.undefined f32
+    ! ALL-DAG: fir.has_value %0 : f32
+    ! ALL-DAG: }
     !$omp declare target(I)
 END

diff  --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir
new file mode 100644
index 00000000000000..ccb11caf7c81f7
--- /dev/null
+++ b/flang/test/Transforms/omp-function-filtering.mlir
@@ -0,0 +1,111 @@
+// RUN: fir-opt -split-input-file --omp-function-filtering %s | FileCheck %s
+
+// CHECK:     func.func @any
+// CHECK:     func.func @nohost
+// CHECK-NOT: func.func @host
+// CHECK-NOT: func.func @none
+// CHECK:     func.func @nohost_target
+// CHECK:     func.func @host_target
+// CHECK:     func.func @none_target
+module attributes {omp.is_target_device = true} {
+  func.func @any() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (any), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @nohost() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @none() -> () {
+    func.return
+  }
+  func.func @nohost_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @host_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @none_target() -> () {
+    omp.target {}
+    func.return
+  }
+}
+
+// -----
+
+// CHECK:     func.func @any
+// CHECK-NOT: func.func @nohost
+// CHECK:     func.func @host
+// CHECK:     func.func @none
+// CHECK:     func.func @nohost_target
+// CHECK:     func.func @host_target
+// CHECK:     func.func @none_target
+module attributes {omp.is_target_device = false} {
+  func.func @any() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (any), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @nohost() -> ()
+      attributes {
+          omp.declare_target =
+            #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    func.return
+  }
+  func.func @none() -> () {
+    func.return
+  }
+  func.func @nohost_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @host_target() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    omp.target {}
+    func.return
+  }
+  func.func @none_target() -> () {
+    omp.target {}
+    func.return
+  }
+}

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 49df49d85748c5..efcb9180f6098f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 #include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
 #include "mlir/IR/IRMapping.h"
@@ -1667,6 +1668,38 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   return bodyGenStatus;
 }
 
+static LogicalResult
+convertDeclareTargetAttr(Operation *op,
+                         omp::DeclareTargetAttr declareTargetAttr,
+                         LLVM::ModuleTranslation &moduleTranslation) {
+  // Amend omp.declare_target by deleting the IR of the outlined functions
+  // created for target regions. They cannot be filtered out from MLIR earlier
+  // because the omp.target operation inside must be translated to LLVM, but the
+  // wrapper functions themselves must not remain at the end of the process.
+  // We know that functions where omp.declare_target does not match
+  // omp.is_target_device at this stage can only be wrapper functions because
+  // those that aren't are removed earlier as an MLIR transformation pass.
+  if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
+    if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
+            op->getParentOfType<ModuleOp>().getOperation())) {
+      bool isDeviceCompilation = offloadMod.getIsTargetDevice();
+      omp::DeclareTargetDeviceType declareType =
+          declareTargetAttr.getDeviceType().getValue();
+
+      if ((isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::host) ||
+          (!isDeviceCompilation &&
+           declareType == omp::DeclareTargetDeviceType::nohost)) {
+        llvm::Function *llvmFunc =
+            moduleTranslation.lookupFunction(funcOp.getName());
+        llvmFunc->dropAllReferences();
+        llvmFunc->eraseFromParent();
+      }
+    }
+  }
+  return success();
+}
+
 namespace {
 
 /// Implementation of the dialect interface that converts operations belonging
@@ -1694,7 +1727,6 @@ class OpenMPDialectLLVMIRTranslationInterface
 LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
     Operation *op, NamedAttribute attribute,
     LLVM::ModuleTranslation &moduleTranslation) const {
-
   return llvm::TypeSwitch<Attribute, LogicalResult>(attribute.getValue())
       .Case([&](mlir::omp::FlagsAttr rtlAttr) {
         return convertFlagsAttr(op, rtlAttr, moduleTranslation);
@@ -1706,6 +1738,10 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
                                     versionAttr.getVersion());
         return success();
       })
+      .Case([&](mlir::omp::DeclareTargetAttr declareTargetAttr) {
+        return convertDeclareTargetAttr(op, declareTargetAttr,
+                                        moduleTranslation);
+      })
       .Default([&](Attribute attr) {
         // fall through for omp attributes that do not require lowering and/or
         // have no concrete definition and thus no type to define a case on

diff  --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
index a121538f1eebab..bee77bbea2cded 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
@@ -2,7 +2,7 @@
 // name stored in the omp.outline_parent_name attribute.
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
-module attributes {omp.is_device = true} {
+module attributes {omp.is_target_device = true} {
   llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr<i32>, %arg1: !llvm.ptr<i32>) attributes {omp.outline_parent_name = "writeindex_"} {
     omp.target   map((from -> %arg0 : !llvm.ptr<i32>), (implicit -> %arg1: !llvm.ptr<i32>)) {
       %0 = llvm.mlir.constant(20 : i32) : i32

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 89a4578459f10b..15eb0b353a3e8b 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2543,3 +2543,47 @@ module attributes {omp.flags = #omp.flags<debug_kind = 0, assume_teams_oversubsc
 // CHECK: @__omp_rtl_assume_no_thread_state = weak_odr hidden constant i32 1
 // CHECK: @__omp_rtl_assume_no_nested_parallelism = weak_odr hidden constant i32 0
 module attributes {omp.flags = #omp.flags<assume_teams_oversubscription = true, assume_no_thread_state = true>} {}
+
+// -----
+
+module attributes {omp.is_target_device = false} {
+  // CHECK-NOT: @filter_host_nohost
+  llvm.func @filter_host_nohost() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+
+  // CHECK: @filter_host_host
+  llvm.func @filter_host_host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+}
+
+// -----
+
+module attributes {omp.is_target_device = true} {
+  // CHECK: @filter_device_nohost
+  llvm.func @filter_device_nohost() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (nohost), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+
+  // CHECK-NOT: @filter_device_host
+  llvm.func @filter_device_host() -> ()
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    llvm.return
+  }
+}


        


More information about the Mlir-commits mailing list