[Mlir-commits] [flang] [mlir] [acc][flang] Implement acc interface for tracking type descriptors (PR #168982)

Razvan Lupusoru llvmlistbot at llvm.org
Thu Nov 20 20:03:31 PST 2025


https://github.com/razvanlupusoru updated https://github.com/llvm/llvm-project/pull/168982

>From adb9b5c516d2fd7dba02d71c3e19284069f9ca07 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 20 Nov 2025 17:38:42 -0800
Subject: [PATCH 1/3] [acc][flang] Implement acc interface for tracking type
 descriptors

FIR operations that use derived types need to have type
descriptor globals available on device when offloading.
Examples of this can be seen in `CUFDeviceGlobal` which ensures
that such type descriptor uses work on device for CUF.

Similarly, this is needed for OpenACC. This change introduces
a new interface to the OpenACC dialect named
`IndirectGlobalAccessOpInterface` which can be attached to
operations that may result in generation of accesses that use
type descriptor globals. This functionality is needed for the
`ACCImplicitDeclare` pass that is coming in a follow-up change
which implicitly ensures that all referenced globals are
available in OpenACC compute contexts.

The interface provides a `getReferencedSymbols` method that
collects all global symbols referenced by an operation.
When a symbol table is provided, the implementation for FIR
recursively walks type descriptor globals to find all
transitively referenced symbols.

Note that alternately this could have been implemented in
different ways:
- Codegen could implicitly generate such type globals as
needed by changing the technique that relies on populating
them during lowering (eg generate them directly in gpu.module
during codegen).
- This interface could attach to types instead of operations
for a potentially more conservative implementation which maps
all type descriptors even if the underlying implementation
using it won't necessarily need such mapping.

The technique chosen here is consistent with `CUFDeviceGlobal`
(which walks operations inside `prepareImplicitDeviceGlobals`)
and avoids conservative mapping of all type descriptors.
---
 .../OpenACC/Support/FIROpenACCOpsInterfaces.h |  10 ++
 .../Support/FIROpenACCOpsInterfaces.cpp       | 100 ++++++++++++++++++
 .../Support/RegisterOpenACCExtensions.cpp     |   9 ++
 .../Dialect/OpenACC/OpenACCOpsInterfaces.td   |  23 ++++
 4 files changed, 142 insertions(+)

diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index bf87654979cc9..87d60d489ba13 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -67,6 +67,16 @@ struct GlobalVariableModel
   bool isConstant(mlir::Operation *op) const;
 };
 
+template <typename Op>
+struct IndirectGlobalAccessModel
+    : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
+          IndirectGlobalAccessModel<Op>, Op> {
+  void getReferencedSymbols(
+      mlir::Operation *op,
+      llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+      mlir::SymbolTable *symbolTable) const;
+};
+
 } // namespace fir::acc
 
 #endif // FLANG_OPTIMIZER_OPENACC_FIROPENACC_OPS_INTERFACES_H_
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index 11fbaf2dc2bb8..aa62b5a9820ee 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -14,6 +14,9 @@
 
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
+#include "flang/Optimizer/Support/InternalNames.h"
+#include "mlir/IR/SymbolTable.h"
+#include "llvm/ADT/SmallSet.h"
 
 namespace fir::acc {
 
@@ -68,4 +71,101 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
   return globalOp.getConstant().has_value();
 }
 
