[flang-commits] [flang] [flang][cuda] Extends matching distance computation (PR #91810)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Fri May 10 14:49:45 PDT 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/91810

>From c91685310de6a4d4805e2851b2968ba3d9bba512 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 10 May 2024 10:39:46 -0700
Subject: [PATCH 1/2] [flang][cuda] Extends matching distance computation

---
 flang/include/flang/Common/Fortran-features.h |  4 +-
 flang/include/flang/Common/Fortran.h          |  4 +-
 flang/lib/Common/Fortran.cpp                  | 37 +++++++++----
 flang/lib/Semantics/check-call.cpp            |  2 +-
 flang/lib/Semantics/expression.cpp            | 46 +++++++++++++---
 flang/test/Semantics/cuf14.cuf                | 55 +++++++++++++++++++
 flang/test/Semantics/cuf15.cuf                | 55 +++++++++++++++++++
 flang/tools/bbc/bbc.cpp                       | 10 ++++
 8 files changed, 192 insertions(+), 21 deletions(-)
 create mode 100644 flang/test/Semantics/cuf14.cuf
 create mode 100644 flang/test/Semantics/cuf15.cuf

diff --git a/flang/include/flang/Common/Fortran-features.h b/flang/include/flang/Common/Fortran-features.h
index 07ed7f43c1e73..f930490716fdc 100644
--- a/flang/include/flang/Common/Fortran-features.h
+++ b/flang/include/flang/Common/Fortran-features.h
@@ -49,7 +49,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
     IndistinguishableSpecifics, SubroutineAndFunctionSpecifics,
     EmptySequenceType, NonSequenceCrayPointee, BranchIntoConstruct,
     BadBranchTarget, ConvertedArgument, HollerithPolymorphic, ListDirectedSize,
-    NonBindCInteroperability)
+    NonBindCInteroperability, GpuManaged, GpuUnified)
 
 // Portability and suspicious usage warnings
 ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
