[Mlir-commits] [mlir] fe7ca1a - [mlir][openacc] Initial translation for DataOp to LLVM IR

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jul 27 19:04:12 PDT 2021


Author: Valentin Clement
Date: 2021-07-27T22:04:04-04:00
New Revision: fe7ca1a9fca0ccea7495224e0e837de705e69699

URL: https://github.com/llvm/llvm-project/commit/fe7ca1a9fca0ccea7495224e0e837de705e69699
DIFF: https://github.com/llvm/llvm-project/commit/fe7ca1a9fca0ccea7495224e0e837de705e69699.diff

LOG: [mlir][openacc] Initial translation for DataOp to LLVM IR

Add basic translation of acc.data to LLVM IR with runtime calls.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D104301

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
    mlir/test/Target/LLVMIR/openacc-llvm.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index a92c3ba381c67..8144f1527a067 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -636,6 +636,31 @@ class OpenMPIRBuilder {
   createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
                         std::string VarName);
 
+  struct MapperAllocas {
+    AllocaInst *ArgsBase = nullptr;
+    AllocaInst *Args = nullptr;
+    AllocaInst *ArgSizes = nullptr;
+  };
+
+  /// Create the allocas instruction used in call to mapper functions.
+  void createMapperAllocas(const LocationDescription &Loc,
+                           InsertPointTy AllocaIP, unsigned NumOperands,
+                           struct MapperAllocas &MapperAllocas);
+
+  /// Create the call for the target mapper function.
+  /// \param Loc The source location description.
+  /// \param MapperFunc Function to be called.
+  /// \param SrcLocInfo Source location information global.
+  /// \param MaptypesArgs
+  /// \param MapnamesArg
+  /// \param MapperAllocas The AllocaInst used for the call.
+  /// \param DeviceID Device ID for the call.
+  /// \param TotalNbOperand Number of operand in the call.
+  void emitMapperCall(const LocationDescription &Loc, Function *MapperFunc,
+                      Value *SrcLocInfo, Value *MaptypesArg, Value *MapnamesArg,
+                      struct MapperAllocas &MapperAllocas, int64_t DeviceID,
+                      unsigned NumOperands);
+
 public:
   /// Generator for __kmpc_copyprivate
   ///

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 60d71805c758f..76954f9a37e18 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2318,6 +2318,51 @@ OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
   return MaptypesArrayGlobal;
 }
 
