[llvm] [AMDGPU] Introduce "amdgpu-sw-lower-lds" pass to lower LDS accesses to use device global memory. (PR #87265)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 19 04:08:54 PDT 2024


================
@@ -65,6 +68,219 @@ bool isLDSVariableToLower(const GlobalVariable &GV) {
   return true;
 }
 
+bool eliminateConstantExprUsesOfLDSFromAllInstructions(Module &M) {
+  // Constants are uniqued within LLVM. A ConstantExpr referring to a LDS
+  // global may have uses from multiple different functions as a result.
+  // This pass specialises LDS variables with respect to the kernel that
+  // allocates them.
+
+  // This is semantically equivalent to (the unimplemented as slow):
+  // for (auto &F : M.functions())
+  //   for (auto &BB : F)
+  //     for (auto &I : BB)
+  //       for (Use &Op : I.operands())
+  //         if (constantExprUsesLDS(Op))
+  //           replaceConstantExprInFunction(I, Op);
+
+  SmallVector<Constant *> LDSGlobals;
+  for (auto &GV : M.globals())
+    if (AMDGPU::isLDSVariableToLower(GV))
+      LDSGlobals.push_back(&GV);
+  return convertUsersOfConstantsToInstructions(LDSGlobals);
+}
+
+void getUsesOfLDSByFunction(CallGraph const &CG, Module &M,
+                            FunctionVariableMap &kernels,
+                            FunctionVariableMap &functions) {
+  // Get uses from the current function, excluding uses by called functions
+  // Two output variables to avoid walking the globals list twice
+  for (auto &GV : M.globals()) {
+    if (!AMDGPU::isLDSVariableToLower(GV)) {
+      continue;
+    }
+    for (User *V : GV.users()) {
+      if (auto *I = dyn_cast<Instruction>(V)) {
+        Function *F = I->getFunction();
+        if (isKernelLDS(F)) {
+          kernels[F].insert(&GV);
+        } else {
+          functions[F].insert(&GV);
+        }
+      }
+    }
+  }
+}
+
+bool isKernelLDS(const Function *F) {
+  // Some weirdness here. AMDGPU::isKernelCC does not call into
+  // AMDGPU::isKernel with the calling conv, it instead calls into
+  // isModuleEntryFunction which returns true for more calling conventions
+  // than AMDGPU::isKernel does. There's a FIXME on AMDGPU::isKernel.
+  // There's also a test that checks that the LDS lowering does not hit on
+  // a graphics shader, denoted amdgpu_ps, so stay with the limited case.
+  // Putting LDS in the name of the function to draw attention to this.
+  return AMDGPU::isKernel(F->getCallingConv());
+}
+
+LDSUsesInfoTy getTransitiveUsesOfLDS(CallGraph const &CG, Module &M) {
+
+  FunctionVariableMap direct_map_kernel;
+  FunctionVariableMap direct_map_function;
+  getUsesOfLDSByFunction(CG, M, direct_map_kernel, direct_map_function);
+
+  // Collect variables that are used by functions whose address has escaped
+  DenseSet<GlobalVariable *> VariablesReachableThroughFunctionPointer;
+  for (Function &F : M.functions()) {
+    if (!isKernelLDS(&F))
+      if (F.hasAddressTaken(nullptr,
+                            /* IgnoreCallbackUses */ false,
+                            /* IgnoreAssumeLikeCalls */ false,
+                            /* IgnoreLLVMUsed */ true,
+                            /* IgnoreArcAttachedCall */ false)) {
+        set_union(VariablesReachableThroughFunctionPointer,
+                  direct_map_function[&F]);
+      }
+  }
+
+  auto functionMakesUnknownCall = [&](const Function *F) -> bool {
+    assert(!F->isDeclaration());
+    for (const CallGraphNode::CallRecord &R : *CG[F]) {
+      if (!R.second->getFunction()) {
+        return true;
+      }
+    }
+    return false;
+  };
+
+  // Work out which variables are reachable through function calls
+  FunctionVariableMap transitive_map_function = direct_map_function;
+
+  // If the function makes any unknown call, assume the worst case that it can
+  // access all variables accessed by functions whose address escaped
+  for (Function &F : M.functions()) {
+    if (!F.isDeclaration() && functionMakesUnknownCall(&F)) {
+      if (!isKernelLDS(&F)) {
+        set_union(transitive_map_function[&F],
+                  VariablesReachableThroughFunctionPointer);
+      }
+    }
+  }
+
+  // Direct implementation of collecting all variables reachable from each
+  // function
+  for (Function &Func : M.functions()) {
+    if (Func.isDeclaration() || isKernelLDS(&Func))
+      continue;
+
+    DenseSet<Function *> seen; // catches cycles
+    SmallVector<Function *, 4> wip{&Func};
+
+    while (!wip.empty()) {
+      Function *F = wip.pop_back_val();
+
+      // Can accelerate this by referring to transitive map for functions that
+      // have already been computed, with more care than this
+      set_union(transitive_map_function[&Func], direct_map_function[F]);
+
+      for (const CallGraphNode::CallRecord &R : *CG[F]) {
+        Function *ith = R.second->getFunction();
+        if (ith) {
+          if (!seen.contains(ith)) {
+            seen.insert(ith);
+            wip.push_back(ith);
+          }
+        }
+      }
+    }
+  }
+
+  // direct_map_kernel lists which variables are used by the kernel
+  // find the variables which are used through a function call
+  FunctionVariableMap indirect_map_kernel;
+
+  for (Function &Func : M.functions()) {
+    if (Func.isDeclaration() || !isKernelLDS(&Func))
+      continue;
+
+    for (const CallGraphNode::CallRecord &R : *CG[&Func]) {
+      Function *ith = R.second->getFunction();
+      if (ith) {
+        set_union(indirect_map_kernel[&Func], transitive_map_function[ith]);
+      } else {
+        set_union(indirect_map_kernel[&Func],
+                  VariablesReachableThroughFunctionPointer);
+      }
+    }
+  }
+
+  // Verify that we fall into one of 2 cases:
+  //    - All variables are absolute: this is a re-run of the pass
+  //      so we don't have anything to do.
+  //    - No variables are absolute.
+  std::optional<bool> HasAbsoluteGVs;
+  for (auto &Map : {direct_map_kernel, indirect_map_kernel}) {
+    for (auto &[Fn, GVs] : Map) {
+      for (auto *GV : GVs) {
+        bool IsAbsolute = GV->isAbsoluteSymbolRef();
+        if (HasAbsoluteGVs.has_value()) {
+          if (*HasAbsoluteGVs != IsAbsolute) {
+            report_fatal_error(
+                "Module cannot mix absolute and non-absolute LDS GVs");
+          }
+        } else
+          HasAbsoluteGVs = IsAbsolute;
+      }
+    }
+  }
+
+  // If we only had absolute GVs, we have nothing to do, return an empty
+  // result.
+  if (HasAbsoluteGVs && *HasAbsoluteGVs)
+    return {FunctionVariableMap(), FunctionVariableMap()};
+
+  return {std::move(direct_map_kernel), std::move(indirect_map_kernel)};
+}
+
+void removeFnAttrFromReachable(CallGraph &CG, Function *KernelRoot,
+                               StringRef FnAttr) {
+  KernelRoot->removeFnAttr(FnAttr);
+
+  SmallVector<Function *> WorkList({CG[KernelRoot]->getFunction()});
----------------
skc7 wrote:

Updated.

https://github.com/llvm/llvm-project/pull/87265


More information about the llvm-commits mailing list