[flang-commits] [mlir] [flang] [MLIR][OpenMP] Changes to function-filtering pass (PR #71850)

Akash Banerjee via flang-commits flang-commits at lists.llvm.org
Tue Nov 14 03:52:12 PST 2023


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/71850

>From 71458a4be43cc91649b3f6aa7879e138816281dd Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Thu, 9 Nov 2023 18:22:49 +0000
Subject: [PATCH] [MLIR][OpenMP] Changes to function-filtering pass

Currently, when deleting the device functions in the second stage of filtering during MLIR to LLVM translation we can end up with invalid calls to these functions. This is because of the removal of the EarlyOutliningPass which would have otherwise gotten rid of any such calls.

This patch aims to alter the function filtering pass in the following way:
	- Any host function is completely removed.
	- Call to the host function are also removed and their uses replaced with Undef values.
	- Any host function with target region code is marked to be removed during the the second stage.
	- Calls to such functions are still removed and their uses replaced with Undef values.

Co-authored-by: Sergio Afonso <sergio.afonsofumero at amd.com>
---
 .../flang/Optimizer/Transforms/Passes.td      |  3 +-
 .../Transforms/OMPFunctionFiltering.cpp       | 40 +++++++++++++------
 flang/test/Lower/OpenMP/FIR/array-bounds.f90  |  4 +-
 .../test/Lower/OpenMP/function-filtering.f90  |  8 ++--
 .../Transforms/omp-function-filtering.mlir    | 24 ++++++++---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  2 -
 6 files changed, 52 insertions(+), 29 deletions(-)

diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 6e23b87b7e276e9..c3768fd2d689c1a 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -330,7 +330,8 @@ def OMPFunctionFiltering : Pass<"omp-function-filtering"> {
                 "for the target device.";
   let constructor = "::fir::createOMPFunctionFilteringPass()";
   let dependentDialects = [
-    "mlir::func::FuncDialect"
+    "mlir::func::FuncDialect",
+    "fir::FIROpsDialect"
   ];
 }
 
diff --git a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
index 43fa5b7c4de2414..e1ae4b452defab0 100644
--- a/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
+++ b/flang/lib/Optimizer/Transforms/OMPFunctionFiltering.cpp
@@ -11,6 +11,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Dialect/FIRDialect.h"
 #include "flang/Optimizer/Transforms/Passes.h"
 
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -33,6 +34,8 @@ class OMPFunctionFilteringPass
   OMPFunctionFilteringPass() = default;
 
   void runOnOperation() override {
+    MLIRContext *context = &getContext();
+    OpBuilder opBuilder(context);
     auto op = dyn_cast<omp::OffloadModuleInterface>(getOperation());
     if (!op || !op.getIsTargetDevice())
       return;
@@ -46,8 +49,6 @@ class OMPFunctionFilteringPass
               ->walk<WalkOrder::PreOrder>(
                   [&](omp::TargetOp) { return WalkResult::interrupt(); })
               .wasInterrupted();
-      if (hasTargetRegion)
-        return;
 
       omp::DeclareTargetDeviceType declareType =
           omp::DeclareTargetDeviceType::host;
@@ -56,18 +57,31 @@ class OMPFunctionFilteringPass
       if (declareTargetOp && declareTargetOp.isDeclareTarget())
         declareType = declareTargetOp.getDeclareTargetDeviceType();
 
-      // Filtering a function here means removing its body and explicitly
-      // setting its omp.declare_target attribute, so that following
-      // translation/lowering/transformation passes will skip processing its
-      // contents, but preventing the calls to undefined symbols that could
-      // result if the function were deleted. The second stage of function
-      // filtering, at the MLIR to LLVM IR translation level, will remove these
-      // from the IR thanks to the mismatch between the omp.declare_target
-      // attribute and the target device.
+      // Filtering a function here means deleting it if it doesn't containt a
+      // target region. Else we explicitly set the omp.declare_target
+      // attribute. The second stage of function filtering at the MLIR to LLVM
+      // IR translation level will remove functions that contain the target
+      // region from the generated llvm IR.
       if (declareType == omp::DeclareTargetDeviceType::host) {
-        funcOp.eraseBody();
-        funcOp.setVisibility(SymbolTable::Visibility::Private);
-        if (declareTargetOp)
+        SymbolTable::UseRange funcUses = *funcOp.getSymbolUses(op);
+        for (SymbolTable::SymbolUse use : funcUses) {
+          Operation *callOp = use.getUser();
+          // If the callOp has users then replace them with Undef values.
+          if (!callOp->use_empty()) {
+            SmallVector<Value> undefResults;
+            for (Value res : callOp->getResults()) {
+              opBuilder.setInsertionPoint(callOp);
+              undefResults.emplace_back(
+                  opBuilder.create<fir::UndefOp>(res.getLoc(), res.getType()));
+            }
+            callOp->replaceAllUsesWith(undefResults);
+          }
+          // Remove the callOp
+          callOp->erase();
+        }
+        if (!hasTargetRegion)
+          funcOp.erase();
+        else if (declareTargetOp)
           declareTargetOp.setDeclareTarget(declareType,
                                            omp::DeclareTargetCaptureClause::to);
       }
diff --git a/flang/test/Lower/OpenMP/FIR/array-bounds.f90 b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
index 7e6ac02aefe6057..02b5ebcee022675 100644
--- a/flang/test/Lower/OpenMP/FIR/array-bounds.f90
+++ b/flang/test/Lower/OpenMP/FIR/array-bounds.f90
@@ -32,7 +32,7 @@ end subroutine read_write_section
 module assumed_array_routines
 contains
 !ALL-LABEL: func.func @_QMassumed_array_routinesPassumed_shape_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
 !ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_shape_arrayEi"}
 !ALL: %[[C0:.*]] = arith.constant 1 : index
 !ALL: %[[C1:.*]] = arith.constant 0 : index
@@ -56,7 +56,7 @@ subroutine assumed_shape_array(arr_read_write)
         end subroutine assumed_shape_array
 
 !ALL-LABEL:   func.func @_QMassumed_array_routinesPassumed_size_array(
-!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"}) {
+!ALL-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<?xi32>> {fir.bindc_name = "arr_read_write"})
 !ALL: %[[ALLOCA:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMassumed_array_routinesFassumed_size_arrayEi"}
 !ALL: %[[C0:.*]] = arith.constant 1 : index
 !ALL: %[[C1:.*]] = arith.constant 1 : index
diff --git a/flang/test/Lower/OpenMP/function-filtering.f90 b/flang/test/Lower/OpenMP/function-filtering.f90
index e550348e50692c5..45d8c2e2533d07a 100644
--- a/flang/test/Lower/OpenMP/function-filtering.f90
+++ b/flang/test/Lower/OpenMP/function-filtering.f90
@@ -21,8 +21,7 @@ end function device_fn
 
 ! MLIR-HOST: func.func @{{.*}}host_fn(
 ! MLIR-HOST: return
-! MLIR-DEVICE: func.func private @{{.*}}host_fn(
-! MLIR-DEVICE-NOT: return
+! MLIR-DEVICE-NOT: func.func {{.*}}host_fn(
 
 ! LLVM-HOST: define {{.*}} @{{.*}}host_fn{{.*}}(
 ! LLVM-DEVICE-NOT: {{.*}} @{{.*}}host_fn{{.*}}(
@@ -32,9 +31,8 @@ function host_fn() result(x)
   x = 10
 end function host_fn
 
-! MLIR-HOST: func.func @{{.*}}target_subr(
-! MLIR-HOST: return
-! MLIR-DEVICE: return
+! MLIR-ALL: func.func @{{.*}}target_subr(
+! MLIR-ALL: return
 
 ! LLVM-HOST: define {{.*}} @{{.*}}target_subr{{.*}}(
 ! LLVM-ALL: define {{.*}} @__omp_offloading_{{.*}}_{{.*}}_target_subr__{{.*}}(
diff --git a/flang/test/Transforms/omp-function-filtering.mlir b/flang/test/Transforms/omp-function-filtering.mlir
index 44777e2cac30c50..46291e9321f76a2 100644
--- a/flang/test/Transforms/omp-function-filtering.mlir
+++ b/flang/test/Transforms/omp-function-filtering.mlir
@@ -4,16 +4,18 @@
 // CHECK: return
 // CHECK: func.func @nohost
 // CHECK: return
-// CHECK: func.func private @host
-// CHECK-NOT: return
-// CHECK: func.func private @none
-// CHECK-NOT: return
+// CHECK-NOT: func.func {{.*}}}} @host
+// CHECK-NOT: func.func {{.*}}}} @none
 // CHECK: func.func @nohost_target
 // CHECK: return
 // CHECK: func.func @host_target
 // CHECK: return
 // CHECK: func.func @none_target
 // CHECK: return
+// CHECK: func.func @host_target_call
+// CHECK-NOT: call @none_target
+// CHECK: %[[UNDEF:.*]] = fir.undefined i32
+// CHECK: return %[[UNDEF]] : i32
 module attributes {omp.is_target_device = true} {
   func.func @any() -> ()
       attributes {
@@ -55,9 +57,19 @@ module attributes {omp.is_target_device = true} {
     omp.target {}
     func.return
   }
-  func.func @none_target() -> () {
+  func.func @none_target() -> i32 {
     omp.target {}
-    func.return
+    %0 = arith.constant 25 : i32
+    func.return %0 : i32
+  }
+  func.func @host_target_call() -> i32
+      attributes {
+        omp.declare_target =
+          #omp.declaretarget<device_type = (host), capture_clause = (to)>
+      } {
+    omp.target {}
+    %0 = call @none_target() : () -> i32
+    func.return %0 : i32
   }
 }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 5481f9d8070700a..ae13a745196c42d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2496,8 +2496,6 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
       if (declareType == omp::DeclareTargetDeviceType::host) {
         llvm::Function *llvmFunc =
             moduleTranslation.lookupFunction(funcOp.getName());
-        llvmFunc->replaceAllUsesWith(
-            llvm::UndefValue::get(llvmFunc->getType()));
         llvmFunc->dropAllReferences();
         llvmFunc->eraseFromParent();
       }



More information about the flang-commits mailing list