+void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
+                                          InsertPointTy AllocaIP,
+                                          unsigned NumOperands,
+                                          struct MapperAllocas &MapperAllocas) {
+  if (!updateToLocation(Loc))
+    return;
+
+  auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
+  auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
+  Builder.restoreIP(AllocaIP);
+  AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI8PtrTy);
+  AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy);
+  AllocaInst *ArgSizes = Builder.CreateAlloca(ArrI64Ty);
+  Builder.restoreIP(Loc.IP);
+  MapperAllocas.ArgsBase = ArgsBase;
+  MapperAllocas.Args = Args;
+  MapperAllocas.ArgSizes = ArgSizes;
+}
+
+void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
+                                     Function *MapperFunc, Value *SrcLocInfo,
+                                     Value *MaptypesArg, Value *MapnamesArg,
+                                     struct MapperAllocas &MapperAllocas,
+                                     int64_t DeviceID, unsigned NumOperands) {
+  if (!updateToLocation(Loc))
+    return;
+
+  auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
+  auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
+  Value *ArgsBaseGEP =
+      Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
+                                {Builder.getInt32(0), Builder.getInt32(0)});
+  Value *ArgsGEP =
+      Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
+                                {Builder.getInt32(0), Builder.getInt32(0)});
+  Value *ArgSizesGEP =
+      Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
+                                {Builder.getInt32(0), Builder.getInt32(0)});
+  Value *NullPtr = Constant::getNullValue(Int8Ptr->getPointerTo());
+  Builder.CreateCall(MapperFunc,
+                     {SrcLocInfo, Builder.getInt64(DeviceID),
+                      Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
+                      ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
+}
+
 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
   assert(!(AO == AtomicOrdering::NotAtomic ||

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 0df6fe531e3c1..50887611eaf17 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -2703,4 +2703,110 @@ TEST_F(OpenMPIRBuilderTest, CreateOffloadMapnames) {
   EXPECT_EQ(Initializer->getType()->getArrayNumElements(), Names.size());
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateMapperAllocas) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  unsigned TotalNbOperand = 2;
+
+  OpenMPIRBuilder::MapperAllocas MapperAllocas;
+  IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
+                                    F->getEntryBlock().getFirstInsertionPt());
+  OMPBuilder.createMapperAllocas(Loc, AllocaIP, TotalNbOperand, MapperAllocas);
+  EXPECT_NE(MapperAllocas.ArgsBase, nullptr);
+  EXPECT_NE(MapperAllocas.Args, nullptr);
+  EXPECT_NE(MapperAllocas.ArgSizes, nullptr);
+  EXPECT_TRUE(MapperAllocas.ArgsBase->getAllocatedType()->isArrayTy());
+  ArrayType *ArrType =
+      dyn_cast<ArrayType>(MapperAllocas.ArgsBase->getAllocatedType());
+  EXPECT_EQ(ArrType->getNumElements(), TotalNbOperand);
+  EXPECT_TRUE(MapperAllocas.ArgsBase->getAllocatedType()
+                  ->getArrayElementType()
+                  ->isPointerTy());
+  EXPECT_TRUE(MapperAllocas.ArgsBase->getAllocatedType()
+                  ->getArrayElementType()
+                  ->getPointerElementType()
+                  ->isIntegerTy(8));
+
+  EXPECT_TRUE(MapperAllocas.Args->getAllocatedType()->isArrayTy());
+  ArrType = dyn_cast<ArrayType>(MapperAllocas.Args->getAllocatedType());
+  EXPECT_EQ(ArrType->getNumElements(), TotalNbOperand);
+  EXPECT_TRUE(MapperAllocas.Args->getAllocatedType()
+                  ->getArrayElementType()
+                  ->isPointerTy());
+  EXPECT_TRUE(MapperAllocas.Args->getAllocatedType()
+                  ->getArrayElementType()
+                  ->getPointerElementType()
+                  ->isIntegerTy(8));
+
+  EXPECT_TRUE(MapperAllocas.ArgSizes->getAllocatedType()->isArrayTy());
+  ArrType = dyn_cast<ArrayType>(MapperAllocas.ArgSizes->getAllocatedType());
+  EXPECT_EQ(ArrType->getNumElements(), TotalNbOperand);
+  EXPECT_TRUE(MapperAllocas.ArgSizes->getAllocatedType()
+                  ->getArrayElementType()
+                  ->isIntegerTy(64));
+}
+
+TEST_F(OpenMPIRBuilderTest, EmitMapperCall) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  LLVMContext &Ctx = M->getContext();
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  unsigned TotalNbOperand = 2;
+
+  OpenMPIRBuilder::MapperAllocas MapperAllocas;
+  IRBuilder<>::InsertPoint AllocaIP(&F->getEntryBlock(),
+                                    F->getEntryBlock().getFirstInsertionPt());
+  OMPBuilder.createMapperAllocas(Loc, AllocaIP, TotalNbOperand, MapperAllocas);
+
+  auto *BeginMapperFunc = OMPBuilder.getOrCreateRuntimeFunctionPtr(
+      omp::OMPRTL___tgt_target_data_begin_mapper);
+
+  SmallVector<uint64_t> Flags = {0, 2};
+
+  Constant *SrcLocCst = OMPBuilder.getOrCreateSrcLocStr("", "file1", 2, 5);
+  Value *SrcLocInfo = OMPBuilder.getOrCreateIdent(SrcLocCst);
+
+  Constant *Cst1 = OMPBuilder.getOrCreateSrcLocStr("array1", "file1", 2, 5);
+  Constant *Cst2 = OMPBuilder.getOrCreateSrcLocStr("array2", "file1", 3, 5);
+  SmallVector<llvm::Constant *> Names = {Cst1, Cst2};
+
+  GlobalVariable *Maptypes =
+      OMPBuilder.createOffloadMaptypes(Flags, ".offload_maptypes");
+  Value *MaptypesArg = Builder.CreateConstInBoundsGEP2_32(
+      ArrayType::get(Type::getInt64Ty(Ctx), TotalNbOperand), Maptypes,
+      /*Idx0=*/0, /*Idx1=*/0);
+
+  GlobalVariable *Mapnames =
+      OMPBuilder.createOffloadMapnames(Names, ".offload_mapnames");
+  Value *MapnamesArg = Builder.CreateConstInBoundsGEP2_32(
+      ArrayType::get(Type::getInt8PtrTy(Ctx), TotalNbOperand), Mapnames,
+      /*Idx0=*/0, /*Idx1=*/0);
+
+  OMPBuilder.emitMapperCall(Builder.saveIP(), BeginMapperFunc, SrcLocInfo,
+                            MaptypesArg, MapnamesArg, MapperAllocas, -1,
+                            TotalNbOperand);
+
+  CallInst *MapperCall = dyn_cast<CallInst>(&BB->back());
+  EXPECT_NE(MapperCall, nullptr);
+  EXPECT_EQ(MapperCall->getNumArgOperands(), 9U);
+  EXPECT_EQ(MapperCall->getCalledFunction()->getName(),
+            "__tgt_target_data_begin_mapper");
+  EXPECT_EQ(MapperCall->getOperand(0), SrcLocInfo);
+  EXPECT_TRUE(MapperCall->getOperand(1)->getType()->isIntegerTy(64));
+  EXPECT_TRUE(MapperCall->getOperand(2)->getType()->isIntegerTy(32));
+
+  EXPECT_EQ(MapperCall->getOperand(6), MaptypesArg);
+  EXPECT_EQ(MapperCall->getOperand(7), MapnamesArg);
+  EXPECT_TRUE(MapperCall->getOperand(8)->getType()->isPointerTy());
+}
+
 } // namespace

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
index 4fbd051b65390..2900a095e3dbe 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp
@@ -32,14 +32,14 @@ using OpenACCIRBuilder = llvm::OpenMPIRBuilder;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// 0 = alloc/create
-static constexpr uint64_t kCreateFlag = 0;
-/// 1 = to/device/copyin
-static constexpr uint64_t kDeviceCopyinFlag = 1;
-/// 2 = from/copyout
-static constexpr uint64_t kHostCopyoutFlag = 2;
-/// 8 = delete
-static constexpr uint64_t kDeleteFlag = 8;
+/// Flag values are extracted from openmp/libomptarget/include/omptarget.h and
+/// mapped to corresponding OpenACC flags.
+static constexpr uint64_t kCreateFlag = 0x000;
+static constexpr uint64_t kDeviceCopyinFlag = 0x001;
+static constexpr uint64_t kHostCopyoutFlag = 0x002;
+static constexpr uint64_t kCopyFlag = kDeviceCopyinFlag | kHostCopyoutFlag;
+static constexpr uint64_t kPresentFlag = 0x1000;
+static constexpr uint64_t kDeleteFlag = 0x008;
 
 /// Default value for the device id
 static constexpr int64_t kDefaultDevice = -1;