@@ -81,6 +81,8 @@ class LanguageFeatureControl {
     disable_.set(LanguageFeature::OpenACC);
     disable_.set(LanguageFeature::OpenMP);
     disable_.set(LanguageFeature::CUDA); // !@cuf
+    disable_.set(LanguageFeature::GpuManaged);
+    disable_.set(LanguageFeature::GpuUnified);
     disable_.set(LanguageFeature::ImplicitNoneTypeNever);
     disable_.set(LanguageFeature::ImplicitNoneTypeAlways);
     disable_.set(LanguageFeature::DefaultSave);
diff --git a/flang/include/flang/Common/Fortran.h b/flang/include/flang/Common/Fortran.h
index 3b965fe60c2f0..0701e3e8b64cc 100644
--- a/flang/include/flang/Common/Fortran.h
+++ b/flang/include/flang/Common/Fortran.h
@@ -19,6 +19,7 @@
 #include <string>
 
 namespace Fortran::common {
+class LanguageFeatureControl;
 
 // Fortran has five kinds of intrinsic data types, plus the derived types.
 ENUM_CLASS(TypeCategory, Integer, Real, Complex, Character, Logical, Derived)
@@ -115,7 +116,8 @@ static constexpr IgnoreTKRSet ignoreTKRAll{IgnoreTKR::Type, IgnoreTKR::Kind,
 std::string AsFortran(IgnoreTKRSet);
 
 bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr>,
-    std::optional<CUDADataAttr>, IgnoreTKRSet, bool allowUnifiedMatchingRule);
+    std::optional<CUDADataAttr>, IgnoreTKRSet, bool allowUnifiedMatchingRule,
+    const LanguageFeatureControl *features = nullptr);
 
 static constexpr char blankCommonObjectName[] = "__BLNK__";
 
diff --git a/flang/lib/Common/Fortran.cpp b/flang/lib/Common/Fortran.cpp
index 170ce8c225092..83ee68e3a62c0 100644
--- a/flang/lib/Common/Fortran.cpp
+++ b/flang/lib/Common/Fortran.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Common/Fortran.h"
+#include "flang/Common/Fortran-features.h"
 
 namespace Fortran::common {
 
@@ -102,7 +103,13 @@ std::string AsFortran(IgnoreTKRSet tkr) {
 /// dummy argument attribute while `y` represents the actual argument attribute.
 bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
     std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
-    bool allowUnifiedMatchingRule) {
+    bool allowUnifiedMatchingRule, const LanguageFeatureControl *features) {
+  bool isGpuManaged = features
+      ? features->IsEnabled(common::LanguageFeature::GpuManaged)
+      : false;
+  bool isGpuUnified = features
+      ? features->IsEnabled(common::LanguageFeature::GpuUnified)
+      : false;
   if (!x && !y) {
     return true;
   } else if (x && y && *x == *y) {
@@ -120,19 +127,27 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
     return true;
   } else if (allowUnifiedMatchingRule) {
     if (!x) { // Dummy argument has no attribute -> host
-      if (y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) {
+      if ((y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) ||
+          (!y && (isGpuUnified || isGpuManaged))) {
         return true;
       }
     } else {
-      if (*x == CUDADataAttr::Device && y &&
-          (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) {
-        return true;
-      } else if (*x == CUDADataAttr::Managed && y &&
-          *y == CUDADataAttr::Unified) {
-        return true;
-      } else if (*x == CUDADataAttr::Unified && y &&
-          *y == CUDADataAttr::Managed) {
-        return true;
+      if (*x == CUDADataAttr::Device) {
+        if ((y &&
+                (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) ||
+            (!y && (isGpuUnified || isGpuManaged))) {
+          return true;
+        }
+      } else if (*x == CUDADataAttr::Managed) {
+        if ((y && *y == CUDADataAttr::Unified) ||
+            (!y && (isGpuUnified || isGpuManaged))) {
+          return true;
+        }
+      } else if (*x == CUDADataAttr::Unified) {
+        if ((y && *y == CUDADataAttr::Managed) ||
+            (!y && (isGpuUnified || isGpuManaged))) {
+          return true;
+        }
       }
     }
     return false;
diff --git a/flang/lib/Semantics/check-call.cpp b/flang/lib/Semantics/check-call.cpp
index 94afcbb68b349..8f51ef5ebeba3 100644
--- a/flang/lib/Semantics/check-call.cpp
+++ b/flang/lib/Semantics/check-call.cpp
@@ -914,7 +914,7 @@ static void CheckExplicitDataArg(const characteristics::DummyDataObject &dummy,
     }
     if (!common::AreCompatibleCUDADataAttrs(dummyDataAttr, actualDataAttr,
             dummy.ignoreTKR,
-            /*allowUnifiedMatchingRule=*/true)) {
+            /*allowUnifiedMatchingRule=*/true, &context.languageFeatures())) {
       auto toStr{[](std::optional<common::CUDADataAttr> x) {
         return x ? "ATTRIBUTES("s +
                 parser::ToUpperCaseLetters(common::EnumToString(*x)) + ")"s
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index c503ea3f0246f..b87370d9b3338 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2501,8 +2501,13 @@ static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
 
 // Compute the matching distance as described in section 3.2.3 of the CUDA
 // Fortran references.
-static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
+static int GetMatchingDistance(const common::LanguageFeatureControl &features,
+    const characteristics::DummyArgument &dummy,
     const std::optional<ActualArgument> &actual) {
+  bool isGpuManaged = features.IsEnabled(common::LanguageFeature::GpuManaged);
+  bool isGpuUnified = features.IsEnabled(common::LanguageFeature::GpuUnified);
+  // assert((isGpuManaged != isGpuUnified) && "expect only one enabled.");
+
   std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
   if (actual) {
     if (auto *expr{actual->UnwrapExpr()}) {
@@ -2529,6 +2534,9 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
 
   if (!dummyDataAttr) {
     if (!actualDataAttr) {
+      if (isGpuUnified || isGpuManaged) {
+        return 3;
+      }
       return 0;
     } else if (*actualDataAttr == common::CUDADataAttr::Device) {
       return cudaInfMatchingValue;
@@ -2538,6 +2546,9 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
     if (!actualDataAttr) {
+      if (isGpuUnified || isGpuManaged) {
+        return 2;
+      }
       return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Device) {
       return 0;
@@ -2546,7 +2557,16 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
       return 2;
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
-    if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+    if (!actualDataAttr) {
+      if (isGpuUnified) {
+        return 1;
+      }
+      if (isGpuManaged) {
+        return 0;
+      }
+      return cudaInfMatchingValue;
+    }
+    if (*actualDataAttr == common::CUDADataAttr::Device) {
       return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
       return 0;
@@ -2554,7 +2574,16 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
       return 1;
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
-    if (!actualDataAttr || *actualDataAttr == common::CUDADataAttr::Device) {
+    if (!actualDataAttr) {
+      if (isGpuUnified) {
+        return 0;
+      }
+      if (isGpuManaged) {
+        return 1;
+      }
+      return cudaInfMatchingValue;
+    }
+    if (*actualDataAttr == common::CUDADataAttr::Device) {
       return cudaInfMatchingValue;
     } else if (*actualDataAttr == common::CUDADataAttr::Managed) {
       return 1;
@@ -2566,6 +2595,7 @@ static int GetMatchingDistance(const characteristics::DummyArgument &dummy,
 }
 
 static int ComputeCudaMatchingDistance(
+    const common::LanguageFeatureControl &features,
     const characteristics::Procedure &procedure,
     const ActualArguments &actuals) {
   const auto &dummies{procedure.dummyArguments};
@@ -2574,7 +2604,7 @@ static int ComputeCudaMatchingDistance(
   for (std::size_t i{0}; i < dummies.size(); ++i) {
     const characteristics::DummyArgument &dummy{dummies[i]};
     const std::optional<ActualArgument> &actual{actuals[i]};
-    int d{GetMatchingDistance(dummy, actual)};
+    int d{GetMatchingDistance(features, dummy, actual)};
     if (d == cudaInfMatchingValue)
       return d;
     distance += d;
@@ -2666,7 +2696,9 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
             CheckCompatibleArguments(*procedure, localActuals)) {
           if ((procedure->IsElemental() && elemental) ||
               (!procedure->IsElemental() && nonElemental)) {
-            int d{ComputeCudaMatchingDistance(*procedure, localActuals)};
+            int d{ComputeCudaMatchingDistance(
+                context_.languageFeatures(), *procedure, localActuals)};
+            llvm::errs() << "matching distance: " << d << "\n";
             if (d != crtMatchingDistance) {
               if (d > crtMatchingDistance) {
                 continue;
@@ -2688,8 +2720,8 @@ std::pair<const Symbol *, bool> ExpressionAnalyzer::ResolveGeneric(
           } else {
             elemental = &specific;
           }
-          crtMatchingDistance =
-              ComputeCudaMatchingDistance(*procedure, localActuals);
+          crtMatchingDistance = ComputeCudaMatchingDistance(
+              context_.languageFeatures(), *procedure, localActuals);
         }
       }
     }
diff --git a/flang/test/Semantics/cuf14.cuf b/flang/test/Semantics/cuf14.cuf
new file mode 100644
index 0000000000000..29c9ecf90677f
--- /dev/null
+++ b/flang/test/Semantics/cuf14.cuf
@@ -0,0 +1,55 @@
+! RUN: bbc -emit-hlfir -fcuda -gpu=unified %s -o - | FileCheck %s
+
+module matching
+  interface host_and_device
+    module procedure sub_host
+    module procedure sub_device
+  end interface
+
+  interface all
+    module procedure sub_host
+    module procedure sub_device
+    module procedure sub_managed
+    module procedure sub_unified
+  end interface
+
+  interface all_without_unified
+    module procedure sub_host
+    module procedure sub_device
+    module procedure sub_managed
+  end interface
+
+contains
+  subroutine sub_host(a)
+    integer :: a(:)
+  end
+
+  subroutine sub_device(a)
+    integer, device :: a(:)
+  end
+
+  subroutine sub_managed(a)
+    integer, managed :: a(:)
+  end
+
+  subroutine sub_unified(a)
+    integer, unified :: a(:)
+  end
+end module
+
+program m
+  use matching
+
+  integer, allocatable :: actual_host(:)
+
+  allocate(actual_host(10))
+
+  call host_and_device(actual_host)     ! Should resolve to sub_device
+  call all(actual_host)                 ! Should resolved to unified
+  call all_without_unified(actual_host) ! Should resolved to managed
+end
+
+! CHECK: fir.call @_QMmatchingPsub_device
+! CHECK: fir.call @_QMmatchingPsub_unified
+! CHECK: fir.call @_QMmatchingPsub_managed
+
diff --git a/flang/test/Semantics/cuf15.cuf b/flang/test/Semantics/cuf15.cuf
new file mode 100644
index 0000000000000..030dd6ff8ffe8
--- /dev/null
+++ b/flang/test/Semantics/cuf15.cuf
@@ -0,0 +1,55 @@
+! RUN: bbc -emit-hlfir -fcuda -gpu=managed %s -o - | FileCheck %s
+
+module matching
+  interface host_and_device
+    module procedure sub_host
+    module procedure sub_device
+  end interface
+
+  interface all
+    module procedure sub_host
+    module procedure sub_device
+    module procedure sub_managed
+    module procedure sub_unified
+  end interface
+
+  interface all_without_managed
+    module procedure sub_host
+    module procedure sub_device
+    module procedure sub_unified
+  end interface
+
+contains
+  subroutine sub_host(a)
+    integer :: a(:)
+  end
+
+  subroutine sub_device(a)
+    integer, device :: a(:)
+  end
+
+  subroutine sub_managed(a)
+    integer, managed :: a(:)
+  end
+
+  subroutine sub_unified(a)
+    integer, unified :: a(:)
+  end
+end module
+
+program m
+  use matching
+
+  integer, allocatable :: actual_host(:)
+
+  allocate(actual_host(10))
+
+  call host_and_device(actual_host)     ! Should resolve to sub_device
+  call all(actual_host)                 ! Should resolved to unified
+  call all_without_managed(actual_host) ! Should resolved to managed
+end
+
+! CHECK: fir.call @_QMmatchingPsub_device
+! CHECK: fir.call @_QMmatchingPsub_managed
+! CHECK: fir.call @_QMmatchingPsub_unified
+
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index ee2ff8562e9ff..085b988c6af5f 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -204,6 +204,10 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
                                       llvm::cl::desc("enable CUDA Fortran"),
                                       llvm::cl::init(false));
 
+static llvm::cl::opt<std::string>
+    enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
+                  llvm::cl::init(""));
+
 static llvm::cl::opt<bool> fixedForm("ffixed-form",
                                      llvm::cl::desc("enable fixed form"),
                                      llvm::cl::init(false));
@@ -495,6 +499,12 @@ int main(int argc, char **argv) {
     options.features.Enable(Fortran::common::LanguageFeature::CUDA);
   }
 
+  if (enableGPUMode == "managed") {
+    options.features.Enable(Fortran::common::LanguageFeature::GpuManaged);
+  } else if (enableGPUMode == "unified") {
+    options.features.Enable(Fortran::common::LanguageFeature::GpuUnified);
+  }
+
   if (fixedForm) {
     options.isFixedForm = fixedForm;
   }

>From 6cf2d7d714977248696f6b7d3f40516d53013a17 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Fri, 10 May 2024 14:49:29 -0700
Subject: [PATCH 2/2] Address comments

---
 flang/include/flang/Common/Fortran-features.h |  6 ++---
 flang/lib/Common/Fortran.cpp                  | 16 ++++++------
 flang/lib/Semantics/expression.cpp            | 26 +++++--------------
 flang/tools/bbc/bbc.cpp                       |  4 +--
 4 files changed, 20 insertions(+), 32 deletions(-)

diff --git a/flang/include/flang/Common/Fortran-features.h b/flang/include/flang/Common/Fortran-features.h
index f930490716fdc..f57fcdc895adc 100644
--- a/flang/include/flang/Common/Fortran-features.h
+++ b/flang/include/flang/Common/Fortran-features.h
@@ -49,7 +49,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
     IndistinguishableSpecifics, SubroutineAndFunctionSpecifics,
     EmptySequenceType, NonSequenceCrayPointee, BranchIntoConstruct,
     BadBranchTarget, ConvertedArgument, HollerithPolymorphic, ListDirectedSize,
-    NonBindCInteroperability, GpuManaged, GpuUnified)
+    NonBindCInteroperability, CudaManaged, CudaUnified)
 
 // Portability and suspicious usage warnings
 ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
@@ -81,8 +81,8 @@ class LanguageFeatureControl {
     disable_.set(LanguageFeature::OpenACC);
     disable_.set(LanguageFeature::OpenMP);
     disable_.set(LanguageFeature::CUDA); // !@cuf
-    disable_.set(LanguageFeature::GpuManaged);
-    disable_.set(LanguageFeature::GpuUnified);
+    disable_.set(LanguageFeature::CudaManaged);
+    disable_.set(LanguageFeature::CudaUnified);
     disable_.set(LanguageFeature::ImplicitNoneTypeNever);
     disable_.set(LanguageFeature::ImplicitNoneTypeAlways);
     disable_.set(LanguageFeature::DefaultSave);
diff --git a/flang/lib/Common/Fortran.cpp b/flang/lib/Common/Fortran.cpp
index 83ee68e3a62c0..f914155bdbe5b 100644
--- a/flang/lib/Common/Fortran.cpp
+++ b/flang/lib/Common/Fortran.cpp
@@ -104,11 +104,11 @@ std::string AsFortran(IgnoreTKRSet tkr) {
 bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
     std::optional<CUDADataAttr> y, IgnoreTKRSet ignoreTKR,
     bool allowUnifiedMatchingRule, const LanguageFeatureControl *features) {
-  bool isGpuManaged = features
-      ? features->IsEnabled(common::LanguageFeature::GpuManaged)
+  bool isCudaManaged = features
+      ? features->IsEnabled(common::LanguageFeature::CudaManaged)
       : false;
-  bool isGpuUnified = features
-      ? features->IsEnabled(common::LanguageFeature::GpuUnified)
+  bool isCudaUnified = features
+      ? features->IsEnabled(common::LanguageFeature::CudaUnified)
       : false;
   if (!x && !y) {
     return true;
@@ -128,24 +128,24 @@ bool AreCompatibleCUDADataAttrs(std::optional<CUDADataAttr> x,
   } else if (allowUnifiedMatchingRule) {
     if (!x) { // Dummy argument has no attribute -> host
       if ((y && (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) ||
-          (!y && (isGpuUnified || isGpuManaged))) {
+          (!y && (isCudaUnified || isCudaManaged))) {
         return true;
       }
     } else {
       if (*x == CUDADataAttr::Device) {
         if ((y &&
                 (*y == CUDADataAttr::Managed || *y == CUDADataAttr::Unified)) ||
-            (!y && (isGpuUnified || isGpuManaged))) {
+            (!y && (isCudaUnified || isCudaManaged))) {
           return true;
         }
       } else if (*x == CUDADataAttr::Managed) {
         if ((y && *y == CUDADataAttr::Unified) ||
-            (!y && (isGpuUnified || isGpuManaged))) {
+            (!y && (isCudaUnified || isCudaManaged))) {
           return true;
         }
       } else if (*x == CUDADataAttr::Unified) {
         if ((y && *y == CUDADataAttr::Managed) ||
-            (!y && (isGpuUnified || isGpuManaged))) {
+            (!y && (isCudaUnified || isCudaManaged))) {
           return true;
         }
       }
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b87370d9b3338..06e38da6626a9 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -2504,9 +2504,9 @@ static constexpr int cudaInfMatchingValue{std::numeric_limits<int>::max()};
 static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     const characteristics::DummyArgument &dummy,
     const std::optional<ActualArgument> &actual) {
-  bool isGpuManaged = features.IsEnabled(common::LanguageFeature::GpuManaged);
-  bool isGpuUnified = features.IsEnabled(common::LanguageFeature::GpuUnified);
-  // assert((isGpuManaged != isGpuUnified) && "expect only one enabled.");
+  bool isCudaManaged{features.IsEnabled(common::LanguageFeature::CudaManaged)};
+  bool isCudaUnified{features.IsEnabled(common::LanguageFeature::CudaUnified)};
+  CHECK(!(isCudaUnified && isCudaManaged) && "expect only one enabled.");
 
   std::optional<common::CUDADataAttr> actualDataAttr, dummyDataAttr;
   if (actual) {
@@ -2534,7 +2534,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
 
   if (!dummyDataAttr) {
     if (!actualDataAttr) {
-      if (isGpuUnified || isGpuManaged) {
+      if (isCudaUnified || isCudaManaged) {
         return 3;
       }
       return 0;
@@ -2546,7 +2546,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Device) {
     if (!actualDataAttr) {
-      if (isGpuUnified || isGpuManaged) {
+      if (isCudaUnified || isCudaManaged) {
         return 2;
       }
       return cudaInfMatchingValue;
@@ -2558,13 +2558,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Managed) {
     if (!actualDataAttr) {
-      if (isGpuUnified) {
-        return 1;
-      }
-      if (isGpuManaged) {
-        return 0;
-      }
-      return cudaInfMatchingValue;
+      return isCudaUnified ? 1 : isCudaManaged ? 0 : cudaInfMatchingValue;
     }
     if (*actualDataAttr == common::CUDADataAttr::Device) {
       return cudaInfMatchingValue;
@@ -2575,13 +2569,7 @@ static int GetMatchingDistance(const common::LanguageFeatureControl &features,
     }
   } else if (*dummyDataAttr == common::CUDADataAttr::Unified) {
     if (!actualDataAttr) {
-      if (isGpuUnified) {
-        return 0;
-      }
-      if (isGpuManaged) {
-        return 1;
-      }
-      return cudaInfMatchingValue;
+      return isCudaUnified ? 0 : isCudaManaged ? 1 : cudaInfMatchingValue;
     }
     if (*actualDataAttr == common::CUDADataAttr::Device) {
       return cudaInfMatchingValue;
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index 085b988c6af5f..f7092d35eeb57 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -500,9 +500,9 @@ int main(int argc, char **argv) {
   }
 
   if (enableGPUMode == "managed") {
-    options.features.Enable(Fortran::common::LanguageFeature::GpuManaged);
+    options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
   } else if (enableGPUMode == "unified") {
-    options.features.Enable(Fortran::common::LanguageFeature::GpuUnified);
+    options.features.Enable(Fortran::common::LanguageFeature::CudaUnified);
   }
 
   if (fixedForm) {



More information about the flang-commits mailing list