[flang-commits] [flang] 7b473df - [flang][acc] Implement type categorization for FIR types (#126964)

via flang-commits flang-commits at lists.llvm.org
Wed Feb 12 21:10:03 PST 2025


Author: Razvan Lupusoru
Date: 2025-02-12T21:09:59-08:00
New Revision: 7b473dfe84c17319930d4019ab3f6ca0cfc03416

URL: https://github.com/llvm/llvm-project/commit/7b473dfe84c17319930d4019ab3f6ca0cfc03416
DIFF: https://github.com/llvm/llvm-project/commit/7b473dfe84c17319930d4019ab3f6ca0cfc03416.diff

LOG: [flang][acc] Implement type categorization for FIR types (#126964)

The OpenACC type interfaces have been updated to require that a type
self-identify which type category it belongs to. Ensure that FIR types
are able to provide this self identification.

In addition to implementing the new API, the PointerLikeType interface
attachment was moved to FIROpenACCSupport library like MappableType to
ensure all type interfaces and their implementation are now in the same
spot.

Added: 
    flang/test/Fir/OpenACC/openacc-type-categories.f90

Modified: 
    flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h
    flang/include/flang/Tools/PointerModels.h
    flang/lib/Frontend/FrontendActions.cpp
    flang/lib/Optimizer/Dialect/FIRType.cpp
    flang/lib/Optimizer/OpenACC/CMakeLists.txt
    flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
    flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp
    flang/test/Fir/OpenACC/openacc-mappable.fir
    flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h b/flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h
index c1bea32a22361..3e343f347e4ae 100644
--- a/flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h
@@ -18,6 +18,19 @@
 
 namespace fir::acc {
 
+template <typename T>
+struct OpenACCPointerLikeModel
+    : public mlir::acc::PointerLikeType::ExternalModel<
+          OpenACCPointerLikeModel<T>, T> {
+  mlir::Type getElementType(mlir::Type pointer) const {
+    return mlir::cast<T>(pointer).getElementType();
+  }
+  mlir::acc::VariableTypeCategory
+  getPointeeTypeCategory(mlir::Type pointer,
+                         mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+                         mlir::Type varType) const;
+};
+
 template <typename T>
 struct OpenACCMappableModel
     : public mlir::acc::MappableType::ExternalModel<OpenACCMappableModel<T>,
@@ -36,6 +49,9 @@ struct OpenACCMappableModel
   llvm::SmallVector<mlir::Value>
   generateAccBounds(mlir::Type type, mlir::Value var,
                     mlir::OpBuilder &builder) const;
+
+  mlir::acc::VariableTypeCategory getTypeCategory(mlir::Type type,
+                                                  mlir::Value var) const;
 };
 
 } // namespace fir::acc

diff  --git a/flang/include/flang/Tools/PointerModels.h b/flang/include/flang/Tools/PointerModels.h
index c3c0977d6e54a..0d22ed3ca7f4f 100644
--- a/flang/include/flang/Tools/PointerModels.h
+++ b/flang/include/flang/Tools/PointerModels.h
@@ -9,7 +9,6 @@
 #ifndef FORTRAN_TOOLS_POINTER_MODELS_H
 #define FORTRAN_TOOLS_POINTER_MODELS_H
 
-#include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 
 /// models for FIR pointer like types that already provide a `getElementType`
@@ -24,13 +23,4 @@ struct OpenMPPointerLikeModel
   }
 };
 
-template <typename T>
-struct OpenACCPointerLikeModel
-    : public mlir::acc::PointerLikeType::ExternalModel<
-          OpenACCPointerLikeModel<T>, T> {
-  mlir::Type getElementType(mlir::Type pointer) const {
-    return mlir::cast<T>(pointer).getElementType();
-  }
-};
-
 #endif // FORTRAN_TOOLS_POINTER_MODELS_H

diff  --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index b7674bd093f68..622848eac2dd2 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -261,12 +261,12 @@ bool CodeGenAction::beginSourceFileAction() {
   }
 
   // Load the MLIR dialects required by Flang
-  mlir::DialectRegistry registry;
-  mlirCtx = std::make_unique<mlir::MLIRContext>(registry);
-  fir::support::registerNonCodegenDialects(registry);
-  fir::support::loadNonCodegenDialects(*mlirCtx);
+  mlirCtx = std::make_unique<mlir::MLIRContext>();
   fir::support::loadDialects(*mlirCtx);
   fir::support::registerLLVMTranslation(*mlirCtx);
