[Mlir-commits] [mlir] [OpenMP][OpenMPIRBuilder][NFC] Move copyInput to a passed in lambda function (PR #68124)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 3 19:02:22 PDT 2023


https://github.com/agozillon updated https://github.com/llvm/llvm-project/pull/68124

>From 4cf87ac5f73aa96d51f297c6483ee4cfad571276 Mon Sep 17 00:00:00 2001
From: Andrew Gozillon <Andrew.Gozillon at amd.com>
Date: Tue, 3 Oct 2023 10:48:03 -0500
Subject: [PATCH 1/2] [OpenMP][OpenMPIRBuilder][NFC] Move copyInput to a passed
 in lambda function

This patch moves the existing copyInput function
into a lambda argument that can be defined
by a caller to the function.

This allows more flexibility in how the function
is defined, allowing Clang and Fortran to utilise
their own respective functions and types inside
of the lamba without affecting the OMPIRBuilder
itself.

The idea is to eventually replace/build on
the existing copyInput function that's used
and moved into OpenMPToLLVMIRTranslation.cpp
to a slightly more complex implementation
that uses Flang's map information (primarily
ByRef and ByCapture information at the
moment).

For now this should be an NFC as far as
lowering behavior is concerned.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  8 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 57 +++++++------------
 .../Frontend/OpenMPIRBuilderTest.cpp          | 48 ++++++++++++++--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 40 ++++++++++++-
 4 files changed, 109 insertions(+), 44 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 75da461cfd8d95e..e3f1cddb72fa03d 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2166,6 +2166,9 @@ class OpenMPIRBuilder {
   using TargetBodyGenCallbackTy = function_ref<InsertPointTy(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
 
+  using TargetGenArgAccessorsCallbackTy = function_ref<Value *(
+      Argument &Arg, Value *Input, IRBuilderBase &Builder)>;
+
   /// Generator for '#omp target'
   ///
   /// \param Loc where the target data construct was encountered.
@@ -2177,6 +2180,8 @@ class OpenMPIRBuilder {
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
+  /// \param ArgAccessorFuncCB Callback that will generate accessors
+  /// instructions for passed in target arguments where neccessary
   InsertPointTy createTarget(const LocationDescription &Loc,
                              OpenMPIRBuilder::InsertPointTy AllocaIP,
                              OpenMPIRBuilder::InsertPointTy CodeGenIP,
@@ -2184,7 +2189,8 @@ class OpenMPIRBuilder {
                              int32_t NumThreads,
                              SmallVectorImpl<Value *> &Inputs,
                              GenMapInfoCallbackTy GenMapInfoCB,
-                             TargetBodyGenCallbackTy BodyGenCB);
+                             TargetBodyGenCallbackTy BodyGenCB,
+                             TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB);
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 72e1af55fe63f60..bcb5d07bbf7edc8 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4539,25 +4539,11 @@ FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
   return getOrCreateRuntimeFunction(M, Name);
 }
 
-// Copy input from pointer or i64 to the expected argument type.
-static Value *copyInput(IRBuilderBase &Builder, unsigned AddrSpace,
-                        Value *Input, Argument &Arg) {
-  auto Addr = Builder.CreateAlloca(Arg.getType()->isPointerTy()
-                                       ? Arg.getType()
-                                       : Type::getInt64Ty(Builder.getContext()),
-                                   AddrSpace);
-  auto AddrAscast =
-      Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
-  Builder.CreateStore(&Arg, AddrAscast);
-  auto Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
-
-  return Copy;
-}
-
-static Function *
-createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-                       StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
-                       OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
+static Function *createOutlinedFunction(
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
+    SmallVectorImpl<Value *> &Inputs,
+    OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
+    OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
   SmallVector<Type *> ParameterTypes;
   if (OMPBuilder.Config.isTargetDevice()) {
     // All parameters to target devices are passed as pointers
@@ -4603,12 +4589,7 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     Value *Input = std::get<0>(InArg);
     Argument &Arg = std::get<1>(InArg);
 
-    Value *InputCopy =
-        OMPBuilder.Config.isTargetDevice()
-            ? copyInput(Builder,
-                        OMPBuilder.M.getDataLayout().getAllocaAddrSpace(),
-                        Input, Arg)
-            : &Arg;
+    Value *InputCopy = ArgAccessorFuncCB(Arg, Input, Builder);
 
     // Collect all the instructions
     for (User *User : make_early_inc_range(Input->users()))
@@ -4623,18 +4604,19 @@ createOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   return Func;
 }
 
-static void
-emitTargetOutlinedFunction(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-                           TargetRegionEntryInfo &EntryInfo,
-                           Function *&OutlinedFn, Constant *&OutlinedFnID,
-                           int32_t NumTeams, int32_t NumThreads,
-                           SmallVectorImpl<Value *> &Inputs,
-                           OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc) {
+static void emitTargetOutlinedFunction(
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
+    Constant *&OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
+    SmallVectorImpl<Value *> &Inputs,
+    OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
+    OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&OMPBuilder, &Builder, &Inputs, &CBFunc](StringRef EntryFnName) {
+      [&OMPBuilder, &Builder, &Inputs, &CBFunc,
+       &ArgAccessorFuncCB](StringRef EntryFnName) {
         return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
-                                      CBFunc);
+                                      CBFunc, ArgAccessorFuncCB);
       };
 
   OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
@@ -4698,7 +4680,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
     int32_t NumThreads, SmallVectorImpl<Value *> &Args,
-    GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy CBFunc) {
+    GenMapInfoCallbackTy GenMapInfoCB,
+    OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
+    OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB) {
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -4707,7 +4691,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
   Function *OutlinedFn;
   Constant *OutlinedFnID;
   emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
-                             OutlinedFnID, NumTeams, NumThreads, Args, CBFunc);
+                             OutlinedFnID, NumTeams, NumThreads, Args, CBFunc,
+                             ArgAccessorFuncCB);
   if (!Config.isTargetDevice())
     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
                    NumThreads, Args, GenMapInfoCB);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..81733f1a2790287 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5236,6 +5236,24 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   Inputs.push_back(BPtr);
   Inputs.push_back(CPtr);
 
