[llvm] 95b981c - [CodeExtractor] Enable partial aggregate arguments
Joseph Huber via llvm-commits
llvm-commits at lists.llvm.org
Tue Jan 25 18:24:56 PST 2022
Author: Giorgis Georgakoudis
Date: 2022-01-25T20:50:34-05:00
New Revision: 95b981ca2ae3915464a63d42eb53b0dde4a88227
URL: https://github.com/llvm/llvm-project/commit/95b981ca2ae3915464a63d42eb53b0dde4a88227
DIFF: https://github.com/llvm/llvm-project/commit/95b981ca2ae3915464a63d42eb53b0dde4a88227.diff
LOG: [CodeExtractor] Enable partial aggregate arguments
Summary:
Enable CodeExtractor to construct output functions that partially
aggregate inputs/outputs in their argument list. A use case is the
OMPIRBuilder to create outlined functions for parallel regions that
aggregate in a struct the payload variables for the region while passing
as scalars thread and bound identifiers.
Differential Revision: https://reviews.llvm.org/D96854
Added:
Modified:
llvm/include/llvm/Transforms/Utils/CodeExtractor.h
llvm/lib/Transforms/Utils/CodeExtractor.cpp
llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
index f08173e45a5bf..8aed3d0e40d93 100644
--- a/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
+++ b/llvm/include/llvm/Transforms/Utils/CodeExtractor.h
@@ -168,7 +168,7 @@ class CodeExtractorAnalysisCache {
///
/// Based on the blocks used when constructing the code extractor,
/// determine whether it is eligible for extraction.
- ///
+ ///
/// Checks that varargs handling (with vastart and vaend) is only done in
/// the outlined blocks.
bool isEligible() const;
@@ -214,6 +214,10 @@ class CodeExtractorAnalysisCache {
/// original block will be added to the outline region.
BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
+ /// Exclude a value from aggregate argument passing when extracting a code
+ /// region, passing it instead as a scalar.
+ void excludeArgFromAggregate(Value *Arg);
+
private:
struct LifetimeMarkerInfo {
bool SinkLifeStart = false;
@@ -222,6 +226,8 @@ class CodeExtractorAnalysisCache {
Instruction *LifeEnd = nullptr;
};
+ ValueSet ExcludeArgsFromAggregate;
+
LifetimeMarkerInfo
getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
Instruction *Addr, BasicBlock *ExitBlock) const;
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index f577643f81b0f..24cd5747c5a45 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -829,39 +829,54 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
default: RetTy = Type::getInt16Ty(header->getContext()); break;
}
- std::vector<Type *> paramTy;
+ std::vector<Type *> ParamTy;
+ std::vector<Type *> AggParamTy;
+ ValueSet StructValues;
// Add the types of the input values to the function's argument list
for (Value *value : inputs) {
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
- paramTy.push_back(value->getType());
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
+ AggParamTy.push_back(value->getType());
+ StructValues.insert(value);
+ } else
+ ParamTy.push_back(value->getType());
}
// Add the types of the output values to the function's argument list.
for (Value *output : outputs) {
LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
- if (AggregateArgs)
- paramTy.push_back(output->getType());
- else
- paramTy.push_back(PointerType::getUnqual(output->getType()));
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+ AggParamTy.push_back(output->getType());
+ StructValues.insert(output);
+ } else
+ ParamTy.push_back(PointerType::getUnqual(output->getType()));
+ }
+
+ assert(
+ (ParamTy.size() + AggParamTy.size()) ==
+ (inputs.size() + outputs.size()) &&
+ "Number of scalar and aggregate params does not match inputs, outputs");
+ assert(StructValues.empty() ||
+ AggregateArgs && "Expeced StructValues only with AggregateArgs set");
+
+ // Concatenate scalar and aggregate params in ParamTy.
+ size_t NumScalarParams = ParamTy.size();
+ StructType *StructTy = nullptr;
+ if (AggregateArgs && !AggParamTy.empty()) {
+ StructTy = StructType::get(M->getContext(), AggParamTy);
+ ParamTy.push_back(PointerType::getUnqual(StructTy));
}
LLVM_DEBUG({
dbgs() << "Function type: " << *RetTy << " f(";
- for (Type *i : paramTy)
+ for (Type *i : ParamTy)
dbgs() << *i << ", ";
dbgs() << ")\n";
});
- StructType *StructTy = nullptr;
- if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
- StructTy = StructType::get(M->getContext(), paramTy);
- paramTy.clear();
- paramTy.push_back(PointerType::getUnqual(StructTy));
- }
- FunctionType *funcType =
- FunctionType::get(RetTy, paramTy,
- AllowVarArgs && oldFunction->isVarArg());
+ FunctionType *funcType = FunctionType::get(
+ RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
std::string SuffixToUse =
Suffix.empty()
@@ -981,24 +996,27 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}
newFunction->getBasicBlockList().push_back(newRootNode);
- // Create an iterator to name all of the arguments we inserted.
- Function::arg_iterator AI = newFunction->arg_begin();
+ // Create scalar and aggregate iterators to name all of the arguments we
+ // inserted.
+ Function::arg_iterator ScalarAI = newFunction->arg_begin();
+ Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
// Rewrite all users of the inputs in the extracted region to use the
// arguments (or appropriate addressing into struct) instead.
- for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
+ for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
Value *RewriteVal;
- if (AggregateArgs) {
+ if (AggregateArgs && StructValues.contains(inputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
Instruction *TI = newFunction->begin()->getTerminator();
GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
- RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
+ StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
+ RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
"loadgep_" + inputs[i]->getName(), TI);
+ ++aggIdx;
} else
- RewriteVal = &*AI++;
+ RewriteVal = &*ScalarAI++;
std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
for (User *use : Users)
@@ -1008,12 +1026,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}
// Set names for input and output arguments.
- if (!AggregateArgs) {
- AI = newFunction->arg_begin();
- for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
- AI->setName(inputs[i]->getName());
- for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
- AI->setName(outputs[i]->getName()+".out");
+ if (NumScalarParams) {
+ ScalarAI = newFunction->arg_begin();
+ for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
+ if (!StructValues.contains(inputs[i]))
+ ScalarAI->setName(inputs[i]->getName());
+ for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
+ if (!StructValues.contains(outputs[i]))
+ ScalarAI->setName(outputs[i]->getName() + ".out");
}
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
@@ -1121,7 +1141,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
ValueSet &outputs) {
// Emit a call to the new function, passing in: *pointer to struct (if
// aggregating parameters), or plan inputs and allocated memory for outputs
- std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
+ std::vector<Value *> params, ReloadOutputs, Reloads;
+ ValueSet StructValues;
Module *M = newFunction->getParent();
LLVMContext &Context = M->getContext();
@@ -1129,23 +1150,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
CallInst *call = nullptr;
// Add inputs as params, or to be filled into the struct
- unsigned ArgNo = 0;
+ unsigned ScalarInputArgNo = 0;
SmallVector<unsigned, 1> SwiftErrorArgs;
for (Value *input : inputs) {
- if (AggregateArgs)
- StructValues.push_back(input);
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
+ StructValues.insert(input);
else {
params.push_back(input);
if (input->isSwiftError())
- SwiftErrorArgs.push_back(ArgNo);
+ SwiftErrorArgs.push_back(ScalarInputArgNo);
}
- ++ArgNo;
+ ++ScalarInputArgNo;
}
// Create allocas for the outputs
+ unsigned ScalarOutputArgNo = 0;
for (Value *output : outputs) {
- if (AggregateArgs) {
- StructValues.push_back(output);
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+ StructValues.insert(output);
} else {
AllocaInst *alloca =
new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
@@ -1153,12 +1175,14 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
ReloadOutputs.push_back(alloca);
params.push_back(alloca);
+ ++ScalarOutputArgNo;
}
}
StructType *StructArgTy = nullptr;
AllocaInst *Struct = nullptr;
- if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
+ unsigned NumAggregatedInputs = 0;
+ if (AggregateArgs && !StructValues.empty()) {
std::vector<Type *> ArgTypes;
for (Value *V : StructValues)
ArgTypes.push_back(V->getType());
@@ -1170,14 +1194,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
params.push_back(Struct);
- for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
- Value *Idx[2];
- Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
- GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
- codeReplacer->getInstList().push_back(GEP);
- new StoreInst(StructValues[i], GEP, codeReplacer);
+ // Store aggregated inputs in the struct.
+ for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
+ if (inputs.contains(StructValues[i])) {
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
+ codeReplacer->getInstList().push_back(GEP);
+ new StoreInst(StructValues[i], GEP, codeReplacer);
+ NumAggregatedInputs++;
+ }
}
}
@@ -1200,24 +1228,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
}
- Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
- unsigned FirstOut = inputs.size();
- if (!AggregateArgs)
- std::advance(OutputArgBegin, inputs.size());
-
- // Reload the outputs passed in by reference.
- for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+ // Reload the outputs passed in by reference, use the struct if output is in
+ // the aggregate or reload from the scalar argument.
+ for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
+ aggIdx = NumAggregatedInputs;
+ i != e; ++i) {
Value *Output = nullptr;
- if (AggregateArgs) {
+ if (AggregateArgs && StructValues.contains(outputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
codeReplacer->getInstList().push_back(GEP);
Output = GEP;
+ ++aggIdx;
} else {
- Output = ReloadOutputs[i];
+ Output = ReloadOutputs[scalarIdx];
+ ++scalarIdx;
}
LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
outputs[i]->getName() + ".reload",
@@ -1299,8 +1327,13 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// Store the arguments right after the definition of output value.
// This should be proceeded after creating exit stubs to be ensure that invoke
// result restore will be placed in the outlined function.
- Function::arg_iterator OAI = OutputArgBegin;
- for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
+ Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
+ std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
+ Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
+ std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
+
+ for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
+ ++i) {
auto *OutI = dyn_cast<Instruction>(outputs[i]);
if (!OutI)
continue;
@@ -1320,23 +1353,27 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
assert((InsertBefore->getFunction() == newFunction ||
Blocks.count(InsertBefore->getParent())) &&
"InsertPt should be in new function");
- assert(OAI != newFunction->arg_end() &&
- "Number of output arguments should match "
- "the amount of defined values");
- if (AggregateArgs) {
+ if (AggregateArgs && StructValues.contains(outputs[i])) {
+ assert(AggOutputArgBegin != newFunction->arg_end() &&
+ "Number of aggregate output arguments should match "
+ "the number of defined values");
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
- Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
- StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
+ StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
InsertBefore);
new StoreInst(outputs[i], GEP, InsertBefore);
+ ++aggIdx;
// Since there should be only one struct argument aggregating
- // all the output values, we shouldn't increment OAI, which always
- // points to the struct argument, in this case.
+ // all the output values, we shouldn't increment AggOutputArgBegin, which
+ // always points to the struct argument, in this case.
} else {
- new StoreInst(outputs[i], &*OAI, InsertBefore);
- ++OAI;
+ assert(ScalarOutputArgBegin != newFunction->arg_end() &&
+ "Number of scalar output arguments should match "
+ "the number of defined values");
+ new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
+ ++ScalarOutputArgBegin;
}
}
@@ -1835,3 +1872,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
}
return false;
}
+
+void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
+ ExcludeArgsFromAggregate.insert(Arg);
+}
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 5f7b0111c1c62..023a41c5bfd69 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -188,7 +188,7 @@ TEST(CodeExtractor, ExitBlockOrderingPhis) {
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
-
+
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
@@ -245,7 +245,7 @@ TEST(CodeExtractor, ExitBlockOrdering) {
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
-
+
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
@@ -504,4 +504,54 @@ TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
+
+TEST(CodeExtractor, PartialAggregateArgs) {
+ LLVMContext Ctx;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M(parseAssemblyString(R"ir(
+ target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+ target triple = "x86_64-unknown-linux-gnu"
+
+ declare void @use(i32)
+
+ define void @foo(i32 %a, i32 %b, i32 %c) {
+ entry:
+ br label %extract
+
+ extract:
+ call void @use(i32 %a)
+ call void @use(i32 %b)
+ call void @use(i32 %c)
+ br label %exit
+
+ exit:
+ ret void
+ }
+ )ir",
+ Err, Ctx));
+
+ Function *Func = M->getFunction("foo");
+ SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
+
+ // Create the CodeExtractor with arguments aggregation enabled.
+ CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
+ /* AggregateArgs */ true);
+ EXPECT_TRUE(CE.isEligible());
+
+ CodeExtractorAnalysisCache CEAC(*Func);
+ SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+ BasicBlock *CommonExit = nullptr;
+ CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+ CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
+ // Exclude the first input from the argument aggregate.
+ CE.excludeArgFromAggregate(Inputs[0]);
+
+ Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+ EXPECT_TRUE(Outlined);
+ // Expect 2 arguments in the outlined function: the excluded input and the
+ // struct aggregate for the remaining inputs.
+ EXPECT_EQ(Outlined->arg_size(), 2U);
+ EXPECT_FALSE(verifyFunction(*Outlined));
+ EXPECT_FALSE(verifyFunction(*Func));
+}
} // end anonymous namespace
More information about the llvm-commits
mailing list