[llvm] 95b981c - [CodeExtractor] Enable partial aggregate arguments

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 25 18:24:56 PST 2022


Author: Giorgis Georgakoudis
Date: 2022-01-25T20:50:34-05:00
New Revision: 95b981ca2ae3915464a63d42eb53b0dde4a88227

URL: https://github.com/llvm/llvm-project/commit/95b981ca2ae3915464a63d42eb53b0dde4a88227
DIFF: https://github.com/llvm/llvm-project/commit/95b981ca2ae3915464a63d42eb53b0dde4a88227.diff

LOG: [CodeExtractor] Enable partial aggregate arguments

Summary:
Enable CodeExtractor to construct output functions that partially
aggregate inputs/outputs in their argument list. A use case is the
OMPIRBuilder to create outlined functions for parallel regions that
aggregate in a struct the payload variables for the region while passing
as scalars thread and bound identifiers.

Differential Revision: https://reviews.llvm.org/D96854

Added: 
    

Modified: 
    llvm/include/llvm/Transforms/Utils/CodeExtractor.h
    llvm/lib/Transforms/Utils/CodeExtractor.cpp
    llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index f08173e45a5bf..8aed3d0e40d93 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -168,7 +168,7 @@ class CodeExtractorAnalysisCache {
     ///
     /// Based on the blocks used when constructing the code extractor,
     /// determine whether it is eligible for extraction.
-    /// 
+    ///
     /// Checks that varargs handling (with vastart and vaend) is only done in
     /// the outlined blocks.
     bool isEligible() const;
@@ -214,6 +214,10 @@ class CodeExtractorAnalysisCache {
     /// original block will be added to the outline region.
     BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
 
+    /// Exclude a value from aggregate argument passing when extracting a code
+    /// region, passing it instead as a scalar.
+    void excludeArgFromAggregate(Value *Arg);
+
   private:
     struct LifetimeMarkerInfo {
       bool SinkLifeStart = false;
@@ -222,6 +226,8 @@ class CodeExtractorAnalysisCache {
       Instruction *LifeEnd = nullptr;
     };
 
+    ValueSet ExcludeArgsFromAggregate;
+
     LifetimeMarkerInfo
     getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
                        Instruction *Addr, BasicBlock *ExitBlock) const;

diff  --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index f577643f81b0f..24cd5747c5a45 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -829,39 +829,54 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   default: RetTy = Type::getInt16Ty(header->getContext()); break;
   }
 
-  std::vector<Type *> paramTy;
+  std::vector<Type *> ParamTy;
+  std::vector<Type *> AggParamTy;
+  ValueSet StructValues;
 
   // Add the types of the input values to the function's argument list
   for (Value *value : inputs) {
     LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
-    paramTy.push_back(value->getType());
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
+      AggParamTy.push_back(value->getType());
+      StructValues.insert(value);
+    } else
+      ParamTy.push_back(value->getType());
   }
 
   // Add the types of the output values to the function's argument list.
   for (Value *output : outputs) {
     LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
-    if (AggregateArgs)
-      paramTy.push_back(output->getType());
-    else
-      paramTy.push_back(PointerType::getUnqual(output->getType()));
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+      AggParamTy.push_back(output->getType());
+      StructValues.insert(output);
+    } else
+      ParamTy.push_back(PointerType::getUnqual(output->getType()));
+  }
+
+  assert(
+      (ParamTy.size() + AggParamTy.size()) ==
+          (inputs.size() + outputs.size()) &&
+      "Number of scalar and aggregate params does not match inputs, outputs");
+  assert(StructValues.empty() ||
+         AggregateArgs && "Expeced StructValues only with AggregateArgs set");
+
+  // Concatenate scalar and aggregate params in ParamTy.
+  size_t NumScalarParams = ParamTy.size();
+  StructType *StructTy = nullptr;
+  if (AggregateArgs && !AggParamTy.empty()) {
+    StructTy = StructType::get(M->getContext(), AggParamTy);
+    ParamTy.push_back(PointerType::getUnqual(StructTy));
   }
 
   LLVM_DEBUG({
     dbgs() << "Function type: " << *RetTy << " f(";
-    for (Type *i : paramTy)
+    for (Type *i : ParamTy)
       dbgs() << *i << ", ";
     dbgs() << ")\n";
   });
 
-  StructType *StructTy = nullptr;
-  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
-    StructTy = StructType::get(M->getContext(), paramTy);
-    paramTy.clear();
-    paramTy.push_back(PointerType::getUnqual(StructTy));
-  }
-  FunctionType *funcType =
-                  FunctionType::get(RetTy, paramTy,
-                                    AllowVarArgs && oldFunction->isVarArg());
+  FunctionType *funcType = FunctionType::get(
+      RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
 
   std::string SuffixToUse =
       Suffix.empty()
@@ -981,24 +996,27 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   }
   newFunction->getBasicBlockList().push_back(newRootNode);
 
