[llvm] 06b3f60 - [Flang][OpenMP][OMPIRBuilder] Add lowering of TargetOp for device codegen

Sergio Afonso via llvm-commits llvm-commits at lists.llvm.org
Thu May 18 07:20:03 PDT 2023


Author: Sergio Afonso
Date: 2023-05-18T15:14:03+01:00
New Revision: 06b3f6086c7b46c4caccb868b0f7d09702848222

URL: https://github.com/llvm/llvm-project/commit/06b3f6086c7b46c4caccb868b0f7d09702848222
DIFF: https://github.com/llvm/llvm-project/commit/06b3f6086c7b46c4caccb868b0f7d09702848222.diff

LOG: [Flang][OpenMP][OMPIRBuilder] Add lowering of TargetOp for device codegen

This patch adds support in the `OpenMPIRBuilder` for generating working
device code for OpenMP target regions. It generates and handles the
result of a call to `__kmpc_target_init()` at the beginning of the
function resulting from outlining each target region, and it also
generates the matching `__kmpc_target_deinit()` call before returning.

It relies on the implementation of target region outlining for host
codegen to handle the production of the new function and the lowering of
its body based on the contents of the associated target region.

Depends on D147172

Differential Revision: https://reviews.llvm.org/D147940

Added: 
    mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir

Modified: 
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 99f454ecc50aa..4b4a71062e93d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4148,8 +4148,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
 }
 
 static Function *
-createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName,
-                       SmallVectorImpl<Value *> &Inputs,
+createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                       StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
                        OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
   SmallVector<Type *> ParameterTypes;
   for (auto &Arg : Inputs)
@@ -4166,8 +4166,17 @@ createOutlinedFunction(IRBuilderBase &Builder, StringRef FuncName,
   // Generate the region into the function.
   BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
   Builder.SetInsertPoint(EntryBB);
+
+  // Insert target init call in the device compilation pass.
+  if (OMPBuilder.Config.isEmbedded())
+    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+
   Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
 
+  // Insert target deinit call in the device compilation pass.
+  if (OMPBuilder.Config.isEmbedded())
+    OMPBuilder.createTargetDeinit(Builder, /*IsSPMD*/ false);
+
   // Insert return instruction.
   Builder.CreateRetVoid();
 
@@ -4197,8 +4206,9 @@ emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                            OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
-        return createOutlinedFunction(Builder, EntryFnName, Inputs, CBFunc);
+      [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
+        return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
+                                      CBFunc);
       };
 
   Constant *OutlinedFnID;
@@ -4209,7 +4219,7 @@ emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
 static void emitTargetCall(IRBuilderBase &Builder, Function *OutlinedFn,
                            SmallVectorImpl<Value *> &Args) {
-  // TODO: Add kernel launch call when device codegen is supported.
+  // TODO: Add kernel launch call
   Builder.CreateCall(OutlinedFn, Args);
 }
 
@@ -4225,7 +4235,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
   Function *OutlinedFn;
   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn, NumTeams,
                              NumThreads, Args, CBFunc);
-  emitTargetCall(Builder, OutlinedFn, Args);
+  if (!Config.isEmbedded())
+    emitTargetCall(Builder, OutlinedFn, Args);
   return Builder.saveIP();
 }
 

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 11637f44d97b5..d4c0bf8aab6a8 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/DIBuilder.h"
@@ -5175,6 +5176,94 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.setConfig(OpenMPIRBuilderConfig(true, false, false, false));
+  OMPBuilder.initialize();
+
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  StoreInst *TargetStore = nullptr;
+  llvm::SmallVector<llvm::Value *, 2> CapturedArgs = {
+      Constant::getIntegerValue(Type::getInt32Ty(Ctx), APInt(32, 0)),
+      Constant::getNullValue(Type::getInt32PtrTy(Ctx))};
+
+  auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy AllocaIP,
+                       OpenMPIRBuilder::InsertPointTy CodeGenIP)
+      -> OpenMPIRBuilder::InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    TargetStore = Builder.CreateStore(CapturedArgs[0], CapturedArgs[1]);
+    return Builder.saveIP();
+  };
+
+  IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+                                   F->getEntryBlock().getFirstInsertionPt());
+  TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+                                  /*Line=*/3, /*Count=*/0);
+
+  Builder.restoreIP(
+      OMPBuilder.createTarget(Loc, EntryIP, EntryInfo, /*NumTeams=*/-1,
+                              /*NumThreads=*/-1, CapturedArgs, BodyGenCB));
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  // Check outlined function
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  EXPECT_NE(TargetStore, nullptr);
+  Function *OutlinedFn = TargetStore->getFunction();
+  EXPECT_NE(F, OutlinedFn);
+
+  EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+  EXPECT_EQ(OutlinedFn->arg_size(), 2U);
+  EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+  EXPECT_TRUE(OutlinedFn->getArg(0)->getType()->isIntegerTy(32));
+  EXPECT_TRUE(OutlinedFn->getArg(1)->getType()->isPointerTy());
+
+  // Check entry block
+  auto &EntryBlock = OutlinedFn->getEntryBlock();
+  Instruction *Init = EntryBlock.getFirstNonPHI();
+  EXPECT_NE(Init, nullptr);
+
+  auto *InitCall = dyn_cast<CallInst>(Init);
+  EXPECT_NE(InitCall, nullptr);
+  EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init");
+  EXPECT_EQ(InitCall->arg_size(), 3U);
+  EXPECT_TRUE(isa<GlobalVariable>(InitCall->getArgOperand(0)));
+  EXPECT_EQ(InitCall->getArgOperand(1),
+            ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
+  EXPECT_EQ(InitCall->getArgOperand(2),
+            ConstantInt::get(Type::getInt1Ty(Ctx), true));
+
+  auto *EntryBlockBranch = EntryBlock.getTerminator();
+  EXPECT_NE(EntryBlockBranch, nullptr);
+  EXPECT_EQ(EntryBlockBranch->getNumSuccessors(), 2U);
+
+  // Check user code block
+  auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
+  EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
+  EXPECT_EQ(UserCodeBlock->getFirstNonPHI(), TargetStore);
+
+  auto *Deinit = TargetStore->getNextNode();
+  EXPECT_NE(Deinit, nullptr);
+
+  auto *DeinitCall = dyn_cast<CallInst>(Deinit);
+  EXPECT_NE(DeinitCall, nullptr);
+  EXPECT_EQ(DeinitCall->getCalledFunction()->getName(), "__kmpc_target_deinit");
+  EXPECT_EQ(DeinitCall->arg_size(), 2U);
+  EXPECT_TRUE(isa<GlobalVariable>(DeinitCall->getArgOperand(0)));
+  EXPECT_EQ(DeinitCall->getArgOperand(1),
+            ConstantInt::get(Type::getInt8Ty(Ctx), OMP_TGT_EXEC_MODE_GENERIC));
+
+  EXPECT_TRUE(isa<ReturnInst>(DeinitCall->getNextNode()));
+
+  // Check exit block
+  auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
+  EXPECT_EQ(ExitBlock->getName(), "worker.exit");
+  EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateTask) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 750f7157bf53f..27e3cec7c786a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1629,15 +1629,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (!targetOpSupported(opInst))
     return failure();
 
