[llvm] caa35a1 - [OpenMP][OpenMPIRBuilder] Make outlined function parameters i64 or ptr
Jan Sjodin via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 25 10:06:57 PDT 2023
Author: Jan Sjodin
Date: 2023-07-25T13:01:40-04:00
New Revision: caa35a1ad9afc45079830966258ae24ca8f5ec0a
URL: https://github.com/llvm/llvm-project/commit/caa35a1ad9afc45079830966258ae24ca8f5ec0a
DIFF: https://github.com/llvm/llvm-project/commit/caa35a1ad9afc45079830966258ae24ca8f5ec0a.diff
LOG: [OpenMP][OpenMPIRBuilder] Make outlined function parameters i64 or ptr
This patch ensures that all outlined functions parameters are i64 or ptr when
compiling for a target device, which is what the OpenMP runtime expects. The
values are then cast to the correct type inside the kernel.
Reviewed By: jdoerfert
Differential Revision: https://reviews.llvm.org/D155628
Added:
Modified:
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
Removed:
################################################################################
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 4c3696f9c342ab..0d64c21a6cc07f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4293,13 +4293,37 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
return Builder.saveIP();
}
+// Copy input from pointer or i64 to the expected argument type.
+static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace,
+ Value *Input, Argument &Arg) {
+ auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy()
+ ? Arg.getType()
+ : Type::getInt64Ty(Builder.getContext()),
+ AddrSpace);
+ auto AddrAscast =
+ Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
+ Builder.CreateStore(&Arg, AddrAscast);
+ auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
+
+ return Copy;
+}
+
static Function *
createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
SmallVector<Type *> ParameterTypes;
- for (auto &Arg : Inputs)
- ParameterTypes.push_back(Arg->getType());
+ if (OMPBuilder.Config.isTargetDevice()) {
+ // All parameters to target devices are passed as pointers
+ // or i64. This assumes 64-bit address spaces/pointers.
+ for (auto &Arg : Inputs)
+ ParameterTypes.push_back(Arg->getType()->isPointerTy()
+ ? Arg->getType()
+ : Type::getInt64Ty(Builder.getContext()));
+ } else {
+ for (auto &Arg : Inputs)
+ ParameterTypes.push_back(Arg->getType());
+ }
auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
/*isVarArg*/ false);
@@ -4317,9 +4341,10 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
- Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
+ BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
// Insert target deinit call in the device compilation pass.
+ Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
if (OMPBuilder.Config.isTargetDevice())
OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false);
@@ -4327,15 +4352,23 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Builder.CreateRetVoid();
// Rewrite uses of input valus to parameters.
+ Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
for (auto InArg : zip(Inputs, Func->args())) {
Value *Input = std::get<0>(InArg);
Argument &Arg = std::get<1>(InArg);
+ Value *InputCopy =
+ OMPBuilder.Config.isTargetDevice()
+ ? copyInput(Builder,
+ OMPBuilder.M.getDataLayout().getAllocaAddrSpace(),
+ Input, Arg)
+ : &Arg;
+
// Collect all the instructions
for (User *User : make_early_inc_range(Input->users()))
if (auto Instr = dyn_cast<Instruction>(User))
if (Instr->getFunction() == Func)
- Instr->replaceUsesOfWith(Input, &Arg);
+ Instr->replaceUsesOfWith(Input, InputCopy);
}
// Restore insert point.
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 82ab4c9f4c93ab..5679d26ccc6eaa 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5122,16 +5122,18 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+ LoadInst *Value = nullptr;
StoreInst *TargetStore = nullptr;
llvm::SmallVector<llvm::Value *, 2> CapturedArgs = {
- Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)),
+ Constant::getNullValue(Type::getInt32PtrTy(Ctx)),
Constant::getNullValue(Type::getInt32PtrTy(Ctx))};
auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP)
-> OpenMPIRBuilder::InsertPointTy {
Builder.restoreIP(CodeGenIP);
- TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]);
+ Value = Builder.CreateLoad(Type::getInt32Ty(Ctx), CapturedArgs[0]);
+ TargetStore = Builder.CreateStore(Value, CapturedArgs[1]);
return Builder.saveIP();
};
@@ -5155,7 +5157,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
EXPECT_EQ(OutlinedFn->arg_size(), 2U);
EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
- EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32));
+ EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isPointerTy());
EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
// Check entry block
@@ -5180,8 +5182,22 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
// Check user code block
auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
- EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore);
-
+ auto *Alloca1 = UserCodeBlock->getFirstNonPHI();
+ EXPECT_TRUE(isa<AllocaInst>(Alloca1));
+ auto *Store1 = Alloca1->getNextNode();
+ EXPECT_TRUE(isa<StoreInst>(Store1));
+ auto *Load1 = Store1->getNextNode();
+ EXPECT_TRUE(isa<LoadInst>(Load1));
+ auto *Alloca2 = Load1->getNextNode();
+ EXPECT_TRUE(isa<AllocaInst>(Alloca2));
+ auto *Store2 = Alloca2->getNextNode();
+ EXPECT_TRUE(isa<StoreInst>(Store2));
+ auto *Load2 = Store2->getNextNode();
+ EXPECT_TRUE(isa<LoadInst>(Load2));
+
+ auto *Value1 = Load2->getNextNode();
+ EXPECT_EQ(Value1, Value);
+ EXPECT_EQ(Value1->getNextNode(), TargetStore);
auto *Deinit = TargetStore->getNextNode();
EXPECT_NE(Deinit, nullptr);
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index 126fff70ce3b1f..502ef9176b1118 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -30,12 +30,21 @@ module attributes {omp.is_target_device = true} {
// CHECK-NEXT: %[[CMP:.*]] = icmp eq i32 %3, -1
// CHECK-NEXT: br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]]
// CHECK: [[LABEL_ENTRY]]:
+// CHECK: %[[TMP_A:.*]] = alloca ptr, align 8
+// CHECK: store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8
+// CHECK: %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8
+// CHECK: %[[TMP_B:.*]] = alloca ptr, align 8
+// CHECK: store ptr %[[ADDR_B]], ptr %[[TMP_B]], align 8
+// CHECK: %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8
+// CHECK: %[[TMP_C:.*]] = alloca ptr, align 8
+// CHECK: store ptr %[[ADDR_C]], ptr %[[TMP_C]], align 8
+// CHECK: %[[PTR_C:.*]] = load ptr, ptr %[[TMP_C]], align 8
// CHECK-NEXT: br label %[[LABEL_TARGET:.*]]
// CHECK: [[LABEL_TARGET]]:
-// CHECK: %[[A:.*]] = load i32, ptr %[[ADDR_A]], align 4
-// CHECK: %[[B:.*]] = load i32, ptr %[[ADDR_B]], align 4
+// CHECK: %[[A:.*]] = load i32, ptr %[[PTR_A]], align 4
+// CHECK: %[[B:.*]] = load i32, ptr %[[PTR_B]], align 4
// CHECK: %[[C:.*]] = add i32 %[[A]], %[[B]]
-// CHECK: store i32 %[[C]], ptr %[[ADDR_C]], align 4
+// CHECK: store i32 %[[C]], ptr %[[PTR_C]], align 4
// CHECK: br label %[[LABEL_DEINIT:.*]]
// CHECK: [[LABEL_DEINIT]]:
// CHECK-NEXT: call void @__kmpc_target_deinit(ptr @[[IDENT]], i8 1)
More information about the llvm-commits
mailing list