[llvm-branch-commits] [flang] [Flang][OpenMP] Derived type explicit allocatable member mapping (PR #113557)
Sergio Afonso via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Nov 4 04:29:52 PST 2024
================
@@ -145,11 +145,294 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name), builder.getBoolAttr(partialMap));
-
return op;
}
-static int
+// This function gathers the individual omp::Object's that make up an
+// larger omp::Object symbol.
+//
+// For example, provided the larger symbol: "parent%child%member", this
+// function breaks it up into it's constituent components ("parent",
+// "child", "member"), so we can access each individual component and
+// introspect details, important to note this function breaks it up from
+// RHS to LHS ("member" to "parent") and then we reverse it so that the
+// returned omp::ObjectList is LHS to RHS, with the "parent" at the
+// beginning.
+omp::ObjectList gatherObjectsOf(omp::Object derivedTypeMember,
+ semantics::SemanticsContext &semaCtx) {
+ omp::ObjectList objList;
+ std::optional<omp::Object> baseObj = derivedTypeMember;
+ while (baseObj.has_value()) {
+ objList.push_back(baseObj.value());
+ baseObj = getBaseObject(baseObj.value(), semaCtx);
+ }
+ return omp::ObjectList{llvm::reverse(objList)};
+}
+
+// This function generates a series of indices from a provided omp::Object,
+// that devolves to an ArrayRef symbol, e.g. "array(2,3,4)", this function
+// would generate a series of indices of "[1][2][3]" for the above example,
+// offsetting by -1 to account for the non-zero fortran indexes.
+//
+// These indices can then be provided to a coordinate operation or other
+// GEP-like operation to access the relevant positional member of the
+// array.
+//
+// It is of note that the function only supports subscript integers currently
+// and not Triplets i.e. Array(1:2:3).
+static void generateArrayIndices(lower::AbstractConverter &converter,
+ fir::FirOpBuilder &firOpBuilder,
+ lower::StatementContext &stmtCtx,
+ mlir::Location clauseLocation,
+ llvm::SmallVectorImpl<mlir::Value> &indices,
+ omp::Object object) {
+ auto maybeRef = evaluate::ExtractDataRef(*object.ref());
+ if (!maybeRef)
+ return;
+
+ auto *arr = std::get_if<evaluate::ArrayRef>(&maybeRef->u);
+ if (!arr)
+ return;
+
+ for (auto v : arr->subscript()) {
+ if (std::holds_alternative<Triplet>(v.u)) {
+ llvm_unreachable("Triplet indexing in map clause is unsupported");
+ } else {
+ auto expr =
+ std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(v.u);
+ mlir::Value subscript =
+ fir::getBase(converter.genExprValue(toEvExpr(expr.value()), stmtCtx));
+ mlir::Value one = firOpBuilder.createIntegerConstant(
+ clauseLocation, firOpBuilder.getIndexType(), 1);
+ subscript = firOpBuilder.createConvert(
+ clauseLocation, firOpBuilder.getIndexType(), subscript);
+ indices.push_back(firOpBuilder.create<mlir::arith::SubIOp>(
+ clauseLocation, subscript, one));
+ }
+ }
+}
+
+/// When mapping members of derived types, there is a chance that one of the
+/// members along the way to a mapped member is an descriptor. In which case
+/// we have to make sure we generate a map for those along the way otherwise
+/// we will be missing a chunk of data required to actually map the member
+/// type to device. This function effectively generates these maps and the
+/// appropriate data accesses required to generate these maps. It will avoid
+/// creating duplicate maps, as duplicates are just as bad as unmapped
+/// descriptor data in a lot of cases for the runtime (and unnecessary
+/// data movement should be avoided where possible).
+///
+/// As an example for the following mapping:
+///
+/// type :: vertexes
+/// integer(4), allocatable :: vertexx(:)
+/// integer(4), allocatable :: vertexy(:)
+/// end type vertexes
+///
+/// type :: dtype
+/// real(4) :: i
+/// type(vertexes), allocatable :: vertexes(:)
+/// end type dtype
+///
+/// type(dtype), allocatable :: alloca_dtype
+///
+/// !$omp target map(tofrom: alloca_dtype%vertexes(N1)%vertexx)
+///
+/// The below HLFIR/FIR is generated (trimmed for conciseness):
+///
+/// On the first iteration we index into the record type alloca_dtype
+/// to access "vertexes", we then generate a map for this descriptor
+/// alongside bounds to indicate we only need the 1 member, rather than
+/// the whole array block in this case (In theory we could map its
+/// entirety at the cost of data transfer bandwidth).
+///
+/// %13:2 = hlfir.declare ... "alloca_dtype" ...
+/// %39 = fir.load %13#0 : ...
+/// %40 = fir.coordinate_of %39, %c1 : ...
+/// %51 = omp.map.info var_ptr(%40 : ...) map_clauses(to) capture(ByRef) ...
+/// %52 = fir.load %40 : ...
+///
+/// Second iteration generating access to "vertexes(N1) utilising the N1 index
+/// %53 = load N1 ...
+/// %54 = fir.convert %53 : (i32) -> i64
+/// %55 = fir.convert %54 : (i64) -> index
+/// %56 = arith.subi %55, %c1 : index
+/// %57 = fir.coordinate_of %52, %56 : ...
+///
+/// Still in the second iteration we access the allocatable member "vertexx",
+/// we return %58 from the function and provide it to the final and "main"
+/// map of processMap (generated by the record type segment of the below
+/// function), if this were not the final symbol in the list, i.e. we accessed
+/// a member below vertexx, we would have generated the map below as we did in
+/// the first iteration and then continue to generate further coordinates to
+/// access further components as required.
+///
+/// %58 = fir.coordinate_of %57, %c0 : ...
+/// %61 = omp.map.info var_ptr(%58 : ...) map_clauses(to) capture(ByRef) ...
+///
+/// Parent mapping containing prior generated mapped members, generated at
+/// a later step but here to showcase the "end" result
+///
+/// omp.map.info var_ptr(%13#1 : ...) map_clauses(to) capture(ByRef)
+/// members(%50, %61 : [0, 1, 0], [0, 1, 0] : ...
+///
+/// \param objectList - The list of omp::Object symbol data for each parent
+/// to the mapped member (also includes the mapped member), generated via
+/// gatherObjectsOf.
+/// \param indices - List of index data associated with the mapped member
+/// symbol, which identifies the placement of the member in its parent,
+/// this helps generate the appropriate member accesses. These indices
+/// can be generated via generateMemberPlacementIndices.
+/// \param asFortran - A string generated from the mapped variable to be
+/// associated with the main map, generally (but not restricted to)
+/// generated via gatherDataOperandAddrAndBounds or other
+/// DirectiveCommons.hpp utilities.
+/// \param mapTypeBits - The map flags that will be associated with the
+/// generated maps, minus alterations of the TO and FROM bits for the
+/// intermediate components to prevent accidental overwriting on device
+/// write back.
+mlir::Value createParentSymAndGenIntermediateMaps(
+ mlir::Location clauseLocation, lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx, lower::StatementContext &stmtCtx,
+ omp::ObjectList &objectList, llvm::SmallVector<int64_t> &indices,
+ OmpMapParentAndMemberData &parentMemberIndices, std::string asFortran,
+ llvm::omp::OpenMPOffloadMappingFlags mapTypeBits) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ /// Checks if an omp::Object is an array expression with a subscript, e.g.
+ /// array(1,2).
+ auto arrayExprWithSubscript = [](omp::Object obj) {
+ if (auto maybeRef = evaluate::ExtractDataRef(*obj.ref())) {
+ evaluate::DataRef ref = *maybeRef;
+ if (auto *arr = std::get_if<evaluate::ArrayRef>(&ref.u))
+ return !arr->subscript().empty();
+ }
+ return false;
+ };
+
+ // Generate the access to the original parent base address.
+ lower::AddrAndBoundsInfo parentBaseAddr = lower::getDataOperandBaseAddr(
+ converter, firOpBuilder, *objectList[0].sym(), clauseLocation);
+ mlir::Value curValue = parentBaseAddr.addr;
+
+ // Iterate over all objects in the objectList, this should consist of all
+ // record types between the parent and the member being mapped (including
+ // the parent). The object list may also contain array objects as well,
+ // this can occur when specifying bounds or a specific element access
+ // within a member map, we skip these.
+ size_t currentIndex = 0;
+ for (size_t i = 0; i < objectList.size(); ++i) {
+ // If we encounter a sequence type, i.e. an array, we must generate the
+ // correct coordinate operation to index into the array to proceed further,
+ // this is only relevant in cases where we encounter subscripts currently.
+ //
+ // For example in the following case:
+ //
+ // map(tofrom: array_dtype(4)%internal_dtypes(3)%float_elements(4))
+ //
+ // We must generate coordinate operation accesses for each subscript
+ // we encounter.
+ if (fir::SequenceType arrType = mlir::dyn_cast<fir::SequenceType>(
+ fir::unwrapPassByRefType(curValue.getType()))) {
+ if (arrayExprWithSubscript(objectList[i])) {
+ llvm::SmallVector<mlir::Value> indices;
----------------
skatrak wrote:
Nit: Rename local variable to avoid confusing it with the function argument.
https://github.com/llvm/llvm-project/pull/113557
More information about the llvm-branch-commits
mailing list