+// Helper to recursively process address-of operations in derived type
+// descriptors and collect all needed fir.globals.
+static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp,
+    mlir::SymbolTable &symTab,
+    llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
+  if (auto globalOp = symTab.lookup<fir::GlobalOp>(
+          addrOfOp.getSymbol().getLeafReference().getValue())) {
+    if (globalsSet.contains(globalOp)) {
+      return;
+    }
+    globalsSet.insert(globalOp);
+    symbols.push_back(addrOfOp.getSymbolAttr());
+    globalOp.walk([&](fir::AddrOfOp op) {
+      processAddrOfOpInDerivedTypeDescriptor(
+          op, symTab, globalsSet, symbols);
+    });
+  }
+}
+
+// Utility to collect referenced symbols for type descriptors of derived types.
+// This is the common logic for operations that may require type descriptor
+// globals.
+static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) {
+  ty = fir::getDerivedType(fir::unwrapRefType(ty));
+
+  // Look for type descriptor globals only if it's a derived (record) type
+  if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
+    // If no symbol table provided, simply add the type descriptor name
+    if (!symbolTable) {
+      symbols.push_back(mlir::SymbolRefAttr::get(op->getContext(),
+          fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
+      return;
+    }
+
+    // Otherwise, do full lookup and recursive processing
+    llvm::SmallSet<mlir::Operation *, 16> globalsSet;
+
+    fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
+        fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
+    if (!globalOp) {
+      globalOp = symbolTable->lookup<fir::GlobalOp>(
+          fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
+    }
+    if (globalOp) {
+      globalsSet.insert(globalOp);
+      symbols.push_back(mlir::SymbolRefAttr::get(
+          op->getContext(), globalOp.getSymName()));
+      globalOp.walk([&](fir::AddrOfOp addrOp) {
+        processAddrOfOpInDerivedTypeDescriptor(
+            addrOp, *symbolTable, globalsSet, symbols);
+      });
+    }
+  }
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
+    mlir::Operation *op,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto allocaOp = mlir::cast<fir::AllocaOp>(op);
+  collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
+    mlir::Operation *op,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto emboxOp = mlir::cast<fir::EmboxOp>(op);
+  collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
+                                   symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
+    mlir::Operation *op,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto reboxOp = mlir::cast<fir::ReboxOp>(op);
+  collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
+                                   symbolTable);
+}
+
+template <>
+void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
+    mlir::Operation *op,
+    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::SymbolTable *symbolTable) const {
+  auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
+  collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
+                                   symbolTable);
+}
+
 } // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 5c7f9985d41ca..915518c8de6c7 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -52,6 +52,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
 
     fir::AddrOfOp::attachInterface<AddressOfGlobalModel>(*ctx);
     fir::GlobalOp::attachInterface<GlobalVariableModel>(*ctx);
+
+    fir::AllocaOp::attachInterface<IndirectGlobalAccessModel<fir::AllocaOp>>(
+        *ctx);
+    fir::EmboxOp::attachInterface<IndirectGlobalAccessModel<fir::EmboxOp>>(
+        *ctx);
+    fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
+        *ctx);
+    fir::TypeDescOp::attachInterface<IndirectGlobalAccessModel<fir::TypeDescOp>>(
+        *ctx);
   });
 
   // Register HLFIR operation interfaces
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
index 6b0c84d31d1ba..ec41826b2bbc8 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOpsInterfaces.td
@@ -75,4 +75,27 @@ def GlobalVariableOpInterface : OpInterface<"GlobalVariableOpInterface"> {
   ];
 }
 
+def IndirectGlobalAccessOpInterface : OpInterface<"IndirectGlobalAccessOpInterface"> {
+  let cppNamespace = "::mlir::acc";
+
+  let description = [{
+    An interface for operations that indirectly access global symbols.
+    This interface provides a way to query which global symbols are referenced
+    by an operation, which is useful for tracking dependencies and performing
+    analysis on global variable usage.
+
+    The symbolTable parameter is optional. If null, implementations will look up
+    their own symbol table. This allows callers to pass a pre-existing symbol
+    table for efficiency when querying multiple operations.
+  }];
+
+  let methods = [
+    InterfaceMethod<"Get the symbols referenced by this operation",
+      "void",
+      "getReferencedSymbols",
+      (ins "::llvm::SmallVectorImpl<::mlir::SymbolRefAttr>&":$symbols,
+           "::mlir::SymbolTable *":$symbolTable)>,
+  ];
+}
+
 #endif // OPENACC_OPS_INTERFACES

>From e2ca6975219ae8d3940cb33f583a43160e3a46f1 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 20 Nov 2025 17:44:02 -0800
Subject: [PATCH 2/3] Fix formatting

