[flang-commits] [flang] [flang][openacc] Do not generate duplicate routine op (PR #68348)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Oct 5 13:03:07 PDT 2023


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/68348

This patch updates the lowering of OpenACC routine directive to avoid creating duplicate acc.routine operations when all the clauses are identical. If clauses differ an error is raised. 

>From bf4cfe38195b7efd6ed9142cda16b6ac9eb9eb13 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 5 Oct 2023 10:22:27 -0700
Subject: [PATCH] [flang][openacc] Do not genarate duplicate routine op

---
 flang/lib/Lower/OpenACC.cpp              | 60 ++++++++++++++++--------
 flang/test/Lower/OpenACC/acc-routine.f90 | 20 ++++++++
 2 files changed, 61 insertions(+), 19 deletions(-)

diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 9670cc01b593b7e..d6cb6bad0f41117 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2977,19 +2977,17 @@ genACC(Fortran::lower::AbstractConverter &converter,
     funcName = funcOp.getName();
   }
 
-  mlir::OpBuilder modBuilder(mod.getBodyRegion());
-  std::stringstream routineOpName;
-  routineOpName << accRoutinePrefix.str() << routineCounter++;
-  auto routineOp = modBuilder.create<mlir::acc::RoutineOp>(
-      loc, routineOpName.str(), funcName, mlir::StringAttr{}, false, false,
-      false, false, false, false, mlir::IntegerAttr{});
+  bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
+       hasNohost = false;
+  std::optional<std::string> bindName = std::nullopt;
+  std::optional<int64_t> gangDim = std::nullopt;
 
   for (const Fortran::parser::AccClause &clause : clauses.v) {
     if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
-      routineOp.setSeqAttr(builder.getUnitAttr());
+      hasSeq = true;
     } else if (const auto *gangClause =
                    std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
-      routineOp.setGangAttr(builder.getUnitAttr());
+      hasGang = true;
       if (gangClause->v) {
         const Fortran::parser::AccGangArgList &x = *gangClause->v;
         for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3000,36 +2998,60 @@ genACC(Fortran::lower::AbstractConverter &converter,
             if (!dimValue)
               mlir::emitError(loc,
                               "dim value must be a constant positive integer");
-            routineOp.setGangDimAttr(
-                builder.getIntegerAttr(builder.getIntegerType(32), *dimValue));
+            gangDim = *dimValue;
           }
         }
       }
     } else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
-      routineOp.setVectorAttr(builder.getUnitAttr());
+      hasVector = true;
     } else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
-      routineOp.setWorkerAttr(builder.getUnitAttr());
+      hasWorker = true;
     } else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
-      routineOp.setNohostAttr(builder.getUnitAttr());
+      hasNohost = true;
     } else if (const auto *bindClause =
                    std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
       if (const auto *name =
               std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
-        routineOp.setBindName(
-            builder.getStringAttr(converter.mangleName(*name->symbol)));
+        bindName = converter.mangleName(*name->symbol);
       } else if (const auto charExpr =
                      std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
                          &bindClause->v.u)) {
-        const std::optional<std::string> bindName =
+        const std::optional<std::string> name =
             Fortran::semantics::GetConstExpr<std::string>(semanticsContext,
                                                           *charExpr);
-        if (!bindName)
-          routineOp.emitError("Could not retrieve the bind name");
-        routineOp.setBindName(builder.getStringAttr(*bindName));
+        if (!name)
+          mlir::emitError(loc, "Could not retrieve the bind name");
+        bindName = *name;
       }
     }
   }
 
+  mlir::OpBuilder modBuilder(mod.getBodyRegion());
+  std::stringstream routineOpName;
+  routineOpName << accRoutinePrefix.str() << routineCounter++;
+
+  for (auto routineOp : mod.getOps<mlir::acc::RoutineOp>()) {
+    if (routineOp.getFuncName().str().compare(funcName) == 0) {
+      // If the routine is already specified with the same clauses, just skip
+      // the operation creation.
+      if (routineOp.getBindName() == bindName &&
+          routineOp.getGang() == hasGang &&
+          routineOp.getWorker() == hasWorker &&
+          routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
+          routineOp.getNohost() == hasNohost &&
+          routineOp.getGangDim() == gangDim)
+        return;
+      mlir::emitError(loc, "Routine already specified with different clauses");
+    }
+  }
+
+  modBuilder.create<mlir::acc::RoutineOp>(
+      loc, routineOpName.str(), funcName,
+      bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
+      hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
+      gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
+              : mlir::IntegerAttr{});
+
   if (funcOp)
     attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
   else
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index f9fc9b1a0b4b75a..7514e0a8819fae9 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -3,6 +3,8 @@
 ! RUN: bbc -fopenacc -emit-fir %s -o - | FileCheck %s
 ! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
 
+
+! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
 ! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
 ! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
 ! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
@@ -76,3 +78,21 @@ function acc_routine10()
 end function
 
 ! CHECK-LABEL: func.func @_QPacc_routine10() -> f32 attributes {acc.routine_info = #acc.routine_info<[@acc_routine_9]>}
+
+subroutine acc_routine11(a)
+  real :: a
+  !$acc routine(acc_routine11) seq
+end subroutine
+
+! CHECK-LABEL: func.func @_QPacc_routine11(%arg0: !fir.ref<f32> {fir.bindc_name = "a"}) attributes {acc.routine_info = #acc.routine_info<[@acc_routine_10]>}
+
+subroutine acc_routine12()
+
+  interface
+  subroutine acc_routine11(a)
+    real :: a
+    !$acc routine(acc_routine11) seq
+  end subroutine
+  end interface
+
+end subroutine



More information about the flang-commits mailing list