@@ -123,9 +123,8 @@ processOperands(llvm::IRBuilderBase &builder,
                 LLVM::ModuleTranslation &moduleTranslation, Operation *op,
                 ValueRange operands, unsigned totalNbOperand,
                 uint64_t operandFlag, SmallVector<uint64_t> &flags,
-                SmallVector<llvm::Constant *> &names, unsigned &index,
-                llvm::AllocaInst *argsBase, llvm::AllocaInst *args,
-                llvm::AllocaInst *argSizes) {
+                SmallVectorImpl<llvm::Constant *> &names, unsigned &index,
+                struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
   OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
   llvm::LLVMContext &ctx = builder.getContext();
   auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
@@ -160,21 +159,24 @@ processOperands(llvm::IRBuilderBase &builder,
     // Store base pointer extracted from operand into the i-th position of
     // argBase.
     llvm::Value *ptrBaseGEP = builder.CreateInBoundsGEP(
-        arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(index)});
+        arrI8PtrTy, mapperAllocas.ArgsBase,
+        {builder.getInt32(0), builder.getInt32(index)});
     llvm::Value *ptrBaseCast = builder.CreateBitCast(
         ptrBaseGEP, dataPtrBase->getType()->getPointerTo());
     builder.CreateStore(dataPtrBase, ptrBaseCast);
 
     // Store pointer extracted from operand into the i-th position of args.
     llvm::Value *ptrGEP = builder.CreateInBoundsGEP(
-        arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(index)});
+        arrI8PtrTy, mapperAllocas.Args,
+        {builder.getInt32(0), builder.getInt32(index)});
     llvm::Value *ptrCast =
         builder.CreateBitCast(ptrGEP, dataPtr->getType()->getPointerTo());
     builder.CreateStore(dataPtr, ptrCast);
 
     // Store size extracted from operand into the i-th position of argSizes.
     llvm::Value *sizeGEP = builder.CreateInBoundsGEP(
-        arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(index)});
+        arrI64Ty, mapperAllocas.ArgSizes,
+        {builder.getInt32(0), builder.getInt32(index)});
     builder.CreateStore(dataSize, sizeGEP);
 
     flags.push_back(operandFlag);
@@ -187,11 +189,12 @@ processOperands(llvm::IRBuilderBase &builder,
 }
 
 /// Process data operands from acc::EnterDataOp