---
 .../OpenACC/Support/FIROpenACCOpsInterfaces.h |  7 ++--
 .../Support/FIROpenACCOpsInterfaces.cpp       | 39 +++++++++----------
 .../Support/RegisterOpenACCExtensions.cpp     |  4 +-
 3 files changed, 23 insertions(+), 27 deletions(-)

diff --git a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
index 87d60d489ba13..0020e1ab21a56 100644
--- a/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
+++ b/flang/include/flang/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.h
@@ -71,10 +71,9 @@ template <typename Op>
 struct IndirectGlobalAccessModel
     : public mlir::acc::IndirectGlobalAccessOpInterface::ExternalModel<
           IndirectGlobalAccessModel<Op>, Op> {
-  void getReferencedSymbols(
-      mlir::Operation *op,
-      llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
-      mlir::SymbolTable *symbolTable) const;
+  void getReferencedSymbols(mlir::Operation *op,
+                            llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+                            mlir::SymbolTable *symbolTable) const;
 };
 
 } // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index aa62b5a9820ee..2e5d8a61b5b32 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -73,8 +73,8 @@ bool GlobalVariableModel::isConstant(mlir::Operation *op) const {
 
 // Helper to recursively process address-of operations in derived type
 // descriptors and collect all needed fir.globals.
-static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp,
-    mlir::SymbolTable &symTab,
+static void processAddrOfOpInDerivedTypeDescriptor(
+    fir::AddrOfOp addrOfOp, mlir::SymbolTable &symTab,
     llvm::SmallSet<mlir::Operation *, 16> &globalsSet,
     llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
   if (auto globalOp = symTab.lookup<fir::GlobalOp>(
@@ -85,8 +85,7 @@ static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp,
     globalsSet.insert(globalOp);
     symbols.push_back(addrOfOp.getSymbolAttr());
     globalOp.walk([&](fir::AddrOfOp op) {
-      processAddrOfOpInDerivedTypeDescriptor(
-          op, symTab, globalsSet, symbols);
+      processAddrOfOpInDerivedTypeDescriptor(op, symTab, globalsSet, symbols);
     });
   }
 }
@@ -94,7 +93,8 @@ static void processAddrOfOpInDerivedTypeDescriptor(fir::AddrOfOp addrOfOp,
 // Utility to collect referenced symbols for type descriptors of derived types.
 // This is the common logic for operations that may require type descriptor
 // globals.
-static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
+static void collectReferencedSymbolsForType(
+    mlir::Type ty, mlir::Operation *op,
     llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
     mlir::SymbolTable *symbolTable) {
   ty = fir::getDerivedType(fir::unwrapRefType(ty));
@@ -103,7 +103,8 @@ static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
   if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(ty)) {
     // If no symbol table provided, simply add the type descriptor name
     if (!symbolTable) {
-      symbols.push_back(mlir::SymbolRefAttr::get(op->getContext(),
+      symbols.push_back(mlir::SymbolRefAttr::get(
+          op->getContext(),
           fir::NameUniquer::getTypeDescriptorName(recTy.getName())));
       return;
     }
@@ -119,11 +120,11 @@ static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
     }
     if (globalOp) {
       globalsSet.insert(globalOp);
-      symbols.push_back(mlir::SymbolRefAttr::get(
-          op->getContext(), globalOp.getSymName()));
+      symbols.push_back(
+          mlir::SymbolRefAttr::get(op->getContext(), globalOp.getSymName()));
       globalOp.walk([&](fir::AddrOfOp addrOp) {
-        processAddrOfOpInDerivedTypeDescriptor(
-            addrOp, *symbolTable, globalsSet, symbols);
+        processAddrOfOpInDerivedTypeDescriptor(addrOp, *symbolTable, globalsSet,
+                                               symbols);
       });
     }
   }
@@ -131,8 +132,7 @@ static void collectReferencedSymbolsForType(mlir::Type ty, mlir::Operation *op,
 
 template <>
 void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
-    mlir::Operation *op,
-    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
     mlir::SymbolTable *symbolTable) const {
   auto allocaOp = mlir::cast<fir::AllocaOp>(op);
   collectReferencedSymbolsForType(allocaOp.getType(), op, symbols, symbolTable);
@@ -140,32 +140,29 @@ void IndirectGlobalAccessModel<fir::AllocaOp>::getReferencedSymbols(
 
 template <>
 void IndirectGlobalAccessModel<fir::EmboxOp>::getReferencedSymbols(
-    mlir::Operation *op,
-    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
     mlir::SymbolTable *symbolTable) const {
   auto emboxOp = mlir::cast<fir::EmboxOp>(op);
   collectReferencedSymbolsForType(emboxOp.getMemref().getType(), op, symbols,
-                                   symbolTable);
+                                  symbolTable);
 }
 
 template <>
 void IndirectGlobalAccessModel<fir::ReboxOp>::getReferencedSymbols(
-    mlir::Operation *op,
-    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
     mlir::SymbolTable *symbolTable) const {
   auto reboxOp = mlir::cast<fir::ReboxOp>(op);
   collectReferencedSymbolsForType(reboxOp.getBox().getType(), op, symbols,
-                                   symbolTable);
+                                  symbolTable);
 }
 
 template <>
 void IndirectGlobalAccessModel<fir::TypeDescOp>::getReferencedSymbols(
-    mlir::Operation *op,
-    llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
+    mlir::Operation *op, llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols,
     mlir::SymbolTable *symbolTable) const {
   auto typeDescOp = mlir::cast<fir::TypeDescOp>(op);
   collectReferencedSymbolsForType(typeDescOp.getInType(), op, symbols,
-                                   symbolTable);
+                                  symbolTable);
 }
 
 } // namespace fir::acc
diff --git a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
index 915518c8de6c7..acd1d01ef1e87 100644
--- a/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/RegisterOpenACCExtensions.cpp
@@ -59,8 +59,8 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
         *ctx);
     fir::ReboxOp::attachInterface<IndirectGlobalAccessModel<fir::ReboxOp>>(
         *ctx);
