[llvm-branch-commits] [flang] [flang][OpenMP] Make object identity more precise (PR #94495)

Krzysztof Parzyszek via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jun 5 09:09:19 PDT 2024


https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/94495

Derived type components may use a given `Symbol` regardless of what parent objects they are a part of. Because of that, simply using a symbol address is not sufficient to determine object identity.

Make the designator a part of the IdTy. To compare identities, when symbols are equal (and non-null), compare the designators.

>From 1c266fa056044e2baa6cd4262fdf9a4cd74fd14d Mon Sep 17 00:00:00 2001
From: Krzysztof Parzyszek <Krzysztof.Parzyszek at amd.com>
Date: Wed, 5 Jun 2024 10:09:09 -0500
Subject: [PATCH] [flang][OpenMP] Make object identity more precise

Derived type components may use a given `Symbol` regardless of what
parent objects they are a part of. Because of that, simply using a
symbol address is not sufficient to determine object identity.

Make the designator a part of the IdTy. To compare identities, when
symbols are equal (and non-null), compare the designators.
---
 flang/lib/Lower/OpenMP/Clauses.h              | 50 ++++++++++++++++---
 flang/lib/Lower/OpenMP/Utils.cpp              |  2 +-
 flang/test/Lower/OpenMP/map-component-ref.f90 | 41 +++++++++++----
 3 files changed, 75 insertions(+), 18 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/Clauses.h b/flang/lib/Lower/OpenMP/Clauses.h
index f7cd0ea83ad12..98fb5dcf7722e 100644
--- a/flang/lib/Lower/OpenMP/Clauses.h
+++ b/flang/lib/Lower/OpenMP/Clauses.h
@@ -36,30 +36,64 @@ struct TypeTy : public evaluate::SomeType {
   bool operator==(const TypeTy &t) const { return true; }
 };
 
-using IdTy = semantics::Symbol *;
+template <typename ExprTy>
+struct IdTyTemplate {
+  // "symbol" is always non-null for id's of actual objects.
+  Fortran::semantics::Symbol *symbol;
+  std::optional<ExprTy> designator;
+
+  bool operator==(const IdTyTemplate &other) const {
+    // If symbols are different, then the objects are different.
+    if (symbol != other.symbol)
+      return false;
+    if (symbol == nullptr)
+      return true;
+    // Equal symbols don't necessarily indicate identical objects,
+    // for example, a derived object component may use a single symbol,
+    // which will refer to different objects for different designators,
+    // e.g. a%c and b%c.
+    return designator == other.designator;
+  }
+
+  operator bool() const { return symbol != nullptr; }
+};
+
 using ExprTy = SomeExpr;
 
 template <typename T>
 using List = tomp::ListT<T>;
 } // namespace Fortran::lower::omp
 