-static LogicalResult processDataOperands(
-    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
-    acc::EnterDataOp op, SmallVector<uint64_t> &flags,
-    SmallVector<llvm::Constant *> &names, llvm::AllocaInst *argsBase,
-    llvm::AllocaInst *args, llvm::AllocaInst *argSizes) {
+static LogicalResult
+processDataOperands(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    acc::EnterDataOp op, SmallVector<uint64_t> &flags,
+                    SmallVectorImpl<llvm::Constant *> &names,
+                    struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
   // TODO add `create_zero` and `attach` operands
 
   unsigned index = 0;
@@ -199,26 +202,26 @@ static LogicalResult processDataOperands(
   // Create operands are handled as `alloc` call.
   if (failed(processOperands(builder, moduleTranslation, op,
                              op.createOperands(), op.getNumDataOperands(),
-                             kCreateFlag, flags, names, index, argsBase, args,
-                             argSizes)))
+                             kCreateFlag, flags, names, index, mapperAllocas)))
     return failure();
 
   // Copyin operands are handled as `to` call.
   if (failed(processOperands(builder, moduleTranslation, op,
                              op.copyinOperands(), op.getNumDataOperands(),
-                             kDeviceCopyinFlag, flags, names, index, argsBase,
-                             args, argSizes)))
+                             kDeviceCopyinFlag, flags, names, index,
+                             mapperAllocas)))
     return failure();
 
   return success();
 }
 
 /// Process data operands from acc::ExitDataOp
-static LogicalResult processDataOperands(
-    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
-    acc::ExitDataOp op, SmallVector<uint64_t> &flags,
-    SmallVector<llvm::Constant *> &names, llvm::AllocaInst *argsBase,
-    llvm::AllocaInst *args, llvm::AllocaInst *argSizes) {
+static LogicalResult
+processDataOperands(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    acc::ExitDataOp op, SmallVector<uint64_t> &flags,
+                    SmallVectorImpl<llvm::Constant *> &names,
+                    struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
   // TODO add `detach` operands
 
   unsigned index = 0;
@@ -226,39 +229,39 @@ static LogicalResult processDataOperands(
   // Delete operands are handled as `delete` call.
   if (failed(processOperands(builder, moduleTranslation, op,
                              op.deleteOperands(), op.getNumDataOperands(),
-                             kDeleteFlag, flags, names, index, argsBase, args,
-                             argSizes)))
+                             kDeleteFlag, flags, names, index, mapperAllocas)))
     return failure();
 
   // Copyout operands are handled as `from` call.
   if (failed(processOperands(builder, moduleTranslation, op,
                              op.copyoutOperands(), op.getNumDataOperands(),
-                             kHostCopyoutFlag, flags, names, index, argsBase,
-                             args, argSizes)))
+                             kHostCopyoutFlag, flags, names, index,
+                             mapperAllocas)))
     return failure();
 
   return success();
 }
 
 /// Process data operands from acc::UpdateOp
-static LogicalResult processDataOperands(
-    llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
-    acc::UpdateOp op, SmallVector<uint64_t> &flags,
-    SmallVector<llvm::Constant *> &names, llvm::AllocaInst *argsBase,
-    llvm::AllocaInst *args, llvm::AllocaInst *argSizes) {
+static LogicalResult
+processDataOperands(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    acc::UpdateOp op, SmallVector<uint64_t> &flags,
+                    SmallVectorImpl<llvm::Constant *> &names,
+                    struct OpenACCIRBuilder::MapperAllocas &mapperAllocas) {
   unsigned index = 0;
 
   // Host operands are handled as `from` call.
   if (failed(processOperands(builder, moduleTranslation, op, op.hostOperands(),
                              op.getNumDataOperands(), kHostCopyoutFlag, flags,
-                             names, index, argsBase, args, argSizes)))
+                             names, index, mapperAllocas)))
     return failure();
 
   // Device operands are handled as `to` call.
   if (failed(processOperands(builder, moduleTranslation, op,
                              op.deviceOperands(), op.getNumDataOperands(),
-                             kDeviceCopyinFlag, flags, names, index, argsBase,
-                             args, argSizes)))
+                             kDeviceCopyinFlag, flags, names, index,
+                             mapperAllocas)))
     return failure();
 
   return success();