-  // Create an iterator to name all of the arguments we inserted.
-  Function::arg_iterator AI = newFunction->arg_begin();
+  // Create scalar and aggregate iterators to name all of the arguments we
+  // inserted.
+  Function::arg_iterator ScalarAI = newFunction->arg_begin();
+  Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
 
   // Rewrite all users of the inputs in the extracted region to use the
   // arguments (or appropriate addressing into struct) instead.
-  for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
+  for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
     Value *RewriteVal;
-    if (AggregateArgs) {
+    if (AggregateArgs && StructValues.contains(inputs[i])) {
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
       Instruction *TI = newFunction->begin()->getTerminator();
       GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
-      RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
+          StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
+      RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
                                 "loadgep_" + inputs[i]->getName(), TI);
+      ++aggIdx;
     } else
-      RewriteVal = &*AI++;
+      RewriteVal = &*ScalarAI++;
 
     std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
     for (User *use : Users)
@@ -1008,12 +1026,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   }
 
   // Set names for input and output arguments.
-  if (!AggregateArgs) {
-    AI = newFunction->arg_begin();
-    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
-      AI->setName(inputs[i]->getName());
-    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
-      AI->setName(outputs[i]->getName()+".out");
+  if (NumScalarParams) {
+    ScalarAI = newFunction->arg_begin();
+    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
+      if (!StructValues.contains(inputs[i]))
+        ScalarAI->setName(inputs[i]->getName());
+    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
+      if (!StructValues.contains(outputs[i]))
+        ScalarAI->setName(outputs[i]->getName() + ".out");
   }
 
   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
@@ -1121,7 +1141,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
                                                     ValueSet &outputs) {
   // Emit a call to the new function, passing in: *pointer to struct (if
   // aggregating parameters), or plan inputs and allocated memory for outputs
-  std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
+  std::vector<Value *> params, ReloadOutputs, Reloads;
+  ValueSet StructValues;
 
   Module *M = newFunction->getParent();
   LLVMContext &Context = M->getContext();
@@ -1129,23 +1150,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
   CallInst *call = nullptr;
 
   // Add inputs as params, or to be filled into the struct
-  unsigned ArgNo = 0;
+  unsigned ScalarInputArgNo = 0;
   SmallVector<unsigned, 1> SwiftErrorArgs;
   for (Value *input : inputs) {
-    if (AggregateArgs)
-      StructValues.push_back(input);
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
+      StructValues.insert(input);
     else {
       params.push_back(input);
       if (input->isSwiftError())
-        SwiftErrorArgs.push_back(ArgNo);
+        SwiftErrorArgs.push_back(ScalarInputArgNo);
     }
-    ++ArgNo;
+    ++ScalarInputArgNo;
   }
 
   // Create allocas for the outputs
+  unsigned ScalarOutputArgNo = 0;
   for (Value *output : outputs) {
-    if (AggregateArgs) {
-      StructValues.push_back(output);
+    if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+      StructValues.insert(output);
     } else {
       AllocaInst *alloca =
         new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
@@ -1153,12 +1175,14 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
                        &codeReplacer->getParent()->front().front());
       ReloadOutputs.push_back(alloca);
       params.push_back(alloca);
+      ++ScalarOutputArgNo;
     }
   }
 
   StructType *StructArgTy = nullptr;
   AllocaInst *Struct = nullptr;
-  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
+  unsigned NumAggregatedInputs = 0;
+  if (AggregateArgs && !StructValues.empty()) {
     std::vector<Type *> ArgTypes;
     for (Value *V : StructValues)
       ArgTypes.push_back(V->getType());
@@ -1170,14 +1194,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
                             &codeReplacer->getParent()->front().front());
     params.push_back(Struct);
 
-    for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
-      Value *Idx[2];
-      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
-      GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
-      codeReplacer->getInstList().push_back(GEP);
-      new StoreInst(StructValues[i], GEP, codeReplacer);
+    // Store aggregated inputs in the struct.
+    for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
+      if (inputs.contains(StructValues[i])) {
+        Value *Idx[2];
+        Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+        Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
+        GetElementPtrInst *GEP = GetElementPtrInst::Create(
+            StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
+        codeReplacer->getInstList().push_back(GEP);
+        new StoreInst(StructValues[i], GEP, codeReplacer);
+        NumAggregatedInputs++;
+      }
     }
   }
 