+// Specialization of the ObjectT template
 namespace tomp::type {
 template <>
-struct ObjectT<Fortran::lower::omp::IdTy, Fortran::lower::omp::ExprTy> {
-  using IdTy = Fortran::lower::omp::IdTy;
+struct ObjectT<Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>,
+               Fortran::lower::omp::ExprTy> {
+  using IdTy = Fortran::lower::omp::IdTyTemplate<Fortran::lower::omp::ExprTy>;
   using ExprTy = Fortran::lower::omp::ExprTy;
 
-  IdTy id() const { return symbol; }
-  Fortran::semantics::Symbol *sym() const { return symbol; }
-  const std::optional<ExprTy> &ref() const { return designator; }
+  IdTy id() const { return identity; }
+  Fortran::semantics::Symbol *sym() const { return identity.symbol; }
+  const std::optional<ExprTy> &ref() const { return identity.designator; }
 
-  IdTy symbol;
-  std::optional<ExprTy> designator;
+  IdTy identity;
 };
 } // namespace tomp::type
 
 namespace Fortran::lower::omp {
+using IdTy = IdTyTemplate<ExprTy>;
+}
 
+namespace std {
+template <>
+struct hash<Fortran::lower::omp::IdTy> {
+  size_t operator()(const Fortran::lower::omp::IdTy &id) const {
+    return static_cast<size_t>(reinterpret_cast<uintptr_t>(id.symbol));
+  }
+};
+} // namespace std
+
+namespace Fortran::lower::omp {
 using Object = tomp::ObjectT<IdTy, ExprTy>;
 using ObjectList = tomp::ObjectListT<IdTy, ExprTy>;
 
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index eff915f569f27..da94352a84a7c 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -188,7 +188,7 @@ void addChildIndexAndMapToParent(
     std::map<const semantics::Symbol *,
              llvm::SmallVector<OmpMapMemberIndicesData>> &parentMemberIndices,
     mlir::omp::MapInfoOp &mapOp, semantics::SemanticsContext &semaCtx) {
-  std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.designator);
+  std::optional<evaluate::DataRef> dataRef = ExtractDataRef(object.ref());
   assert(dataRef.has_value() &&
          "DataRef could not be extracted during mapping of derived type "
          "cannot proceed");
diff --git a/flang/test/Lower/OpenMP/map-component-ref.f90 b/flang/test/Lower/OpenMP/map-component-ref.f90
index 2c582667f38d3..21b56ab303acd 100644
--- a/flang/test/Lower/OpenMP/map-component-ref.f90
+++ b/flang/test/Lower/OpenMP/map-component-ref.f90
@@ -1,21 +1,22 @@
 ! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
 ! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
 
-! CHECK: %[[V0:[0-9]+]] = fir.alloca !fir.type<_QFfooTt0{a0:i32,a1:i32}> {bindc_name = "a", uniq_name = "_QFfooEa"}
-! CHECK: %[[V1:[0-9]+]]:2 = hlfir.declare %[[V0]] {uniq_name = "_QFfooEa"} : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>)
-! CHECK: %[[V2:[0-9]+]] = hlfir.designate %[[V1]]#0{"a1"}   : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
+! CHECK-LABEL: func.func @_QPfoo1
+! CHECK: %[[V0:[0-9]+]] = fir.alloca !fir.type<_QFfoo1Tt0{a0:i32,a1:i32}> {bindc_name = "a", uniq_name = "_QFfoo1Ea"}
+! CHECK: %[[V1:[0-9]+]]:2 = hlfir.declare %[[V0]] {uniq_name = "_QFfoo1Ea"} : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>)
+! CHECK: %[[V2:[0-9]+]] = hlfir.designate %[[V1]]#0{"a1"}   : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
 ! CHECK: %[[V3:[0-9]+]] = omp.map.info var_ptr(%[[V2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "a%a1"}
-! CHECK: %[[V4:[0-9]+]] = omp.map.info var_ptr(%[[V1]]#1 : !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>, !fir.type<_QFfooTt0{a0:i32,a1:i32}>) map_clauses(tofrom) capture(ByRef) members(%[[V3]] : [1] : !fir.ref<i32>) -> !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>> {name = "a", partial_map = true}
-! CHECK: omp.target map_entries(%[[V3]] -> %arg0, %[[V4]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) {
-! CHECK: ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>):
-! CHECK:   %[[V5:[0-9]+]]:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEa"} : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>)
+! CHECK: %[[V4:[0-9]+]] = omp.map.info var_ptr(%[[V1]]#1 : !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>, !fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>) map_clauses(tofrom) capture(ByRef) members(%[[V3]] : [1] : !fir.ref<i32>) -> !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>> {name = "a", partial_map = true}
+! CHECK: omp.target map_entries(%[[V3]] -> %arg0, %[[V4]] -> %arg1 : !fir.ref<i32>, !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) {
+! CHECK: ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>):
+! CHECK:   %[[V5:[0-9]+]]:2 = hlfir.declare %arg1 {uniq_name = "_QFfoo1Ea"} : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>, !fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>)
 ! CHECK:   %c0_i32 = arith.constant 0 : i32
-! CHECK:   %[[V6:[0-9]+]] = hlfir.designate %[[V5]]#0{"a1"}   : (!fir.ref<!fir.type<_QFfooTt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
+! CHECK:   %[[V6:[0-9]+]] = hlfir.designate %[[V5]]#0{"a1"}   : (!fir.ref<!fir.type<_QFfoo1Tt0{a0:i32,a1:i32}>>) -> !fir.ref<i32>
 ! CHECK:   hlfir.assign %c0_i32 to %[[V6]] : i32, !fir.ref<i32>
 ! CHECK:   omp.terminator
 ! CHECK: }
 
-subroutine foo()
+subroutine foo1()
   implicit none
 
   type t0
@@ -29,3 +30,25 @@ subroutine foo()
   !$omp end target
 end
 
+
+! CHECK-LABEL: func.func @_QPfoo2
+! CHECK-DAG: omp.map.info var_ptr(%{{[0-9]+}} : {{.*}} map_clauses(to) capture(ByRef) bounds(%{{[0-9]+}}) -> {{.*}} {name = "t%b(1_8)%a(1)"}
+! CHECK-DAG: omp.map.info var_ptr(%{{[0-9]+}} : {{.*}} map_clauses(from) capture(ByRef) bounds(%{{[0-9]+}}) -> {{.*}} {name = "u%b(1_8)%a(1)"}
+subroutine foo2()
+  implicit none
+
+  type t0
+    integer :: a(10)
+  end type
+
+  type t1
+    type(t0) :: b(10)
+  end type
+
+  type(t1) :: t, u
+
+!$omp target map(to: t%b(1)%a(1)) map(from: u%b(1)%a(1))
+  t%b(1)%a(1) = u%b(1)%a(1)
+!$omp end target
+
+end



More information about the llvm-branch-commits mailing list