-  bool isDevice = false;
-  if (auto offloadMod = dyn_cast<mlir::omp::OffloadModuleInterface>(
-          opInst.getParentOfType<mlir::ModuleOp>().getOperation())) {
-    isDevice = offloadMod.getIsDevice();
-  }
-
-  if (isDevice) // TODO: Implement device codegen.
-    return success();
-
   auto targetOp = cast<omp::TargetOp>(opInst);
   auto &targetRegion = targetOp.getRegion();
 

diff  --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
new file mode 100644
index 0000000000000..3e385e0d7d367
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -0,0 +1,44 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+module attributes {omp.is_device = #omp.isdevice<is_device = true>} {
+  llvm.func @omp_target_region_() {
+    %0 = llvm.mlir.constant(20 : i32) : i32
+    %1 = llvm.mlir.constant(10 : i32) : i32
+    %2 = llvm.mlir.constant(1 : i64) : i64
+    %3 = llvm.alloca %2 x i32 {bindc_name = "a", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEa"} : (i64) -> !llvm.ptr<i32>
+    %4 = llvm.mlir.constant(1 : i64) : i64
+    %5 = llvm.alloca %4 x i32 {bindc_name = "b", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEb"} : (i64) -> !llvm.ptr<i32>
+    %6 = llvm.mlir.constant(1 : i64) : i64
+    %7 = llvm.alloca %6 x i32 {bindc_name = "c", in_type = i32, operand_segment_sizes = array<i32: 0, 0>, uniq_name = "_QFomp_target_regionEc"} : (i64) -> !llvm.ptr<i32>
+    llvm.store %1, %3 : !llvm.ptr<i32>
+    llvm.store %0, %5 : !llvm.ptr<i32>
+    omp.target   {
+      %8 = llvm.load %3 : !llvm.ptr<i32>
+      %9 = llvm.load %5 : !llvm.ptr<i32>
+      %10 = llvm.add %8, %9  : i32
+      llvm.store %10, %7 : !llvm.ptr<i32>
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// CHECK:      @[[SRC_LOC:.*]] = private unnamed_addr constant [23 x i8] c"{{[^"]*}}", align 1
+// CHECK:      @[[IDENT:.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @[[SRC_LOC]] }, align 8
+// CHECK:      define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
+// CHECK:        %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[IDENT]], i8 1, i1 true)
+// CHECK-NEXT:   %[[CMP:.*]] = icmp eq i32 %3, -1
+// CHECK-NEXT:   br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]]
+// CHECK:        [[LABEL_ENTRY]]:
+// CHECK-NEXT:   br label %[[LABEL_TARGET:.*]]
+// CHECK:        [[LABEL_TARGET]]:
+// CHECK:        %[[A:.*]] = load i32, ptr %[[ADDR_A]], align 4
+// CHECK:        %[[B:.*]] = load i32, ptr %[[ADDR_B]], align 4
+// CHECK:        %[[C:.*]] = add i32 %[[A]], %[[B]]
+// CHECK:        store i32 %[[C]], ptr %[[ADDR_C]], align 4
+// CHECK:        br label %[[LABEL_DEINIT:.*]]
+// CHECK:        [[LABEL_DEINIT]]:
+// CHECK-NEXT:   call void @__kmpc_target_deinit(ptr @[[IDENT]], i8 1)
+// CHECK-NEXT:   ret void
+// CHECK:        [[LABEL_EXIT]]:
+// CHECK-NEXT:   ret void


        


More information about the llvm-commits mailing list