@@ -268,6 +271,153 @@ static LogicalResult processDataOperands(
 // Conversion functions
 //===----------------------------------------------------------------------===//
 
+/// Converts an OpenACC data operation into LLVM IR.
+static LogicalResult convertDataOp(acc::DataOp &op,
+                                   llvm::IRBuilderBase &builder,
+                                   LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::LLVMContext &ctx = builder.getContext();
+  auto enclosingFuncOp = op.getOperation()->getParentOfType<LLVM::LLVMFuncOp>();
+  llvm::Function *enclosingFunction =
+      moduleTranslation.lookupFunction(enclosingFuncOp.getName());
+
+  OpenACCIRBuilder *accBuilder = moduleTranslation.getOpenMPBuilder();
+
+  llvm::Value *srcLocInfo = createSourceLocationInfo(*accBuilder, op);
+
+  llvm::Function *beginMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
+      llvm::omp::OMPRTL___tgt_target_data_begin_mapper);
+
+  llvm::Function *endMapperFunc = accBuilder->getOrCreateRuntimeFunctionPtr(
+      llvm::omp::OMPRTL___tgt_target_data_end_mapper);
+
+  // Number of arguments in the data operation.
+  unsigned totalNbOperand = op.getNumDataOperands();
+
+  struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
+  OpenACCIRBuilder::InsertPointTy allocaIP(
+      &enclosingFunction->getEntryBlock(),
+      enclosingFunction->getEntryBlock().getFirstInsertionPt());
+  accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
+                                  mapperAllocas);
+
+  SmallVector<uint64_t> flags;
+  SmallVector<llvm::Constant *> names;
+  unsigned index = 0;
+
+  // TODO handle no_create, deviceptr and attach operands.
+
+  if (failed(processOperands(builder, moduleTranslation, op, op.copyOperands(),
+                             totalNbOperand, kCopyFlag, flags, names, index,
+                             mapperAllocas)))
+    return failure();
+
+  if (failed(processOperands(
+          builder, moduleTranslation, op, op.copyinOperands(), totalNbOperand,
+          kDeviceCopyinFlag, flags, names, index, mapperAllocas)))
+    return failure();
+
+  // TODO copyin readonly currenlty handled as copyin. Update when extension
+  // available.
+  if (failed(processOperands(builder, moduleTranslation, op,
+                             op.copyinReadonlyOperands(), totalNbOperand,
+                             kDeviceCopyinFlag, flags, names, index,
+                             mapperAllocas)))
+    return failure();
+
+  if (failed(processOperands(
+          builder, moduleTranslation, op, op.copyoutOperands(), totalNbOperand,
+          kHostCopyoutFlag, flags, names, index, mapperAllocas)))
+    return failure();
+
+  // TODO copyout zero currenlty handled as copyout. Update when extension
+  // available.
+  if (failed(processOperands(builder, moduleTranslation, op,
+                             op.copyoutZeroOperands(), totalNbOperand,
+                             kHostCopyoutFlag, flags, names, index,
+                             mapperAllocas)))
+    return failure();
+
+  if (failed(processOperands(builder, moduleTranslation, op,
+                             op.createOperands(), totalNbOperand, kCreateFlag,
+                             flags, names, index, mapperAllocas)))
+    return failure();
+
+  // TODO create zero currenlty handled as create. Update when extension
+  // available.
+  if (failed(processOperands(builder, moduleTranslation, op,
+                             op.createZeroOperands(), totalNbOperand,
+                             kCreateFlag, flags, names, index, mapperAllocas)))
+    return failure();
+
+  if (failed(processOperands(builder, moduleTranslation, op,
+                             op.presentOperands(), totalNbOperand, kPresentFlag,
+                             flags, names, index, mapperAllocas)))
+    return failure();
+
+  llvm::GlobalVariable *maptypes =
+      accBuilder->createOffloadMaptypes(flags, ".offload_maptypes");
+  llvm::Value *maptypesArg = builder.CreateConstInBoundsGEP2_32(
+      llvm::ArrayType::get(llvm::Type::getInt64Ty(ctx), totalNbOperand),
+      maptypes, /*Idx0=*/0, /*Idx1=*/0);
+
+  llvm::GlobalVariable *mapnames =
+      accBuilder->createOffloadMapnames(names, ".offload_mapnames");
+  llvm::Value *mapnamesArg = builder.CreateConstInBoundsGEP2_32(
+      llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
+      mapnames, /*Idx0=*/0, /*Idx1=*/0);
+
+  // Create call to start the data region.
+  accBuilder->emitMapperCall(builder.saveIP(), beginMapperFunc, srcLocInfo,
+                             maptypesArg, mapnamesArg, mapperAllocas,
+                             kDefaultDevice, totalNbOperand);
+
+  // Convert the region.
+  llvm::BasicBlock *entryBlock = nullptr;
+
+  for (Block &bb : op.region()) {
+    llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
+        ctx, "acc.data", builder.GetInsertBlock()->getParent());
+    if (entryBlock == nullptr)
+      entryBlock = llvmBB;
+    moduleTranslation.mapBlock(&bb, llvmBB);
+  }
+
+  auto afterDataRegion = builder.saveIP();
+
+  llvm::BranchInst *sourceTerminator = builder.CreateBr(entryBlock);
+
+  builder.restoreIP(afterDataRegion);
+  llvm::BasicBlock *endDataBlock = llvm::BasicBlock::Create(
+      ctx, "acc.end_data", builder.GetInsertBlock()->getParent());
+
+  SetVector<Block *> blocks =
+      LLVM::detail::getTopologicallySortedBlocks(op.region());
+  for (Block *bb : blocks) {
+    llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(bb);
+    if (bb->isEntryBlock()) {
+      assert(sourceTerminator->getNumSuccessors() == 1 &&
+             "provided entry block has multiple successors");
+      sourceTerminator->setSuccessor(0, llvmBB);
+    }
+
+    if (failed(
+            moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder))) {
+      return failure();
+    }
+
+    if (isa<acc::TerminatorOp, acc::YieldOp>(bb->getTerminator()))
+      builder.CreateBr(endDataBlock);
+  }
+
+  // Create call to end the data region.
+  builder.SetInsertPoint(endDataBlock);
+  accBuilder->emitMapperCall(builder.saveIP(), endMapperFunc, srcLocInfo,
+                             maptypesArg, mapnamesArg, mapperAllocas,
+                             kDefaultDevice, totalNbOperand);
+
+  return success();
+}
+
 /// Converts an OpenACC standalone data operation into LLVM IR.
 template <typename OpTy>
 static LogicalResult