+  mlir::DialectRegistry registry;
+  fir::acc::registerOpenACCExtensions(registry);
+  mlirCtx->appendDialectRegistry(registry);
 
   const llvm::TargetMachine &targetMachine = ci.getTargetMachine();
 

diff  --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp
index 49f0e53fa113d..719cb1b9d75aa 100644
--- a/flang/lib/Optimizer/Dialect/FIRType.cpp
+++ b/flang/lib/Optimizer/Dialect/FIRType.cpp
@@ -1370,23 +1370,12 @@ void FIROpsDialect::registerTypes() {
            TypeDescType, fir::VectorType, fir::DummyScopeType>();
   fir::ReferenceType::attachInterface<
       OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
-  fir::ReferenceType::attachInterface<
-      OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext());
-
   fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
       *getContext());
-  fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>(
-      *getContext());
-
   fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
       *getContext());
-  fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
-      *getContext());
-
   fir::LLVMPointerType::attachInterface<
       OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
-  fir::LLVMPointerType::attachInterface<
-      OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
 }
 
 std::optional<std::pair<uint64_t, unsigned short>>

diff  --git a/flang/lib/Optimizer/OpenACC/CMakeLists.txt b/flang/lib/Optimizer/OpenACC/CMakeLists.txt
index 04d351ac265d6..1bfae603fd80d 100644
--- a/flang/lib/Optimizer/OpenACC/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenACC/CMakeLists.txt
@@ -6,6 +6,7 @@ add_flang_library(FIROpenACCSupport
 
   DEPENDS
   FIRBuilder
+  FIRCodeGen
   FIRDialect
   FIRDialectSupport
   FIRSupport
@@ -14,6 +15,7 @@ add_flang_library(FIROpenACCSupport
 
   LINK_LIBS
   FIRBuilder
+  FIRCodeGen
   FIRDialect
   FIRDialectSupport
   FIRSupport

diff  --git a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
index 94ab31de1763d..0ebc62e7f2fd6 100644
--- a/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp
@@ -15,6 +15,7 @@
 #include "flang/Optimizer/Builder/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
 #include "flang/Optimizer/Builder/HLFIRTools.h"
+#include "flang/Optimizer/CodeGen/CGOps.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
@@ -24,6 +25,7 @@
 #include "mlir/Dialect/OpenACC/OpenACC.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/TypeSwitch.h"
 
 namespace fir::acc {
 
@@ -224,4 +226,145 @@ OpenACCMappableModel<fir::BaseBoxType>::generateAccBounds(
   return {};
 }
 
+static bool isScalarLike(mlir::Type type) {
+  return fir::isa_trivial(type) || fir::isa_ref_type(type);
+}
+
+static bool isArrayLike(mlir::Type type) {
+  return mlir::isa<fir::SequenceType>(type);
+}
+
+static bool isCompositeLike(mlir::Type type) {
+  return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCMappableModel<fir::SequenceType>::getTypeCategory(
+    mlir::Type type, mlir::Value var) const {
+  return mlir::acc::VariableTypeCategory::array;
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
+                                                        mlir::Value var) const {
+
+  mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
+
+  // If the type enclosed by the box is a mappable type, then have it
+  // provide the type category.
+  if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
+    return mappableTy.getTypeCategory(var);
+
+  // For all arrays, despite whether they are allocatable, pointer, assumed,
+  // etc, we'd like to categorize them as "array".
+  if (isArrayLike(eleTy))
+    return mlir::acc::VariableTypeCategory::array;
+
+  // We got here because we don't have an array nor a mappable type. At this
+  // point, we know we have a type that fits the "aggregate" definition since it
+  // is a type with a descriptor. Try to refine it by checking if it matches the
+  // "composite" definition.
+  if (isCompositeLike(eleTy))
+    return mlir::acc::VariableTypeCategory::composite;
+
+  // Even if we have a scalar type - simply because it is wrapped in a box
+  // we want to categorize it as "nonscalar". Anything else would've been
+  // non-scalar anyway.
+  return mlir::acc::VariableTypeCategory::nonscalar;
+}
+
+static mlir::TypedValue<mlir::acc::PointerLikeType>
+getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
+  // If there is no defining op - the unwrapped reference is the base one.
+  mlir::Operation *op = varPtr.getDefiningOp();
+  if (!op)
+    return varPtr;
+
+  // Look to find if this value originates from an interior pointer
+  // calculation op.
+  mlir::Value baseRef =
+      llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
+          .Case<hlfir::DesignateOp>([&](auto op) {
+            // Get the base object.
+            return op.getMemref();
+          })
+          .Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp>([&](auto op) {
+            // Get the base array on which the coordinate is being applied.
+            return op.getMemref();
+          })
+          .Case<fir::CoordinateOp>([&](auto op) {
+            // For coordinate operation which is applied on derived type
+            // object, get the base object.
+            return op.getRef();
+          })
+          .Default([&](mlir::Operation *) { return varPtr; });
+
+  return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
+}
+
+static mlir::acc::VariableTypeCategory
+categorizePointee(mlir::Type pointer,
+                  mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+                  mlir::Type varType) {
+  // FIR uses operations to compute interior pointers.
+  // So for example, an array element or composite field access to a float
+  // value would both be represented as !fir.ref<f32>. We do not want to treat
+  // such a reference as a scalar. Thus unwrap interior pointer calculations.
+  auto baseRef = getBaseRef(varPtr);
+  mlir::Type eleTy = baseRef.getType().getElementType();
+
+  if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
+    return mappableTy.getTypeCategory(varPtr);
+
+  if (isScalarLike(eleTy))
+    return mlir::acc::VariableTypeCategory::scalar;
+  if (isArrayLike(eleTy))
+    return mlir::acc::VariableTypeCategory::array;
+  if (isCompositeLike(eleTy))
+    return mlir::acc::VariableTypeCategory::composite;
+  if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy))
+    return mlir::acc::VariableTypeCategory::nonscalar;
+  // "pointers" - in the sense of raw address point-of-view, are considered
+  // scalars. However
+  if (mlir::isa<fir::LLVMPointerType>(eleTy))
+    return mlir::acc::VariableTypeCategory::scalar;
+
+  // Without further checking, this type cannot be categorized.
+  return mlir::acc::VariableTypeCategory::uncategorized;
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCPointerLikeModel<fir::ReferenceType>::getPointeeTypeCategory(
+    mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+    mlir::Type varType) const {
+  return categorizePointee(pointer, varPtr, varType);
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCPointerLikeModel<fir::PointerType>::getPointeeTypeCategory(
+    mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+    mlir::Type varType) const {
+  return categorizePointee(pointer, varPtr, varType);
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCPointerLikeModel<fir::HeapType>::getPointeeTypeCategory(
+    mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+    mlir::Type varType) const {
+  return categorizePointee(pointer, varPtr, varType);
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory(
+    mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+    mlir::Type varType) const {
+  return categorizePointee(pointer, varPtr, varType);
+}
+
 } // namespace fir::acc

diff  --git a/flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp
index 34ea122f6b997..184a264c64325 100644
--- a/flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp
@@ -22,6 +22,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
     fir::SequenceType::attachInterface<OpenACCMappableModel<fir::SequenceType>>(
         *ctx);
     fir::BoxType::attachInterface<OpenACCMappableModel<fir::BaseBoxType>>(*ctx);
+
+    fir::ReferenceType::attachInterface<
+        OpenACCPointerLikeModel<fir::ReferenceType>>(*ctx);
+    fir::PointerType::attachInterface<
+        OpenACCPointerLikeModel<fir::PointerType>>(*ctx);
+    fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
+        *ctx);
+    fir::LLVMPointerType::attachInterface<
+        OpenACCPointerLikeModel<fir::LLVMPointerType>>(*ctx);
   });
 }
 