-    fir::TypeDescOp::attachInterface<IndirectGlobalAccessModel<fir::TypeDescOp>>(
-        *ctx);
+    fir::TypeDescOp::attachInterface<
+        IndirectGlobalAccessModel<fir::TypeDescOp>>(*ctx);
   });
 
   // Register HLFIR operation interfaces

>From 234a531e863ddbf7303979e73301afa2d70c2d94 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 20 Nov 2025 20:03:18 -0800
Subject: [PATCH 3/3] Fix braces

---
 .../Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp  | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
index 2e5d8a61b5b32..902a2ecdec35f 100644
--- a/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
+++ b/flang/lib/Optimizer/OpenACC/Support/FIROpenACCOpsInterfaces.cpp
@@ -79,9 +79,8 @@ static void processAddrOfOpInDerivedTypeDescriptor(
     llvm::SmallVectorImpl<mlir::SymbolRefAttr> &symbols) {
   if (auto globalOp = symTab.lookup<fir::GlobalOp>(
           addrOfOp.getSymbol().getLeafReference().getValue())) {
-    if (globalsSet.contains(globalOp)) {
+    if (globalsSet.contains(globalOp))
       return;
-    }
     globalsSet.insert(globalOp);
     symbols.push_back(addrOfOp.getSymbolAttr());
     globalOp.walk([&](fir::AddrOfOp op) {
@@ -114,10 +113,10 @@ static void collectReferencedSymbolsForType(
 
     fir::GlobalOp globalOp = symbolTable->lookup<fir::GlobalOp>(
         fir::NameUniquer::getTypeDescriptorName(recTy.getName()));
-    if (!globalOp) {
+    if (!globalOp)
       globalOp = symbolTable->lookup<fir::GlobalOp>(
           fir::NameUniquer::getTypeDescriptorAssemblyName(recTy.getName()));
-    }
+
     if (globalOp) {
       globalsSet.insert(globalOp);
       symbols.push_back(



More information about the Mlir-commits mailing list