[llvm] caa35a1 - [OpenMP][OpenMPIRBuilder] Make outlined function parameters i64 or ptr

Jan Sjodin via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 25 10:06:57 PDT 2023


Author: Jan Sjodin
Date: 2023-07-25T13:01:40-04:00
New Revision: caa35a1ad9afc45079830966258ae24ca8f5ec0a

URL: https://github.com/llvm/llvm-project/commit/caa35a1ad9afc45079830966258ae24ca8f5ec0a
DIFF: https://github.com/llvm/llvm-project/commit/caa35a1ad9afc45079830966258ae24ca8f5ec0a.diff

LOG: [OpenMP][OpenMPIRBuilder] Make outlined function parameters i64 or ptr

This patch ensures that all outlined functions parameters are i64 or ptr when
compiling for a target device, which is what the OpenMP runtime expects. The
values are then cast to the correct type inside the kernel.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D155628

Added: 
    

Modified: 
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
    mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 4c3696f9c342ab..0d64c21a6cc07f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4293,13 +4293,37 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
   return Builder.saveIP();
 }
 
+// Copy input from pointer or i64 to the expected argument type.
+static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace,
+                        Value *Input, Argument &Arg) {
+  auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy()
+                                       ? Arg.getType()
+                                       : Type::getInt64Ty(Builder.getContext()),
+                                   AddrSpace);
+  auto AddrAscast =
+      Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
+  Builder.CreateStore(&Arg, AddrAscast);
+  auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
+
+  return Copy;
+}
+
 static Function *
 createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                        StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
                        OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
   SmallVector<Type *> ParameterTypes;
-  for (auto &Arg : Inputs)
-    ParameterTypes.push_back(Arg->getType());
+  if (OMPBuilder.Config.isTargetDevice()) {
+    // All parameters to target devices are passed as pointers
+    // or i64. This assumes 64-bit address spaces/pointers.
+    for (auto &Arg : Inputs)
+      ParameterTypes.push_back(Arg->getType()->isPointerTy()
+                                   ? Arg->getType()
+                                   : Type::getInt64Ty(Builder.getContext()));
+  } else {
+    for (auto &Arg : Inputs)
+      ParameterTypes.push_back(Arg->getType());
+  }
 
   auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
                                     /*isVarArg*/ false);
@@ -4317,9 +4341,10 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   if (OMPBuilder.Config.isTargetDevice())
     Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
 
-  Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
+  BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
   // Insert target deinit call in the device compilation pass.
+  Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
   if (OMPBuilder.Config.isTargetDevice())
     OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false);
 
@@ -4327,15 +4352,23 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   Builder.CreateRetVoid();
 
   // Rewrite uses of input valus to parameters.
+  Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
   for (auto InArg : zip(Inputs, Func->args())) {
     Value *Input = std::get<0>(InArg);
     Argument &Arg = std::get<1>(InArg);
 
+    Value *InputCopy =
+        OMPBuilder.Config.isTargetDevice()
+            ? copyInput(Builder,
+                        OMPBuilder.M.getDataLayout().getAllocaAddrSpace(),
+                        Input, Arg)
+            : &Arg;
+
     // Collect all the instructions
     for (User *User : make_early_inc_range(Input->users()))
       if (auto Instr = dyn_cast<Instruction>(User))
         if (Instr->getFunction() == Func)
-          Instr->replaceUsesOfWith(Input, &Arg);
+          Instr->replaceUsesOfWith(Input, InputCopy);
   }
 
   // Restore insert point.

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 82ab4c9f4c93ab..5679d26ccc6eaa 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5122,16 +5122,18 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
 