@@ -286,27 +436,20 @@ convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
   // Number of arguments in the enter_data operation.
   unsigned totalNbOperand = op.getNumDataOperands();
 
-  // TODO could be moved to OpenXXIRBuilder?
   llvm::LLVMContext &ctx = builder.getContext();
-  auto *i8PtrTy = llvm::Type::getInt8PtrTy(ctx);
-  auto *arrI8PtrTy = llvm::ArrayType::get(i8PtrTy, totalNbOperand);
-  auto *i64Ty = llvm::Type::getInt64Ty(ctx);
-  auto *arrI64Ty = llvm::ArrayType::get(i64Ty, totalNbOperand);
-  llvm::IRBuilder<>::InsertPoint allocaIP(
+
+  struct OpenACCIRBuilder::MapperAllocas mapperAllocas;
+  OpenACCIRBuilder::InsertPointTy allocaIP(
       &enclosingFunction->getEntryBlock(),
       enclosingFunction->getEntryBlock().getFirstInsertionPt());
-  llvm::IRBuilder<>::InsertPoint currentIP = builder.saveIP();
-  builder.restoreIP(allocaIP);
-  llvm::AllocaInst *argsBase = builder.CreateAlloca(arrI8PtrTy);
-  llvm::AllocaInst *args = builder.CreateAlloca(arrI8PtrTy);
-  llvm::AllocaInst *argSizes = builder.CreateAlloca(arrI64Ty);
-  builder.restoreIP(currentIP);
+  accBuilder->createMapperAllocas(builder.saveIP(), allocaIP, totalNbOperand,
+                                  mapperAllocas);
 
   SmallVector<uint64_t> flags;
   SmallVector<llvm::Constant *> names;
 
   if (failed(processDataOperands(builder, moduleTranslation, op, flags, names,
-                                 argsBase, args, argSizes)))
+                                 mapperAllocas)))
     return failure();
 
   llvm::GlobalVariable *maptypes =
@@ -321,19 +464,9 @@ convertStandaloneDataOp(OpTy &op, llvm::IRBuilderBase &builder,
       llvm::ArrayType::get(llvm::Type::getInt8PtrTy(ctx), totalNbOperand),
       mapnames, /*Idx0=*/0, /*Idx1=*/0);
 
-  llvm::Value *argsBaseGEP = builder.CreateInBoundsGEP(
-      arrI8PtrTy, argsBase, {builder.getInt32(0), builder.getInt32(0)});
-  llvm::Value *argsGEP = builder.CreateInBoundsGEP(
-      arrI8PtrTy, args, {builder.getInt32(0), builder.getInt32(0)});
-  llvm::Value *argSizesGEP = builder.CreateInBoundsGEP(
-      arrI64Ty, argSizes, {builder.getInt32(0), builder.getInt32(0)});
-  llvm::Value *nullPtr = llvm::Constant::getNullValue(
-      llvm::Type::getInt8PtrTy(ctx)->getPointerTo());
-
-  builder.CreateCall(mapperFunc,
-                     {srcLocInfo, builder.getInt64(kDefaultDevice),
-                      builder.getInt32(totalNbOperand), argsBaseGEP, argsGEP,
-                      argSizesGEP, maptypesArg, mapnamesArg, nullPtr});
+  accBuilder->emitMapperCall(builder.saveIP(), mapperFunc, srcLocInfo,
+                             maptypesArg, mapnamesArg, mapperAllocas,
+                             kDefaultDevice, totalNbOperand);
 
   return success();
 }
