[Mlir-commits] [mlir] [flang] [MLIR][OpenMP] Changes to function-filtering pass (PR #71850)
Akash Banerjee
llvmlistbot at llvm.org
Fri Nov 10 08:43:35 PST 2023
https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/71850
>From b36fdd239462690cf3b2f882401626799812b0b7 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 9 Nov 2023 18:22:49 +0000
Subject: [PATCH] [MLIR][OpenMP] Changes to function-filtering pass
Currently, when deleting the device functions in the second stage of filtering during MLIR to LLVM translation we can end up with invalid calls to these functions. This is because of the removal of the EarlyOutliningPass which would have otherwise gotten rid of any such calls.
This patch aims to alter the function filtering pass in the following way:
- Any host function is completely removed.
- Call to the host function are also removed and their uses replaced with Undef values.
- Any host function with target region code is marked to be removed during the the second stage.
- Calls to such functions are still removed and their uses replaced with Undef values.
Coauthored by: @skatrak(Sergio.Afonso at amd.com)
---
.../Transforms/OMPFunctionFiltering.cpp | 46 ++++++++++----
flang/test/Driver/OpenMP/target-filtering.f90 | 61 +++++++++++++++++++
flang/test/Lower/OpenMP/FIR/array-bounds.f90 | 4 +-
.../test/Lower/OpenMP/function-filtering.f90 | 8 +--
.../Transforms/omp-function-filtering.mlir | 24 ++++++--
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 2 -
6 files changed, 117 insertions(+), 28 deletions(-)
create mode 100644 flang/test/Driver/OpenMP/target-filtering.f90
diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
index 43fa5b7c4de2414..6430e1438621956 100644
--- a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
+++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
@@ -11,6 +11,7 @@
//
//===----------------------------------------------------------------------===//
+#include "flang/Optimizer/Dialect/FIRDialect.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -32,7 +33,15 @@ class OMPFunctionFilteringPass
public:
OMPFunctionFilteringPass() = default;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ // fir::UndefOp creation requires that FIROpsDialect dialect is loaded.
+ registry.insert<fir::FIROpsDialect>();
+ }
+
void runOnOperation() override {
+ MLIRContext *context = &getContext();
+ context->getOrLoadDialect<fir::FIROpsDialect>();
+ OpBuilder opBuilder(context);
auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
if (!op || !op.getIsTargetDevice())
return;
@@ -46,8 +55,6 @@ class OMPFunctionFilteringPass
->walk<WalkOrder::PreOrder>(
[&](omp::TargetOp) { return WalkResult::interrupt(); })
.wasInterrupted();
- if (hasTargetRegion)
- return;
omp::DeclareTargetDeviceType declareType =
omp::DeclareTargetDeviceType::host;
@@ -56,18 +63,31 @@ class OMPFunctionFilteringPass
if (declareTargetOp && declareTargetOp.isDeclareTarget())
declareType = declareTargetOp.getDeclareTargetDeviceType();
- // Filtering a function here means removing its body and explicitly
- // setting its omp.declare_target attribute, so that following
- // translation/lowering/transformation passes will skip processing its
- // contents, but preventing the calls to undefined symbols that could
- // result if the function were deleted. The second stage of function
- // filtering, at the MLIR to LLVM IR translation level, will remove these
- // from the IR thanks to the mismatch between the omp.declare_target
- // attribute and the target device.
+ // Filtering a function here means deleting it if it doesn't containt a
+ // target region. Else we explicitly set the omp.declare_target
+ // attribute. The second stage of function filtering at the MLIR to LLVM
+ // IR translation level will remove functions that contain the target
+ // region from the generated llvm IR.
if (declareType == omp::DeclareTargetDeviceType::host) {
- funcOp.eraseBody();
- funcOp.setVisibility(SymbolTable::Visibility::Private);
- if (declareTargetOp)
+ SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
+ for (SymbolTable::SymbolUse use : funcUses) {
+ Operation *callOp = use.getUser();
+ // If the callOp has users then replace them with Undef values.
+ if (!callOp->use_empty()) {
+ SmallVector<Value> undefResults;
+ for (Value res : callOp->getResults()) {
+ opBuilder.setInsertionPoint(callOp);
+ undefResults.emplace_back(
+ opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
+ }
+ callOp->replaceAllUsesWith(undefResults);
+ }
+ // Remove the callOp
+ callOp->erase();
+ }
+ if (!hasTargetRegion)
+ funcOp.erase();
+ else if (declareTargetOp)
declareTargetOp.setDeclareTarget(declareType,
omp::DeclareTargetCaptureClause::to);
}
diff --git a/flang/test/Driver/OpenMP/target-filtering.f90 b/flang/test/Driver/OpenMP/target-filtering.f90
new file mode 100644
index 000000000000000..5db5ade6c119bf5
--- /dev/null
+++ b/flang/test/Driver/OpenMP/target-filtering.f90
@@ -0,0 +1,61 @@
+!RUN: %flang_fc1 -emit-llvm -fopenmp %s -o - | FileCheck %s --check-prefixes HOST,ALL
+!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes DEVICE,ALL
+
+!HOST: define {{.*}}@{{.*}}before{{.*}}(
+!DEVICE-NOT: define {{.*}}@before{{.*}}(
+!DEVICE-NOT: declare {{.*}}@before{{.*}}
+integer function before(x)
+ integer, intent(in) :: x
+ before = x + 200
+end function
+
+!ALL: define {{.*}}@{{.*}}main{{.*}}(
+program main
+ integer :: x, before, after
+ !$omp target map(tofrom : x)
+ x = 100
+ !$omp end target
+ !HOST: call {{.*}}@{{.*}}before{{.*}}(
+ !DEVICE-NOT: call {{.*}}@before{{.*}}(
+ !HOST: call {{.*}}@{{.*}}after{{.*}}(
+ !DEVICE-NOT: call {{.*}}@after{{.*}}(
+ x = x + before(x) + after(x)
+end program
+
+!HOST: define {{.*}}@{{.*}}after{{.*}}(
+!DEVICE-NOT: define {{.*}}@after{{.*}}(
+!DEVICE-NOT: declare {{.*}}@after{{.*}}
+integer function after(x)
+ integer, intent(in) :: x
+ after = x + 300
+end function
+
+!ALL: define {{.*}}@{{.*}}before_target{{.*}}(
+subroutine before_target(x)
+ integer, intent(out) :: x
+ !$omp target map(from: x)
+ x = 1
+ !$omp end target
+end subroutine
+
+!ALL: define {{.*}}@{{.*}}middle{{.*}}(
+subroutine middle()
+ integer :: x
+ !$omp target map(from: x)
+ x = 0
+ !$omp end target
+ !HOST: call {{.*}}@{{.*}}before_target{{.*}}(
+ !DEVICE-NOT: call {{.*}}@{{.*}}before_target{{.*}}(
+ !HOST: call {{.*}}@{{.*}}after_target{{.*}}(
+ !DEVICE-NOT: call {{.*}}@{{.*}}after_target{{.*}}(
+ call before_target(x)
+ call after_target(x)
+end subroutine
+
+!ALL: define {{.*}}@{{.*}}after_target{{.*}}(
+subroutine after_target(x)
+ integer, intent(out) :: x
+ !$omp target map(from:x)
+ x = 2
+ !$omp end target
+end subroutine
diff --git a/flang/test/Lower/OpenMP/FIR/array-bounds.f90 b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
index abef31af22ba663..deb2e063f9db714 100644
--- a/flang/test/Lower/OpenMP/FIR/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
@@ -32,7 +32,7 @@ end subroutine read_write_section
module assumed_array_routines
contains
!ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
!ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
!ALL: %[[C0:.*]] = arith.constant 1 : index
!ALL: %[[C1:.*]] = arith.constant 0 : index
@@ -56,7 +56,7 @@ subroutine assumed_shape_array(arr_read_write)
end subroutine assumed_shape_array
!ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_size_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
!ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
!ALL: %[[C0:.*]] = arith.constant 1 : index
!ALL: %[[C1:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/function-filtering.f90 b/flang/test/Lower/OpenMP/function-filtering.f90
index e550348e50692c5..45d8c2e2533d07a 100644
--- a/flang/test/Lower/OpenMP/function-filtering.f90
+++ b/flang/test/Lower/OpenMP/function-filtering.f90
@@ -21,8 +21,7 @@ end function device_fn
! MLIR-HOST: func.func @{{.*}}host_fn(
! MLIR-HOST: return
-! MLIR-DEVICE: func.func private @{{.*}}host_fn(
-! MLIR-DEVICE-NOT: return
+! MLIR-DEVICE-NOT: func.func {{.*}}host_fn(
! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
! LLVM-DEVICE-NOT: {{.*}} @{{.*}}host_fn{{.*}}(
@@ -32,9 +31,8 @@ function host_fn() result(x)
x = 10
end function host_fn
-! MLIR-HOST: func.func @{{.*}}target_subr(
-! MLIR-HOST: return
-! MLIR-DEVICE: return
+! MLIR-ALL: func.func @{{.*}}target_subr(
+! MLIR-ALL: return
! LLVM-HOST: define {{.*}} @{{.*}}target_subr{{.*}}(
! LLVM-ALL: define {{.*}} @__omp_offloading_{{.*}}_{{.*}}_target_subr__{{.*}}(
diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir
index 44777e2cac30c50..46291e9321f76a2 100644
--- a/flang/test/Transforms/omp-function-filtering.mlir
+++ b/flang/test/Transforms/omp-function-filtering.mlir
@@ -4,16 +4,18 @@
// CHECK: return
// CHECK: func.func @nohost
// CHECK: return
-// CHECK: func.func private @host
-// CHECK-NOT: return
-// CHECK: func.func private @none
-// CHECK-NOT: return
+// CHECK-NOT: func.func {{.*}}}} @host
+// CHECK-NOT: func.func {{.*}}}} @none
// CHECK: func.func @nohost_target
// CHECK: return
// CHECK: func.func @host_target
// CHECK: return
// CHECK: func.func @none_target
// CHECK: return
+// CHECK: func.func @host_target_call
+// CHECK-NOT: call @none_target
+// CHECK: %[[UNDEF:.*]] = fir.undefined i32
+// CHECK: return %[[UNDEF]] : i32
module attributes {omp.is_target_device = true} {
func.func @any() -> ()
attributes {
@@ -55,9 +57,19 @@ module attributes {omp.is_target_device = true} {
omp.target {}
func.return
}
- func.func @none_target() -> () {
+ func.func @none_target() -> i32 {
omp.target {}
- func.return
+ %0 = arith.constant 25 : i32
+ func.return %0 : i32
+ }
+ func.func @host_target_call() -> i32
+ attributes {
+ omp.declare_target =
+ #omp.declaretarget<device_type = (host), capture_clause = (to)>
+ } {
+ omp.target {}
+ %0 = call @none_target() : () -> i32
+ func.return %0 : i32
}
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d9ab785a082835d..4c1dc6603af1599 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2496,8 +2496,6 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
if (declareType == omp::DeclareTargetDeviceType::host) {
llvm::Function *llvmFunc =
moduleTranslation.lookupFunction(funcOp.getName());
- llvmFunc->replaceAllUsesWith(
- llvm::UndefValue::get(llvmFunc->getType()));
llvmFunc->dropAllReferences();
llvmFunc->eraseFromParent();
}
More information about the Mlir-commits
mailing list