+  auto SimpleArgAccessorCB = [&OMPBuilder](llvm::Argument &Arg,
+                                           llvm::Value *Input,
+                                           IRBuilderBase &Builder) {
+    if (!OMPBuilder.Config.isTargetDevice())
+      return cast<llvm::Value>(&Arg);
+
+    llvm::Value *Addr = Builder.CreateAlloca(
+        Arg.getType()->isPointerTy() ? Arg.getType()
+                                     : Type::getInt64Ty(Builder.getContext()),
+        OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
+    llvm::Value *AddrAscast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
+    Builder.CreateStore(&Arg, AddrAscast);
+    llvm::Value *Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
+
+    return Copy;
+  };
+
   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
       -> llvm::OpenMPIRBuilder::MapInfosTy & {
@@ -5245,9 +5263,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
-  Builder.restoreIP(OMPBuilder.createTarget(OmpLoc, Builder.saveIP(),
-                                            Builder.saveIP(), EntryInfo, -1, 0,
-                                            Inputs, GenMapInfoCB, BodyGenCB));
+  Builder.restoreIP(OMPBuilder.createTarget(
+      OmpLoc, Builder.saveIP(), Builder.saveIP(), EntryInfo, -1, 0, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   OMPBuilder.finalize();
   Builder.CreateRetVoid();
 
@@ -5301,6 +5319,23 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
       Constant::getNullValue(PointerType::get(Ctx, 0)),
       Constant::getNullValue(PointerType::get(Ctx, 0))};
 
+  auto SimpleArgAccessorCB = [&](llvm::Argument &Arg, llvm::Value *Input,
+                                 IRBuilderBase &Builder) {
+    if (!OMPBuilder.Config.isTargetDevice())
+      return cast<llvm::Value>(&Arg);
+
+    llvm::Value *Addr = Builder.CreateAlloca(
+        Arg.getType()->isPointerTy() ? Arg.getType()
+                                     : Type::getInt64Ty(Builder.getContext()),
+        OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
+    llvm::Value *AddrAscast =
+        Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
+    Builder.CreateStore(&Arg, AddrAscast);
+    llvm::Value *Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
+
+    return Copy;
+  };
+
   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
       -> llvm::OpenMPIRBuilder::MapInfosTy & {
@@ -5322,9 +5357,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
 
-  Builder.restoreIP(OMPBuilder.createTarget(
-      Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
-      /*NumThreads=*/0, CapturedArgs, GenMapInfoCB, BodyGenCB));
+  Builder.restoreIP(
+      OMPBuilder.createTarget(Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
+                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
+                              BodyGenCB, SimpleArgAccessorCB));
 
   Builder.CreateRetVoid();
   OMPBuilder.finalize();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 14bcbc3018f72bd..8326a20a15c4629 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1993,6 +1993,27 @@ handleDeclareTargetMapVar(llvm::ArrayRef<Value> mapOperands,
   }
 }
 
+static llvm::Value *
+createDeviceArgumentAccessor(llvm::Argument &arg, llvm::Value *input,
+                             llvm::IRBuilderBase &builder,
+                             llvm::OpenMPIRBuilder &ompBuilder,
+                             LLVM::ModuleTranslation &moduleTranslation) {
+  if (!ompBuilder.Config.isTargetDevice())
+    return cast<llvm::Value>(&arg);
+
+  llvm::Value *addr =
+      builder.CreateAlloca(arg.getType()->isPointerTy()
+                               ? arg.getType()
+                               : llvm::Type::getInt64Ty(builder.getContext()),
+                           ompBuilder.M.getDataLayout().getAllocaAddrSpace());
+  llvm::Value *addrAscast =
+      builder.CreatePointerBitCastOrAddrSpaceCast(addr, input->getType());
+  builder.CreateStore(&arg, addrAscast);
+  llvm::Value *copy = builder.CreateLoad(arg.getType(), addrAscast);
+
+  return copy;
+}
+
 static LogicalResult
 convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
@@ -2084,9 +2105,26 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     return combinedInfos;
   };
 
+  auto argAccessorCB = [&moduleTranslation](llvm::Argument &arg,
+                                            llvm::Value *input,
+                                            llvm::IRBuilderBase &builder) {
+    llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
+    // We just return the unaltered argument for the host function
+    // for now, some alterations may be required in the future to
+    // keep host fallback functions working identically to the device
+    // version (e.g. pass ByCopy values should be treated as such on
+    // host and device, currently not always the case)
+    if (!ompBuilder->Config.isTargetDevice())
+      return cast<llvm::Value>(&arg);
+
+    return createDeviceArgumentAccessor(arg, input, builder, *ompBuilder,
+                                        moduleTranslation);
+  };
+
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
       ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams,
-      defaultValThreads, inputs, genMapInfoCB, bodyCB));
+      defaultValThreads, inputs, genMapInfoCB, bodyCB, argAccessorCB));
 
   // Remap access operations to declare target reference pointers for the
   // device, essentially generating extra loadop's as necessary

