[flang-commits] [flang] [flang][acc] Implement type categorization for FIR types (PR #126964)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Wed Feb 12 13:30:31 PST 2025


================
@@ -224,4 +226,157 @@ OpenACCMappableModel<fir::BaseBoxType>::generateAccBounds(
   return {};
 }
 
+static bool isScalarLike(mlir::Type type) {
+  return type.isIntOrIndexOrFloat() ||
+         mlir::isa<mlir::ComplexType, fir::LogicalType>(type) ||
+         fir::isa_ref_type(type);
+}
+
+static bool isArrayLike(mlir::Type type) {
+  return mlir::isa<fir::SequenceType>(type);
+}
+
+static bool isCompositeLike(mlir::Type type) {
+  return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCMappableModel<fir::SequenceType>::getTypeCategory(
+    mlir::Type type, mlir::Value var) const {
+  return mlir::acc::VariableTypeCategory::array;
+}
+
+template <>
+mlir::acc::VariableTypeCategory
+OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
+                                                        mlir::Value var) const {
+
+  mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
+
+  // If the type enclosed by the box is a mappable type, then have it
+  // provide the type category.
+  if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy)) {
+    return mappableTy.getTypeCategory(var);
+  }
+
+  // For all arrays, despite whether they are allocatable, pointer, assumed,
+  // etc, we'd like to categorize them as "array".
+  if (isArrayLike(eleTy)) {
+    return mlir::acc::VariableTypeCategory::array;
+  }
+
+  // We got here because we don't have an array nor a mappable type. At this
+  // point, we know we have a type that fits the "aggregate" definition since it
+  // is a type with a descriptor. Try to refine it by checking if it matches the
+  // "composite" definition.
+  if (isCompositeLike(eleTy)) {
+    return mlir::acc::VariableTypeCategory::composite;
+  }
+
+  // Even if we have a scalar type - simply because it is wrapped in a box
+  // we want to categorize it as "nonscalar". Anything else would've been
+  // non-scalar anyway.
+  return mlir::acc::VariableTypeCategory::nonscalar;
+}
+
+static mlir::TypedValue<mlir::acc::PointerLikeType>
+getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
+  // If there is no defining op - the unwrapped reference is the base one.
+  mlir::Operation *op = varPtr.getDefiningOp();
+  if (!op) {
+    return varPtr;
+  }
+
+  // Look to find if this value originates from an interior pointer
+  // calculation op.
+  mlir::Value baseRef =
+      llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
+          .Case<hlfir::DesignateOp>([&](auto op) {
+            // Get the base object.
+            return op.getMemref();
+          })
+          .Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp>([&](auto op) {
+            // Get the base array on which the coordinate is being applied.
+            return op.getMemref();
+          })
+          .Case<fir::CoordinateOp>([&](auto op) {
+            // For coordinate operation which is applied on derived type
+            // object, get the base object.
+            return op.getRef();
+          })
+          .Default([&](mlir::Operation *) { return varPtr; });
+
+  return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
+}
+
+static mlir::acc::VariableTypeCategory
+categorizePointee(mlir::Type pointer,
+                  mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
+                  mlir::Type varType) {
+  // FIR uses operations to compute interior pointers.
+  // So for example, an array element or composite field access to a float
+  // value would both be represented as !fir.ref<f32>. We do not want to treat
+  // such a reference as a scalar. Thus unwrap interior pointer calculations.
+  auto baseRef = getBaseRef(varPtr);
+  mlir::Type eleTy = baseRef.getType().getElementType();
+
+  if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy)) {
+    return mappableTy.getTypeCategory(varPtr);
+  }
+
+  if (isScalarLike(eleTy)) {
+    return mlir::acc::VariableTypeCategory::scalar;
+  }
+  if (isArrayLike(eleTy)) {
+    return mlir::acc::VariableTypeCategory::array;
+  }
+  if (isCompositeLike(eleTy)) {
+    return mlir::acc::VariableTypeCategory::composite;
+  }
+  if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy)) {
+    return mlir::acc::VariableTypeCategory::nonscalar;
+  }
+  // "pointers" - in the sense of raw address point-of-view, are considered
+  // scalars. However
+  if (mlir::isa<fir::LLVMPointerType>(eleTy)) {
+    return mlir::acc::VariableTypeCategory::scalar;
+  }
----------------
clementval wrote:

braces on each if statements 

https://github.com/llvm/llvm-project/pull/126964


More information about the flang-commits mailing list