[llvm] [CodeExtractor] Allow to use 0 addr space for aggregate arg (PR #66998)

Dominik Adamski via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 28 02:29:22 PDT 2023


https://github.com/DominikAdamski updated https://github.com/llvm/llvm-project/pull/66998

>From c24e3190f35c5ca1ca9462010c5b22a605a3cb06 Mon Sep 17 00:00:00 2001
From: Dominik Adamski <dominik.adamski at amd.com>
Date: Wed, 27 Sep 2023 07:55:03 -0400
Subject: [PATCH] [CodeExtractor] Allow to use 0 addr space for aggregate arg

The user of CodeExtractor should be able to specify that
the aggregate argument should be passed as a pointer in zero address
space.

CodeExtractor is used to generate outlined functions required by OpenMP
runtime. The arguments of the outlined functions for OpenMP GPU code
are in 0 address space. 0 address space does not need to be the default
address space for GPU device. That's why there is a need to allow
the user of CodeExtractor to specify, that the allocated aggregate param
is passed as pointer in zero address space.
---
 .../llvm/Transforms/Utils/CodeExtractor.h       |  9 ++++++++-
 llvm/lib/Transforms/Utils/CodeExtractor.cpp     | 17 +++++++++++++----
 2 files changed, 21 insertions(+), 5 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index bb23cf4a9a3cbbb..27b34ef023db729 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -114,6 +114,10 @@ class CodeExtractorAnalysisCache {
     // label, if non-empty, otherwise "extracted".
     std::string Suffix;
 
+    // If true, the outlined function has aggregate argument in zero address
+    // space.
+    bool ArgsInZeroAddressSpace;
+
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -128,13 +132,16 @@ class CodeExtractorAnalysisCache {
     /// Any new allocations will be placed in the AllocationBlock, unless
     /// it is null, in which case it will be placed in the entry block of
     /// the function from which the code is being extracted.
+    /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
+    /// param pointer of the outlined function is declared in zero address
+    /// space.
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
                   BranchProbabilityInfo *BPI = nullptr,
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "");
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
 
     /// Create a code extractor for a loop body.
     ///
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index cafa99491f5b5f6..627fce094589c49 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
-                             BasicBlock *AllocationBlock, std::string Suffix)
+                             BasicBlock *AllocationBlock, std::string Suffix,
+                             bool ArgsInZeroAddressSpace)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
       AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix) {}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
 
 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
                              BlockFrequencyInfo *BFI,
@@ -866,7 +867,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   StructType *StructTy = nullptr;
   if (AggregateArgs && !AggParamTy.empty()) {
     StructTy = StructType::get(M->getContext(), AggParamTy);
-    ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace()));
+    ParamTy.push_back(PointerType::get(
+        StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
   }
 
   LLVM_DEBUG({
@@ -1186,8 +1188,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
         StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
         AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
                         : &codeReplacer->getParent()->front().front());
-    params.push_back(Struct);
 
+    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+      AddrSpaceCastInst *StructSpaceCast = new AddrSpaceCastInst(
+          Struct, Struct->getType()->getPointerTo(), "structArg.ascast");
+      StructSpaceCast->insertAfter(Struct);
+      params.push_back(StructSpaceCast);
+    } else {
+      params.push_back(Struct);
+    }
     // Store aggregated inputs in the struct.
     for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
       if (inputs.contains(StructValues[i])) {



More information about the llvm-commits mailing list