diff  --git a/flang/test/Fir/OpenACC/openacc-mappable.fir b/flang/test/Fir/OpenACC/openacc-mappable.fir
index 438cb29b991c7..005f002c491a5 100644
--- a/flang/test/Fir/OpenACC/openacc-mappable.fir
+++ b/flang/test/Fir/OpenACC/openacc-mappable.fir
@@ -19,7 +19,9 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>,
 
 // CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
 // CHECK: Mappable: !fir.box<!fir.array<10xf32>>
+// CHECK: Type category: array
 // CHECK: Size: 40
 // CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
 // CHECK: Mappable: !fir.array<10xf32>
+// CHECK: Type category: array
 // CHECK: Size: 40

diff  --git a/flang/test/Fir/OpenACC/openacc-type-categories.f90 b/flang/test/Fir/OpenACC/openacc-type-categories.f90
new file mode 100644
index 0000000000000..c25c38422b755
--- /dev/null
+++ b/flang/test/Fir/OpenACC/openacc-type-categories.f90
@@ -0,0 +1,49 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
+
+program main
+  real :: scalar
+  real, allocatable :: scalaralloc
+  type tt
+    real :: field
+    real :: fieldarray(10)
+  end type tt
+  type(tt) :: ttvar
+  real :: arrayconstsize(10)
+  real, allocatable :: arrayalloc(:)
+  complex :: complexvar
+  character*1 :: charvar
+
+  !$acc enter data copyin(scalar, scalaralloc, ttvar, arrayconstsize, arrayalloc)
+  !$acc enter data copyin(complexvar, charvar, ttvar%field, ttvar%fieldarray, arrayconstsize(1))
+end program
+
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalar", structured = false}
+! CHECK: Pointer-like: !fir.ref<f32>
+! CHECK: Type category: scalar
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalaralloc", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<f32>>>
+! CHECK: Type category: nonscalar
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.type<_QFTtt{field:f32,fieldarray:!fir.array<10xf32>}>>
+! CHECK: Type category: composite
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
+! CHECK: Type category: array
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayalloc", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
+! CHECK: Type category: array
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "complexvar", structured = false}
+! CHECK: Pointer-like: !fir.ref<complex<f32>>
+! CHECK: Type category: scalar
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "charvar", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.char<1>>
+! CHECK: Type category: nonscalar
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%field", structured = false}
+! CHECK: Pointer-like: !fir.ref<f32>
+! CHECK: Type category: composite
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%fieldarray", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
+! CHECK: Type category: array
+! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize(1)", structured = false}
+! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
+! CHECK: Type category: array

