[llvm] eee8dd9 - [CodeExtractor] Allow to use 0 addr space for aggregate arg (#66998)

via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 18 11:12:35 PDT 2023


Author: Dominik Adamski
Date: 2023-10-18T20:12:31+02:00
New Revision: eee8dd90887cbf86fa0fea1ff770377a87af0257

URL: https://github.com/llvm/llvm-project/commit/eee8dd90887cbf86fa0fea1ff770377a87af0257
DIFF: https://github.com/llvm/llvm-project/commit/eee8dd90887cbf86fa0fea1ff770377a87af0257.diff

LOG: [CodeExtractor] Allow to use 0 addr space for aggregate arg (#66998)

The user of CodeExtractor should be able to specify that
the aggregate argument should be passed as a pointer in zero address
space.

CodeExtractor is used to generate outlined functions required by OpenMP
runtime. The arguments of the outlined functions for OpenMP GPU code
are in 0 address space. 0 address space does not need to be the default
address space for GPU device. That's why there is a need to allow
the user of CodeExtractor to specify, that the allocated aggregate parameter
is passed as pointer in zero address space.

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/CodeExtractor.h
    llvm/lib/Transforms/Utils/CodeExtractor.cpp
    llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index bb23cf4a9a3cbbb..27b34ef023db729 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -114,6 +114,10 @@ class CodeExtractorAnalysisCache {
     // label, if non-empty, otherwise "extracted".
     std::string Suffix;
 
+    // If true, the outlined function has aggregate argument in zero address
+    // space.
+    bool ArgsInZeroAddressSpace;
+
   public:
     /// Create a code extractor for a sequence of blocks.
     ///
@@ -128,13 +132,16 @@ class CodeExtractorAnalysisCache {
     /// Any new allocations will be placed in the AllocationBlock, unless
     /// it is null, in which case it will be placed in the entry block of
     /// the function from which the code is being extracted.
+    /// If ArgsInZeroAddressSpace param is set to true, then the aggregate
+    /// param pointer of the outlined function is declared in zero address
+    /// space.
     CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT = nullptr,
                   bool AggregateArgs = false, BlockFrequencyInfo *BFI = nullptr,
                   BranchProbabilityInfo *BPI = nullptr,
                   AssumptionCache *AC = nullptr, bool AllowVarArgs = false,
                   bool AllowAlloca = false,
                   BasicBlock *AllocationBlock = nullptr,
-                  std::string Suffix = "");
+                  std::string Suffix = "", bool ArgsInZeroAddressSpace = false);
 
     /// Create a code extractor for a loop body.
     ///

diff  --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index ae7ed296c45ea88..b251a85cf85f92a 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -245,12 +245,13 @@ CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
                              bool AggregateArgs, BlockFrequencyInfo *BFI,
                              BranchProbabilityInfo *BPI, AssumptionCache *AC,
                              bool AllowVarArgs, bool AllowAlloca,
-                             BasicBlock *AllocationBlock, std::string Suffix)
+                             BasicBlock *AllocationBlock, std::string Suffix,
+                             bool ArgsInZeroAddressSpace)
     : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
       BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
       AllowVarArgs(AllowVarArgs),
       Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
-      Suffix(Suffix) {}
+      Suffix(Suffix), ArgsInZeroAddressSpace(ArgsInZeroAddressSpace) {}
 
 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
                              BlockFrequencyInfo *BFI,
@@ -866,7 +867,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   StructType *StructTy = nullptr;
   if (AggregateArgs && !AggParamTy.empty()) {
     StructTy = StructType::get(M->getContext(), AggParamTy);
-    ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace()));
+    ParamTy.push_back(PointerType::get(
+        StructTy, ArgsInZeroAddressSpace ? 0 : DL.getAllocaAddrSpace()));
   }
 
   LLVM_DEBUG({
@@ -1187,8 +1189,15 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
         StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
         AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
                         : &codeReplacer->getParent()->front().front());
-    params.push_back(Struct);
 
+    if (ArgsInZeroAddressSpace && DL.getAllocaAddrSpace() != 0) {
+      auto *StructSpaceCast = new AddrSpaceCastInst(
+          Struct, PointerType ::get(Context, 0), "structArg.ascast");
+      StructSpaceCast->insertAfter(Struct);
+      params.push_back(StructSpaceCast);
+    } else {
+      params.push_back(Struct);
+    }
     // Store aggregated inputs in the struct.
     for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
       if (inputs.contains(StructValues[i])) {

diff  --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index c142729e2c6f424..528d33239332645 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -555,4 +555,64 @@ TEST(CodeExtractor, PartialAggregateArgs) {
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
+
+TEST(CodeExtractor, OpenMPAggregateArgs) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"ir(
+    target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8"
+    target triple = "amdgcn-amd-amdhsa"
+
+    define void @foo(ptr %0) {
+      %2= alloca ptr, align 8, addrspace(5)
+      %3 = addrspacecast ptr addrspace(5) %2 to ptr
+      store ptr %0, ptr %3, align 8
+      %4 = load ptr, ptr %3, align 8
+      br label %entry
+
+   entry:
+      br label %extract
+
+    extract:
+      store i64 10, ptr %4, align 4
+      br label %exit
+
+    exit:
+      ret void
+    }
+  )ir",
+                                                Err, Ctx));
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
+
+  // Create the CodeExtractor with arguments aggregation enabled.
+  // Outlined function argument should be declared in 0 address space
+  // even if the default alloca address space is 5.
+  CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
+                   /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr,
+                   /* BranchProbabilityInfo */ nullptr,
+                   /* AssumptionCache */ nullptr,
+                   /* AllowVarArgs */ true,
+                   /* AllowAlloca */ true,
+                   /* AllocaBlock*/ &Func->getEntryBlock(),
+                   /* Suffix */ ".outlined",
+                   /* ArgsInZeroAddressSpace */ true);
+
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
+  CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+  CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
+
+  Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+  EXPECT_TRUE(Outlined);
+  EXPECT_EQ(Outlined->arg_size(), 1U);
+  // Check address space of outlined argument is ptr in address space 0
+  EXPECT_EQ(Outlined->getArg(0)->getType(),
+            PointerType::get(M->getContext(), 0));
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
 } // end anonymous namespace


        


More information about the llvm-commits mailing list