[Mlir-commits] [mlir] 89bb99d - [acc][flang] Implement acc interface for tracking type descriptors (#168982)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Nov 21 07:40:09 PST 2025


Author: Razvan Lupusoru
Date: 2025-11-21T07:40:04-08:00
New Revision: 89bb99dd0ed9c7e5dbbc80c6cfb369768bfee96b

URL: https://github.com/llvm/llvm-project/commit/89bb99dd0ed9c7e5dbbc80c6cfb369768bfee96b
DIFF: https://github.com/llvm/llvm-project/commit/89bb99dd0ed9c7e5dbbc80c6cfb369768bfee96b.diff

LOG: [acc][flang] Implement acc interface for tracking type descriptors (#168982)

FIR operations that use derived types need to have type descriptor
globals available on device when offloading. Examples of this can be
seen in `CUFDeviceGlobal` which ensures that such type descriptor uses
work on device for CUF.

Similarly, this is needed for OpenACC. This change introduces a new
interface to the OpenACC dialect named
`IndirectGlobalAccessOpInterface` which can be attached to operations
that may result in generation of accesses that use type descriptor
globals. This functionality is needed for the `ACCImplicitDeclare` pass
that is coming in a follow-up change which implicitly ensures that all
referenced globals are available in OpenACC compute contexts.

The interface provides a `getReferencedSymbols` method that collects all
global symbols referenced by an operation. When a symbol table is
provided, the implementation for FIR recursively walks type descriptor
globals to find all transitively referenced symbols.

Note that alternately this could have been implemented in different
ways:
- Codegen could implicitly generate such type globals as needed by
changing the technique that relies on populating them during lowering
(eg generate them directly in gpu.module during codegen).
- This interface could attach to types instead of operations for a
potentially more conservative implementation which maps all type
descriptors even if the underlying implementation using it won't
necessarily need such mapping.

The technique chosen here is consistent with `CUFDeviceGlobal` (which
walks operations inside `prepareImplicitDeviceGlobals`) and avoids
conservative mapping of all type descriptors.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
    flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
    flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
    mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index bf87654979cc9..0020e1ab21a56 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -67,6 +67,15 @@ struct GlobalVariableModel
   bool isConstant(mlir::Operation *op) const;
 };
 
+template <typename Op>
+struct IndirectGlobalAccessModel
+    : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
+          IndirectGlobalAccessModel<Op>, Op> {
+  void getReferencedSymbols(mlir::Operation *op,
+                            llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+                            mlir::SymbolTable *symbolTable) const;
+};
+
 } // namespace fir::acc
 
 #endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_

diff  --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index 11fbaf2dc2bb8..902a2ecdec35f 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -14,6 +14,9 @@
 
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallSet.h"
 
 namespace fir::acc {
 
@@ -68,4 +71,97 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
   return globalOp.getConstant().has_value();
 }
 
+// Helper to recursively process address-of operations in derived type
+// descriptors and collect all needed fir.globals.
+static void processAddrOfOpInDerivedTypeDescriptor(
+    fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab,
+    llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
+  if (auto globalOp = symTab.lookup<fir::GlobalOp>(
+          addrOfOp.getSymbol().getLeafReference().getValue())) {
+    if (globalsSet.contains(globalOp))
+      return;
+    globalsSet.insert(globalOp);
+    symbols.push_back(addrOfOp.getSymbolAttr());
+    globalOp.walk([&](fir::AddrOfOp op) {
+      processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols);
+    });
+  }
+}
+
+// Utility to collect referenced symbols for type descriptors of derived types.
+// This is the common logic for operations that may require type descriptor
+// globals.
+static void collectReferencedSymbolsForType(
+    mlir::Type ty, mlir::Operation *op,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) {
+  ty = fir::getDerivedType(fir::unwrapRefType(ty));
+
+  // Look for type descriptor globals only if it's a derived (record) type
+  if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
+    // If no symbol table provided, simply add the type descriptor name
+    if (!symbolTable) {
+      symbols.push_back(mlir::SymbolRefAttr::get(
+          op->getContext(),
+          fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
+      return;
+    }
+
+    // Otherwise, do full lookup and recursive processing
+    llvm::SmallSet<mlir::Operation *, 16> globalsSet;
+
+    fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
+        fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
+    if (!globalOp)
+      globalOp = symbolTable->lookup<fir::GlobalOp>(
+          fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
+
+    if (globalOp) {
+      globalsSet.insert(globalOp);
+      symbols.push_back(
+          mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName()));
+      globalOp.walk([&](fir::AddrOfOp addrOp) {
+        processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet,
+                                               symbols);
+      });
+    }
+  }
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto allocaOp = mlir::cast<fir::AllocaOp>(op);
+  collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto emboxOp = mlir::cast<fir::EmboxOp>(op);
+  collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
+                                  symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto reboxOp = mlir::cast<fir::ReboxOp>(op);
+  collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
+                                  symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
+  collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
+                                  symbolTable);
+}
+
 } // namespace fir::acc

diff  --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 5c7f9985d41ca..acd1d01ef1e87 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
 
     fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
     fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
+
+    fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
+        *ctx);
+    fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
+        *ctx);
+    fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
+        *ctx);
+    fir::TypeDescOp::attachInterface<
+        IndirectGlobalAccessModel<fir::TypeDescOp>>(*ctx);
   });
 
   // Register HLFIR operation interfaces

diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 6b0c84d31d1ba..ec41826b2bbc8 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
   ];
 }
 
+def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> {
+  let cppNamespace = "::mlir::acc";
+
+  let description = [{
+    An interface for operations that indirectly access global symbols.
+    This interface provides a way to query which global symbols are referenced
+    by an operation, which is useful for tracking dependencies and performing
+    analysis on global variable usage.
+
+    The symbolTable parameter is optional. If null, implementations will look up
+    their own symbol table. This allows callers to pass a pre-existing symbol
+    table for efficiency when querying multiple operations.
+  }];
+
+  let methods = [
+    InterfaceMethod<"Get the symbols referenced by this operation",
+      "void",
+      "getReferencedSymbols",
+      (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols,
+           "::mlir::SymbolTable *":$symbolTable)>,
+  ];
+}
+
 #endif // OPENACC_OPS_INTERFACES


        


More information about the Mlir-commits mailing list