@@ -363,6 +496,9 @@ LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
     LLVM::ModuleTranslation &moduleTranslation) const {
 
   return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+      .Case([&](acc::DataOp dataOp) {
+        return convertDataOp(dataOp, builder, moduleTranslation);
+      })
       .Case([&](acc::EnterDataOp enterDataOp) {
         return convertStandaloneDataOp<acc::EnterDataOp>(enterDataOp, builder,
                                                          moduleTranslation);
@@ -375,6 +511,13 @@ LogicalResult OpenACCDialectLLVMIRTranslationInterface::convertOperation(
         return convertStandaloneDataOp<acc::UpdateOp>(updateOp, builder,
                                                       moduleTranslation);
       })
+      .Case<acc::TerminatorOp, acc::YieldOp>([](auto op) {
+        // `yield` and `terminator` can be just omitted. The block structure was
+        // created in the function that handles their parent operation.
+        assert(op->getNumOperands() == 0 &&
+               "unexpected OpenACC terminator with operands");
+        return success();
+      })
       .Default([&](Operation *op) {
         return op->emitError("unsupported OpenACC operation: ")
                << op->getName();

diff  --git a/mlir/test/Target/LLVMIR/openacc-llvm.mlir b/mlir/test/Target/LLVMIR/openacc-llvm.mlir
index 52656f9705bc2..f0fb7ff82c9bd 100644
--- a/mlir/test/Target/LLVMIR/openacc-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openacc-llvm.mlir
@@ -182,3 +182,76 @@ llvm.func @testupdateop(%arg0: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x
 // CHECK: call void @__tgt_target_data_update_mapper(%struct.ident_t* [[LOCGLOBAL]], i64 -1, i32 2, i8** [[ARGBASE_ALLOCA_GEP]], i8** [[ARG_ALLOCA_GEP]], i64* [[SIZE_ALLOCA_GEP]], i64* getelementptr inbounds ([{{[0-9]*}} x i64], [{{[0-9]*}} x i64]* [[MAPTYPES]], i32 0, i32 0), i8** getelementptr inbounds ([{{[0-9]*}} x i8*], [{{[0-9]*}} x i8*]* [[MAPNAMES]], i32 0, i32 0), i8** null)
 
 // CHECK: declare void @__tgt_target_data_update_mapper(%struct.ident_t*, i64, i32, i8**, i8**, i64*, i64*, i8**, i8**) #0
+
+// -----
+
+llvm.func @testdataop(%arg0: !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, %arg1: !llvm.ptr<f32>, %arg2: !llvm.ptr<i32>) {
+  %0 = llvm.mlir.constant(10 : index) : i64
+  %1 = llvm.mlir.null : !llvm.ptr<f32>
+  %2 = llvm.getelementptr %1[%0] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
+  %3 = llvm.ptrtoint %2 : !llvm.ptr<f32> to i64
+  %4 = llvm.extractvalue %arg0[1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
+  %5 = llvm.mlir.undef : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
+  %6 = llvm.insertvalue %arg0, %5[0] : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
+  %7 = llvm.insertvalue %4, %6[1] : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
+  %8 = llvm.insertvalue %3, %7[2] : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>
+  acc.data copy(%8 : !llvm.struct<"openacc_data", (struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>, ptr<f32>, i64)>) copyout(%arg1 : !llvm.ptr<f32>) {
+    %9 = llvm.mlir.constant(2 : i32) : i32
+    llvm.store %9, %arg2 : !llvm.ptr<i32>
+    acc.terminator
+  }
+  llvm.return
+}
+
+// CHECK: %struct.ident_t = type { i32, i32, i32, i32, i8* }
+// CHECK: [[LOCSTR:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};testdataop;{{[0-9]*}};{{[0-9]*}};;\00", align 1
+// CHECK: [[LOCGLOBAL:@.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[LOCSTR]], i32 0, i32 0) }, align 8
+// CHECK: [[MAPNAME1:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;{{[0-9]*}};{{[0-9]*}};;\00", align 1
+// CHECK: [[MAPNAME2:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i8] c";{{.*}};unknown;{{[0-9]*}};{{[0-9]*}};;\00", align 1
+// CHECK: [[MAPTYPES:@.*]] = private unnamed_addr constant [{{[0-9]*}} x i64] [i64 3, i64 2]
+// CHECK: [[MAPNAMES:@.*]] = private constant [{{[0-9]*}} x i8*] [i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[MAPNAME1]], i32 0, i32 0), i8* getelementptr inbounds ([{{[0-9]*}} x i8], [{{[0-9]*}} x i8]* [[MAPNAME2]], i32 0, i32 0)]
+
+// CHECK: define void @testdataop({ float*, float*, i64, [1 x i64], [1 x i64] } %{{.*}}, float* [[SIMPLEPTR:%.*]], i32* %{{.*}})
+// CHECK: [[ARGBASE_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i8*], align 8
+// CHECK: [[ARG_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i8*], align 8
+// CHECK: [[SIZE_ALLOCA:%.*]] = alloca [{{[0-9]*}} x i64], align 8
+
+// CHECK: [[ARGBASE:%.*]] = extractvalue %openacc_data %{{.*}}, 0
+// CHECK: [[ARG:%.*]] = extractvalue %openacc_data %{{.*}}, 1
+// CHECK: [[ARGSIZE:%.*]] = extractvalue %openacc_data %{{.*}}, 2
+// CHECK: [[ARGBASEGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 0
+// CHECK: [[ARGBASEGEPCAST:%.*]] = bitcast i8** [[ARGBASEGEP]] to { float*, float*, i64, [1 x i64], [1 x i64] }*
+// CHECK: store { float*, float*, i64, [1 x i64], [1 x i64] } [[ARGBASE]], { float*, float*, i64, [1 x i64], [1 x i64] }* [[ARGBASEGEPCAST]], align 8
+// CHECK: [[ARGGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 0
+// CHECK: [[ARGGEPCAST:%.*]] = bitcast i8** [[ARGGEP]] to float**
+// CHECK: store float* [[ARG]], float** [[ARGGEPCAST]], align 8
+// CHECK: [[SIZEGEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 0
+// CHECK: store i64 [[ARGSIZE]], i64* [[SIZEGEP]], align 4
+
+// CHECK: [[ARGBASEGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 1
+// CHECK: [[ARGBASEGEPCAST:%.*]] = bitcast i8** [[ARGBASEGEP]] to float**
+// CHECK: store float* [[SIMPLEPTR]], float** [[ARGBASEGEPCAST]], align 8
+// CHECK: [[ARGGEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 1
+// CHECK: [[ARGGEPCAST:%.*]] = bitcast i8** [[ARGGEP]] to float**
+// CHECK: store float* [[SIMPLEPTR]], float** [[ARGGEPCAST]], align 8
+// CHECK: [[SIZEGEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 1
+// CHECK: store i64 ptrtoint (i1** getelementptr (i1*, i1** null, i32 1) to i64), i64* [[SIZEGEP]], align 4
+
+// CHECK: [[ARGBASE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 0
+// CHECK: [[ARG_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 0
+// CHECK: [[SIZE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 0
+// CHECK: call void @__tgt_target_data_begin_mapper(%struct.ident_t* [[LOCGLOBAL]], i64 -1, i32 2, i8** [[ARGBASE_ALLOCA_GEP]], i8** [[ARG_ALLOCA_GEP]], i64* [[SIZE_ALLOCA_GEP]], i64* getelementptr inbounds ([{{[0-9]*}} x i64], [{{[0-9]*}} x i64]* [[MAPTYPES]], i32 0, i32 0), i8** getelementptr inbounds ([{{[0-9]*}} x i8*], [{{[0-9]*}} x i8*]* [[MAPNAMES]], i32 0, i32 0), i8** null)
+// CHECK: br label %acc.data
+
+// CHECK:      acc.data:
+// CHECK-NEXT:   store i32 2, i32* %{{.*}}
+// CHECK-NEXT:   br label %acc.end_data
+
+// CHECK: acc.end_data:
+// CHECK:   [[ARGBASE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARGBASE_ALLOCA]], i32 0, i32 0
+// CHECK:   [[ARG_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[ARG_ALLOCA]], i32 0, i32 0
+// CHECK:   [[SIZE_ALLOCA_GEP:%.*]] = getelementptr inbounds [2 x i64], [2 x i64]* [[SIZE_ALLOCA]], i32 0, i32 0
+// CHECK:   call void @__tgt_target_data_end_mapper(%struct.ident_t* [[LOCGLOBAL]], i64 -1, i32 2, i8** [[ARGBASE_ALLOCA_GEP]], i8** [[ARG_ALLOCA_GEP]], i64* [[SIZE_ALLOCA_GEP]], i64* getelementptr inbounds ([{{[0-9]*}} x i64], [{{[0-9]*}} x i64]* [[MAPTYPES]], i32 0, i32 0), i8** getelementptr inbounds ([{{[0-9]*}} x i8*], [{{[0-9]*}} x i8*]* [[MAPNAMES]], i32 0, i32 0), i8** null)
+
+// CHECK: declare void @__tgt_target_data_begin_mapper(%struct.ident_t*, i64, i32, i8**, i8**, i64*, i64*, i8**, i8**)
+// CHECK: declare void @__tgt_target_data_end_mapper(%struct.ident_t*, i64, i32, i8**, i8**, i64*, i64*, i8**, i8**)


        


More information about the Mlir-commits mailing list