+  LoadInst *Value = nullptr;
   StoreInst *TargetStore = nullptr;
   llvm::SmallVector<llvm::Value *, 2> CapturedArgs = {
-      Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)),
+      Constant::getNullValue(Type::getInt32PtrTy(Ctx)),
       Constant::getNullValue(Type::getInt32PtrTy(Ctx))};
 
   auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
                        OpenMPIRBuilder::InsertPointTy CodeGenIP)
       -> OpenMPIRBuilder::InsertPointTy {
     Builder.restoreIP(CodeGenIP);
-    TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]);
+    Value = Builder.CreateLoad(Type::getInt32Ty(Ctx), CapturedArgs[0]);
+    TargetStore = Builder.CreateStore(Value, CapturedArgs[1]);
     return Builder.saveIP();
   };
 
@@ -5155,7 +5157,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
   EXPECT_EQ(OutlinedFn->arg_size(), 2U);
   EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
-  EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32));
+  EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isPointerTy());
   EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
 
   // Check entry block
@@ -5180,8 +5182,22 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   // Check user code block
   auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
   EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
-  EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore);
-
+  auto *Alloca1 = UserCodeBlock->getFirstNonPHI();
+  EXPECT_TRUE(isa<AllocaInst>(Alloca1));
+  auto *Store1 = Alloca1->getNextNode();
+  EXPECT_TRUE(isa<StoreInst>(Store1));
+  auto *Load1 = Store1->getNextNode();
+  EXPECT_TRUE(isa<LoadInst>(Load1));
+  auto *Alloca2 = Load1->getNextNode();
+  EXPECT_TRUE(isa<AllocaInst>(Alloca2));
+  auto *Store2 = Alloca2->getNextNode();
+  EXPECT_TRUE(isa<StoreInst>(Store2));
+  auto *Load2 = Store2->getNextNode();
+  EXPECT_TRUE(isa<LoadInst>(Load2));
+
+  auto *Value1 = Load2->getNextNode();
+  EXPECT_EQ(Value1, Value);
+  EXPECT_EQ(Value1->getNextNode(), TargetStore);
   auto *Deinit = TargetStore->getNextNode();
   EXPECT_NE(Deinit, nullptr);
 

diff  --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index 126fff70ce3b1f..502ef9176b1118 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -30,12 +30,21 @@ module attributes {omp.is_target_device = true} {
 // CHECK-NEXT:   %[[CMP:.*]] = icmp eq i32 %3, -1
 // CHECK-NEXT:   br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]]
 // CHECK:        [[LABEL_ENTRY]]:
+// CHECK:        %[[TMP_A:.*]] = alloca ptr, align 8
+// CHECK:        store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8
+// CHECK:        %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8
+// CHECK:        %[[TMP_B:.*]] = alloca ptr, align 8
+// CHECK:        store ptr %[[ADDR_B]], ptr %[[TMP_B]], align 8
+// CHECK:        %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8
+// CHECK:        %[[TMP_C:.*]] = alloca ptr, align 8
+// CHECK:        store ptr %[[ADDR_C]], ptr %[[TMP_C]], align 8
+// CHECK:        %[[PTR_C:.*]] = load ptr, ptr %[[TMP_C]], align 8
 // CHECK-NEXT:   br label %[[LABEL_TARGET:.*]]
 // CHECK:        [[LABEL_TARGET]]:
-// CHECK:        %[[A:.*]] = load i32, ptr %[[ADDR_A]], align 4
-// CHECK:        %[[B:.*]] = load i32, ptr %[[ADDR_B]], align 4
+// CHECK:        %[[A:.*]] = load i32, ptr %[[PTR_A]], align 4
+// CHECK:        %[[B:.*]] = load i32, ptr %[[PTR_B]], align 4
 // CHECK:        %[[C:.*]] = add i32 %[[A]], %[[B]]
-// CHECK:        store i32 %[[C]], ptr %[[ADDR_C]], align 4
+// CHECK:        store i32 %[[C]], ptr %[[PTR_C]], align 4
 // CHECK:        br label %[[LABEL_DEINIT:.*]]
 // CHECK:        [[LABEL_DEINIT]]:
 // CHECK-NEXT:   call void @__kmpc_target_deinit(ptr @[[IDENT]], i8 1)


        


More information about the llvm-commits mailing list