diff  --git a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
index 5c14809a265e1..90aabd7d40d44 100644
--- a/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
+++ b/flang/test/lib/OpenACC/TestOpenACCInterfaces.cpp
@@ -32,15 +32,43 @@ struct TestFIROpenACCInterfaces
     mlir::OpBuilder builder(mod);
     getOperation().walk([&](Operation *op) {
       if (isa<ACC_DATA_ENTRY_OPS>(op)) {
-        Type typeOfVar = acc::getVar(op).getType();
-        llvm::errs() << "Visiting: " << *op << "\n";
+        Value var = acc::getVar(op);
+        Type typeOfVar = var.getType();
+
+        // Attempt to determine if the variable is mappable-like or if
+        // the pointee itself is mappable-like. For example, if the variable is
+        // of type !fir.ref<!fir.box<>>, we want to print both the details about
+        // the !fir.ref since it is pointer-like, and about !fir.box since it
+        // is mappable.
         auto mappableTy = dyn_cast_if_present<acc::MappableType>(typeOfVar);
         if (!mappableTy) {
           mappableTy =
               dyn_cast_if_present<acc::MappableType>(acc::getVarType(op));
         }
+
+        llvm::errs() << "Visiting: " << *op << "\n";
+        llvm::errs() << "\tVar: " << var << "\n";
+
+        if (auto ptrTy = dyn_cast_if_present<acc::PointerLikeType>(typeOfVar)) {
+          llvm::errs() << "\tPointer-like: " << typeOfVar << "\n";
+          // If the pointee is not mappable, print details about it. Otherwise,
+          // we defer to the mappable printing below to print those details.
+          if (!mappableTy) {
+            acc::VariableTypeCategory typeCategory =
+                ptrTy.getPointeeTypeCategory(
+                    cast<TypedValue<acc::PointerLikeType>>(var),
+                    acc::getVarType(op));
+            llvm::errs() << "\t\tType category: " << typeCategory << "\n";
+          }
+        }
+
         if (mappableTy) {
           llvm::errs() << "\tMappable: " << mappableTy << "\n";
+
+          acc::VariableTypeCategory typeCategory =
+              mappableTy.getTypeCategory(var);
+          llvm::errs() << "\t\tType category: " << typeCategory << "\n";
+
           if (datalayout.has_value()) {
             auto size = mappableTy.getSizeInBytes(
                 acc::getVar(op), acc::getBounds(op), datalayout.value());
@@ -61,10 +89,6 @@ struct TestFIROpenACCInterfaces
               llvm::errs() << "\t\tBound[" << idx << "]: " << bound << "\n";
             }
           }
-        } else {
-          assert(acc::isPointerLikeType(typeOfVar) &&
-              "expected to be pointer-like");
-          llvm::errs() << "\tPointer-like: " << typeOfVar << "\n";
         }
       }
     });


        


More information about the flang-commits mailing list