@@ -1200,24 +1228,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
     newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
   }
 
-  Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
-  unsigned FirstOut = inputs.size();
-  if (!AggregateArgs)
-    std::advance(OutputArgBegin, inputs.size());
-
-  // Reload the outputs passed in by reference.
-  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+  // Reload the outputs passed in by reference, use the struct if output is in
+  // the aggregate or reload from the scalar argument.
+  for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
+                aggIdx = NumAggregatedInputs;
+       i != e; ++i) {
     Value *Output = nullptr;
-    if (AggregateArgs) {
+    if (AggregateArgs && StructValues.contains(outputs[i])) {
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
       GetElementPtrInst *GEP = GetElementPtrInst::Create(
           StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
       codeReplacer->getInstList().push_back(GEP);
       Output = GEP;
+      ++aggIdx;
     } else {
-      Output = ReloadOutputs[i];
+      Output = ReloadOutputs[scalarIdx];
+      ++scalarIdx;
     }
     LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
                                   outputs[i]->getName() + ".reload",
@@ -1299,8 +1327,13 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
   // Store the arguments right after the definition of output value.
   // This should be proceeded after creating exit stubs to be ensure that invoke
   // result restore will be placed in the outlined function.
-  Function::arg_iterator OAI = OutputArgBegin;
-  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+  Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
+  std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
+  Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
+  std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
+
+  for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
+       ++i) {
     auto *OutI = dyn_cast<Instruction>(outputs[i]);
     if (!OutI)
       continue;
@@ -1320,23 +1353,27 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
     assert((InsertBefore->getFunction() == newFunction ||
             Blocks.count(InsertBefore->getParent())) &&
            "InsertPt should be in new function");
-    assert(OAI != newFunction->arg_end() &&
-           "Number of output arguments should match "
-           "the amount of defined values");
-    if (AggregateArgs) {
+    if (AggregateArgs && StructValues.contains(outputs[i])) {
+      assert(AggOutputArgBegin != newFunction->arg_end() &&
+             "Number of aggregate output arguments should match "
+             "the number of defined values");
       Value *Idx[2];
       Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
-      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
       GetElementPtrInst *GEP = GetElementPtrInst::Create(
-          StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
+          StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
           InsertBefore);
       new StoreInst(outputs[i], GEP, InsertBefore);
+      ++aggIdx;
       // Since there should be only one struct argument aggregating
-      // all the output values, we shouldn't increment OAI, which always
-      // points to the struct argument, in this case.
+      // all the output values, we shouldn't increment AggOutputArgBegin, which
+      // always points to the struct argument, in this case.
     } else {
-      new StoreInst(outputs[i], &*OAI, InsertBefore);
-      ++OAI;
+      assert(ScalarOutputArgBegin != newFunction->arg_end() &&
+             "Number of scalar output arguments should match "
+             "the number of defined values");
+      new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
+      ++ScalarOutputArgBegin;
     }
   }
 
@@ -1835,3 +1872,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
   }
   return false;
 }
+
+void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
+  ExcludeArgsFromAggregate.insert(Arg);
+}

diff  --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 5f7b0111c1c62..023a41c5bfd69 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -188,7 +188,7 @@ TEST(CodeExtractor, ExitBlockOrderingPhis) {
   EXPECT_TRUE(NextReturn);
   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
-  
+
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
@@ -245,7 +245,7 @@ TEST(CodeExtractor, ExitBlockOrdering) {
   EXPECT_TRUE(NextReturn);
   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
-  
+
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
@@ -504,4 +504,54 @@ TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
+
+TEST(CodeExtractor, PartialAggregateArgs) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"ir(
+    target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+    target triple = "x86_64-unknown-linux-gnu"
+
+    declare void @use(i32)
+
+    define void @foo(i32 %a, i32 %b, i32 %c) {
+    entry:
+      br label %extract
+
+    extract:
+      call void @use(i32 %a)
+      call void @use(i32 %b)
+      call void @use(i32 %c)
+      br label %exit
+
+    exit:
+      ret void
+    }
+  )ir",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
+
+  // Create the CodeExtractor with arguments aggregation enabled.
+  CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
+                   /* AggregateArgs */ true);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
+  CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+  CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
+  // Exclude the first input from the argument aggregate.
+  CE.excludeArgFromAggregate(Inputs[0]);
+
+  Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+  EXPECT_TRUE(Outlined);
+  // Expect 2 arguments in the outlined function: the excluded input and the
+  // struct aggregate for the remaining inputs.
+  EXPECT_EQ(Outlined->arg_size(), 2U);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
 } // end anonymous namespace


        


More information about the llvm-commits mailing list