>From e4ce1cb72afadffa9a93e20bca260941279d05e0 Mon Sep 17 00:00:00 2001
From: Andrew Gozillon <Andrew.Gozillon at amd.com>
Date: Tue, 3 Oct 2023 20:56:38 -0500
Subject: [PATCH 2/2] Convert to utilising InsertPoints and test Alloca/CodeGen
 points that result in more clang like results.

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   5 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |  11 +-
 .../Frontend/OpenMPIRBuilderTest.cpp          | 111 ++++++++++--------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  35 +++---
 .../LLVMIR/omptarget-region-device-llvm.mlir  |  12 +-
 5 files changed, 104 insertions(+), 70 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index e3f1cddb72fa03d..b505daae7e75f80 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2166,8 +2166,9 @@ class OpenMPIRBuilder {
   using TargetBodyGenCallbackTy = function_ref<InsertPointTy(
       InsertPointTy AllocaIP, InsertPointTy CodeGenIP)>;
 
-  using TargetGenArgAccessorsCallbackTy = function_ref<Value *(
-      Argument &Arg, Value *Input, IRBuilderBase &Builder)>;
+  using TargetGenArgAccessorsCallbackTy = function_ref<InsertPointTy(
+      Argument &Arg, Value *Input, Value *&RetVal, InsertPointTy AllocaIP,
+      InsertPointTy CodeGenIP)>;
 
   /// Generator for '#omp target'
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index bcb5d07bbf7edc8..c95dbbe996660e3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4583,13 +4583,20 @@ static Function *createOutlinedFunction(
   // Insert return instruction.
   Builder.CreateRetVoid();
 
-  // Rewrite uses of input valus to parameters.
+  // New Alloca IP at entry point of created device function.
+  Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
+  auto AllocaIP = Builder.saveIP();
+
   Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
+
+  // Rewrite uses of input valus to parameters.
   for (auto InArg : zip(Inputs, Func->args())) {
     Value *Input = std::get<0>(InArg);
     Argument &Arg = std::get<1>(InArg);
+    Value *InputCopy = nullptr;
 
-    Value *InputCopy = ArgAccessorFuncCB(Arg, Input, Builder);
+    Builder.restoreIP(
+        ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
 
     // Collect all the instructions
     for (User *User : make_early_inc_range(Input->users()))
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 81733f1a2790287..9bdd9f371008c80 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5236,23 +5236,32 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   Inputs.push_back(BPtr);
   Inputs.push_back(CPtr);
 
-  auto SimpleArgAccessorCB = [&OMPBuilder](llvm::Argument &Arg,
-                                           llvm::Value *Input,
-                                           IRBuilderBase &Builder) {
-    if (!OMPBuilder.Config.isTargetDevice())
-      return cast<llvm::Value>(&Arg);
-
-    llvm::Value *Addr = Builder.CreateAlloca(
-        Arg.getType()->isPointerTy() ? Arg.getType()
-                                     : Type::getInt64Ty(Builder.getContext()),
-        OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
-    llvm::Value *AddrAscast =
-        Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
-    Builder.CreateStore(&Arg, AddrAscast);
-    llvm::Value *Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
-
-    return Copy;
-  };
+  auto SimpleArgAccessorCB =
+      [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
+          llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+        if (!OMPBuilder.Config.isTargetDevice()) {
+          RetVal = cast<llvm::Value>(&Arg);
+          return CodeGenIP;
+        }
+
+        Builder.restoreIP(AllocaIP);
+
+        llvm::Value *Addr = Builder.CreateAlloca(
+            Arg.getType()->isPointerTy()
+                ? Arg.getType()
+                : Type::getInt64Ty(Builder.getContext()),
+            OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
+        llvm::Value *AddrAscast =
+            Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
+        Builder.CreateStore(&Arg, AddrAscast);
+
+        Builder.restoreIP(CodeGenIP);
+
+        RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast);
+
+        return Builder.saveIP();
+      };
 
   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
@@ -5319,22 +5328,32 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
       Constant::getNullValue(PointerType::get(Ctx, 0)),
       Constant::getNullValue(PointerType::get(Ctx, 0))};
 
-  auto SimpleArgAccessorCB = [&](llvm::Argument &Arg, llvm::Value *Input,
-                                 IRBuilderBase &Builder) {
-    if (!OMPBuilder.Config.isTargetDevice())
-      return cast<llvm::Value>(&Arg);
-
-    llvm::Value *Addr = Builder.CreateAlloca(
-        Arg.getType()->isPointerTy() ? Arg.getType()
-                                     : Type::getInt64Ty(Builder.getContext()),
-        OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
-    llvm::Value *AddrAscast =
-        Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
-    Builder.CreateStore(&Arg, AddrAscast);
-    llvm::Value *Copy = Builder.CreateLoad(Arg.getType(), AddrAscast);
-
-    return Copy;
-  };
+  auto SimpleArgAccessorCB =
+      [&](llvm::Argument &Arg, llvm::Value *Input, llvm::Value *&RetVal,
+          llvm::OpenMPIRBuilder::InsertPointTy AllocaIP,
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+        if (!OMPBuilder.Config.isTargetDevice()) {
+          RetVal = cast<llvm::Value>(&Arg);
+          return CodeGenIP;
+        }
+
+        Builder.restoreIP(AllocaIP);
+
+        llvm::Value *Addr = Builder.CreateAlloca(
+            Arg.getType()->isPointerTy()
+                ? Arg.getType()
+                : Type::getInt64Ty(Builder.getContext()),
+            OMPBuilder.M.getDataLayout().getAllocaAddrSpace());
+        llvm::Value *AddrAscast =
+            Builder.CreatePointerBitCastOrAddrSpaceCast(Addr, Input->getType());
+        Builder.CreateStore(&Arg, AddrAscast);
+
+        Builder.restoreIP(CodeGenIP);
+
+        RetVal = Builder.CreateLoad(Arg.getType(), AddrAscast);
+
+        return Builder.saveIP();
+      };
 
   llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
   auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP)
@@ -5379,10 +5398,18 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
 
   // Check entry block
   auto &EntryBlock = OutlinedFn->getEntryBlock();
-  Instruction *Init = EntryBlock.getFirstNonPHI();
-  EXPECT_NE(Init, nullptr);
+  Instruction *Alloca1 = EntryBlock.getFirstNonPHI();
+  EXPECT_NE(Alloca1, nullptr);
 
-  auto *InitCall = dyn_cast<CallInst>(Init);
+  EXPECT_TRUE(isa<AllocaInst>(Alloca1));
+  auto *Store1 = Alloca1->getNextNode();
+  EXPECT_TRUE(isa<StoreInst>(Store1));
+  auto *Alloca2 = Store1->getNextNode();
+  EXPECT_TRUE(isa<AllocaInst>(Alloca2));
+  auto *Store2 = Alloca2->getNextNode();
+  EXPECT_TRUE(isa<StoreInst>(Store2));
+
+  auto *InitCall = dyn_cast<CallInst>(Store2->getNextNode());
   EXPECT_NE(InitCall, nullptr);
   EXPECT_EQ(InitCall->getCalledFunction()->getName(), "__kmpc_target_init");
   EXPECT_EQ(InitCall->arg_size(), 1U);
@@ -5406,17 +5433,9 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   // Check user code block
   auto *UserCodeBlock = EntryBlockBranch->getSuccessor(0);
   EXPECT_EQ(UserCodeBlock->getName(), "user_code.entry");
-  auto *Alloca1 = UserCodeBlock->getFirstNonPHI();
-  EXPECT_TRUE(isa<AllocaInst>(Alloca1));
-  auto *Store1 = Alloca1->getNextNode();
-  EXPECT_TRUE(isa<StoreInst>(Store1));
-  auto *Load1 = Store1->getNextNode();
+  auto *Load1 = UserCodeBlock->getFirstNonPHI();
   EXPECT_TRUE(isa<LoadInst>(Load1));
-  auto *Alloca2 = Load1->getNextNode();
-  EXPECT_TRUE(isa<AllocaInst>(Alloca2));
-  auto *Store2 = Alloca2->getNextNode();
-  EXPECT_TRUE(isa<StoreInst>(Store2));
-  auto *Load2 = Store2->getNextNode();
+  auto *Load2 = Load1->getNextNode();
   EXPECT_TRUE(isa<LoadInst>(Load2));
 
   auto *Value1 = Load2->getNextNode();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8326a20a15c4629..c5566a011a3b0c6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1993,13 +1993,14 @@ handleDeclareTargetMapVar(llvm::ArrayRef<Value> mapOperands,
   }
 }
 
-static llvm::Value *
+static llvm::IRBuilderBase::InsertPoint
 createDeviceArgumentAccessor(llvm::Argument &arg, llvm::Value *input,
-                             llvm::IRBuilderBase &builder,
+                             llvm::Value *&retVal, llvm::IRBuilderBase &builder,
                              llvm::OpenMPIRBuilder &ompBuilder,
-                             LLVM::ModuleTranslation &moduleTranslation) {
-  if (!ompBuilder.Config.isTargetDevice())
-    return cast<llvm::Value>(&arg);
+                             LLVM::ModuleTranslation &moduleTranslation,
+                             llvm::IRBuilderBase::InsertPoint allocaIP,
+                             llvm::IRBuilderBase::InsertPoint codeGenIP) {
+  builder.restoreIP(allocaIP);
 
   llvm::Value *addr =
       builder.CreateAlloca(arg.getType()->isPointerTy()
@@ -2009,9 +2010,12 @@ createDeviceArgumentAccessor(llvm::Argument &arg, llvm::Value *input,
   llvm::Value *addrAscast =
       builder.CreatePointerBitCastOrAddrSpaceCast(addr, input->getType());
   builder.CreateStore(&arg, addrAscast);
-  llvm::Value *copy = builder.CreateLoad(arg.getType(), addrAscast);
 
-  return copy;
+  builder.restoreIP(codeGenIP);
+
+  retVal = builder.CreateLoad(arg.getType(), addrAscast);
+
+  return builder.saveIP();
 }
 
 static LogicalResult
@@ -2105,9 +2109,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     return combinedInfos;
   };
 
-  auto argAccessorCB = [&moduleTranslation](llvm::Argument &arg,
-                                            llvm::Value *input,
-                                            llvm::IRBuilderBase &builder) {
+  auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
+                           llvm::Value *&retVal, InsertPointTy allocaIP,
+                           InsertPointTy codeGenIP) {
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
     // We just return the unaltered argument for the host function
@@ -2115,11 +2119,14 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
     // keep host fallback functions working identically to the device
     // version (e.g. pass ByCopy values should be treated as such on
     // host and device, currently not always the case)
-    if (!ompBuilder->Config.isTargetDevice())
-      return cast<llvm::Value>(&arg);
+    if (!ompBuilder->Config.isTargetDevice()) {
+      retVal = cast<llvm::Value>(&arg);
+      return codeGenIP;
+    }
 
-    return createDeviceArgumentAccessor(arg, input, builder, *ompBuilder,
-                                        moduleTranslation);
+    return createDeviceArgumentAccessor(arg, input, retVal, builder,
+                                        *ompBuilder, moduleTranslation,
+                                        allocaIP, codeGenIP);
   };
 
   builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index cf70469e7484f64..99f1a3b072ad8fe 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -31,18 +31,18 @@ module attributes {omp.is_target_device = true} {
 // CHECK:      @[[DYNA_ENV:.*]] = weak_odr protected global %struct.DynamicEnvironmentTy zeroinitializer
 // CHECK:      @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] }
 // CHECK:      define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
-// CHECK:        %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[KERNEL_ENV]])
-// CHECK-NEXT:   %[[CMP:.*]] = icmp eq i32 %3, -1
-// CHECK-NEXT:   br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]]
-// CHECK:        [[LABEL_ENTRY]]:
 // CHECK:        %[[TMP_A:.*]] = alloca ptr, align 8
 // CHECK:        store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8
-// CHECK:        %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8
 // CHECK:        %[[TMP_B:.*]] = alloca ptr, align 8
 // CHECK:        store ptr %[[ADDR_B]], ptr %[[TMP_B]], align 8
-// CHECK:        %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8
 // CHECK:        %[[TMP_C:.*]] = alloca ptr, align 8
 // CHECK:        store ptr %[[ADDR_C]], ptr %[[TMP_C]], align 8
+// CHECK:        %[[INIT:.*]] = call i32 @__kmpc_target_init(ptr @[[KERNEL_ENV]])
+// CHECK-NEXT:   %[[CMP:.*]] = icmp eq i32 %[[INIT]], -1
+// CHECK-NEXT:   br i1 %[[CMP]], label %[[LABEL_ENTRY:.*]], label %[[LABEL_EXIT:.*]]
+// CHECK:        [[LABEL_ENTRY]]:
+// CHECK:        %[[PTR_A:.*]] = load ptr, ptr %[[TMP_A]], align 8
+// CHECK:        %[[PTR_B:.*]] = load ptr, ptr %[[TMP_B]], align 8
 // CHECK:        %[[PTR_C:.*]] = load ptr, ptr %[[TMP_C]], align 8
 // CHECK-NEXT:   br label %[[LABEL_TARGET:.*]]
 // CHECK:        [[LABEL_TARGET]]:



More information about the Mlir-commits mailing list