[Mlir-commits] [mlir] [OpenMP][MLIR] Add "IsolatedFromAbove" trait to omp.target (PR #67318)

Akash Banerjee llvmlistbot at llvm.org
Mon Nov 6 04:33:46 PST 2023


https://github.com/TIFitis updated https://github.com/llvm/llvm-project/pull/67318

>From da0c78f022cf5993565da484eeaaa39a001062b9 Mon Sep 17 00:00:00 2001
From: Akash Banerjee <Akash.Banerjee at amd.com>
Date: Fri, 22 Sep 2023 18:13:46 +0100
Subject: [PATCH] [OpenMP][MLIR] Add "IsolatedFromAbove" trait to omp.target

This patch adds the MLIR translation changes required for add the IsolatedFromAbove and OutlineableOpenMPOpInterface traits to omp.target. It links the newly added block arguments to their corresponding llvm values.

Depends on #67164.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 51 +++++++++++++++----
 .../OpenMPToLLVM/convert-to-llvmir.mlir       | 20 +++++---
 mlir/test/Dialect/OpenMP/canonicalize.mlir    |  5 +-
 mlir/test/Dialect/OpenMP/invalid.mlir         | 40 ++++++++++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             | 21 +++++---
 .../omptarget-array-sectioning-host.mlir      | 15 +++---
 ...target-byref-bycopy-generation-device.mlir |  9 ++--
 ...mptarget-byref-bycopy-generation-host.mlir |  7 +--
 .../omptarget-declare-target-llvm-device.mlir | 13 ++---
 .../LLVMIR/omptarget-region-device-llvm.mlir  |  9 ++--
 .../omptarget-region-llvm-target-device.mlir  |  9 ++--
 .../Target/LLVMIR/omptarget-region-llvm.mlir  |  9 ++--
 .../omptarget-region-parallel-llvm.mlir       |  9 ++--
 13 files changed, 156 insertions(+), 61 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1daf60b8659bb66..8e741bf83be32f0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1733,12 +1733,13 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
            "missing map info operation or incorrect map info operation type");
     if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
             mapValue.getDefiningOp())) {
-      mapData.OriginalValue.push_back(
-          moduleTranslation.lookupValue(mapOp.getVarPtr()));
+      mapData.OriginalValue.push_back(moduleTranslation.lookupValue(
+          mapOp.getVarPtr() ? mapOp.getVarPtr() : mapOp.getVal()));
       mapData.Pointers.push_back(mapData.OriginalValue.back());
 
       if (llvm::Value *refPtr = getRefPtrIfDeclareTarget(
-              mapOp.getVarPtr(), moduleTranslation)) { // declare target
+              mapOp.getVarPtr() ? mapOp.getVarPtr() : mapOp.getVal(),
+              moduleTranslation)) { // declare target
         mapData.IsDeclareTarget.push_back(true);
         mapData.BasePointers.push_back(refPtr);
       } else { // regular mapped variable
@@ -1746,10 +1747,14 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
         mapData.BasePointers.push_back(mapData.OriginalValue.back());
       }
 
-      mapData.Sizes.push_back(getSizeInBytes(dl, mapOp.getVarType(), mapOp,
-                                             builder, moduleTranslation));
-      mapData.BaseType.push_back(
-          moduleTranslation.convertType(mapOp.getVarType()));
+      mapData.Sizes.push_back(
+          getSizeInBytes(dl,
+                         mapOp.getVal() ? mapOp.getVal().getType()
+                                        : mapOp.getVarType().value(),
+                         mapOp, builder, moduleTranslation));
+      mapData.BaseType.push_back(moduleTranslation.convertType(
+          mapOp.getVal() ? mapOp.getVal().getType()
+                         : mapOp.getVarType().value()));
       mapData.MapClause.push_back(mapOp.getOperation());
       mapData.Types.push_back(
           llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType().value()));
@@ -1796,6 +1801,13 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
     else if (isTargetParams)
       mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
 
+    if (auto mapInfoOp = dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]))
+      if (mapInfoOp.getMapCaptureType().value() ==
+              mlir::omp::VariableCaptureKind::ByCopy &&
+          !(mapInfoOp.getVarType().has_value() &&
+            mapInfoOp.getVarType()->isa<LLVM::LLVMPointerType>()))
+        mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
+
     combinedInfo.BasePointers.emplace_back(mapData.BasePointers[i]);
     combinedInfo.Pointers.emplace_back(mapData.Pointers[i]);
     combinedInfo.DevicePointers.emplace_back(mapData.DevicePointers[i]);
@@ -2318,6 +2330,19 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   auto targetOp = cast<omp::TargetOp>(opInst);
   auto &targetRegion = targetOp.getRegion();
+  DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
+  SmallVector<Value> mapOperands = targetOp.getMapOperands();
+
+  // Remove mapOperands/blockArgs that have no use inside the region.
+  assert(mapOperands.size() == targetRegion.getNumArguments() &&
+         "Number of mapOperands must be same as block_arguments");
+  for (size_t i = 0; i < mapOperands.size(); i++) {
+    if (targetRegion.getArgument(i).use_empty()) {
+      targetRegion.eraseArgument(i);
+      mapOperands.erase(&mapOperands[i]);
+      i--;
+    }
+  }
 
   LogicalResult bodyGenStatus = success();
 
@@ -2325,6 +2350,16 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   auto bodyCB = [&](InsertPointTy allocaIP,
                     InsertPointTy codeGenIP) -> InsertPointTy {
     builder.restoreIP(codeGenIP);
+    unsigned argIndex = 0;
+    for (auto &mapOp : mapOperands) {
+      auto mapInfoOp =
+          mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp());
+      llvm::Value *mapOpValue = moduleTranslation.lookupValue(
+          mapInfoOp.getVarPtr() ? mapInfoOp.getVarPtr() : mapInfoOp.getVal());
+      const auto &arg = targetRegion.front().getArgument(argIndex);
+      moduleTranslation.mapValue(arg, mapOpValue);
+      argIndex++;
+    }
     llvm::BasicBlock *exitBlock = convertOmpOpRegions(
         targetRegion, "omp.target", builder, moduleTranslation, bodyGenStatus);
     builder.SetInsertPoint(exitBlock);
@@ -2352,8 +2387,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
 
-  DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
-  llvm::SmallVector<Value> mapOperands = targetOp.getMapOperands();
   MapInfoData mapData;
   collectMapDataFromMapOperands(mapData, mapOperands, moduleTranslation, dl,
                                 builder);
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index bbf50617edf9448..cf267740418f856 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -244,10 +244,12 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr, %i : !llvm.ptr) {
 // CHECK:                             %[[ARG_0:.*]]: !llvm.ptr,
 // CHECK:                             %[[ARG_1:.*]]: !llvm.ptr) {
 // CHECK:           %[[VAL_0:.*]] = llvm.mlir.constant(64 : i32) : i32
-// CHECK:           %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.array<1024 x i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-// CHECK:           omp.target   thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP]] : !llvm.ptr) {
+// CHECK:           %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG_0]] : !llvm.ptr, !llvm.array<1024 x i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+// CHECK:           %[[MAP2:.*]] = omp.map_info var_ptr(%[[ARG_1]] : !llvm.ptr, i32)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = ""}
+// CHECK:           omp.target   thread_limit(%[[VAL_0]] : i32) map_entries(%[[MAP1]] -> %[[BB_ARG0:.*]], %[[MAP2]] -> %[[BB_ARG1:.*]] : !llvm.ptr, !llvm.ptr) {
+// CHECK:           ^bb0(%[[BB_ARG0]]: !llvm.ptr, %[[BB_ARG1]]: !llvm.ptr):
 // CHECK:             %[[VAL_1:.*]] = llvm.mlir.constant(10 : i32) : i32
-// CHECK:             llvm.store %[[VAL_1]], %[[ARG_1]] : i32, !llvm.ptr
+// CHECK:             llvm.store %[[VAL_1]], %[[BB_ARG1]] : !llvm.ptr
 // CHECK:             omp.terminator
 // CHECK:           }
 // CHECK:           llvm.return
@@ -256,9 +258,11 @@ llvm.func @_QPomp_target_data_region(%a : !llvm.ptr, %i : !llvm.ptr) {
 llvm.func @_QPomp_target(%a : !llvm.ptr, %i : !llvm.ptr) {
   %0 = llvm.mlir.constant(64 : i32) : i32
   %1 = omp.map_info var_ptr(%a : !llvm.ptr, !llvm.array<1024 x i32>)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-  omp.target   thread_limit(%0 : i32) map_entries(%1 : !llvm.ptr) {
+  %3 = omp.map_info var_ptr(%i : !llvm.ptr, i32)   map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = ""}
+  omp.target   thread_limit(%0 : i32) map_entries(%1 -> %arg0, %3 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
     %2 = llvm.mlir.constant(10 : i32) : i32
-    llvm.store %2, %i : i32, !llvm.ptr
+    llvm.store %2, %arg1 : !llvm.ptr
     omp.terminator
   }
   llvm.return
@@ -449,7 +453,8 @@ llvm.func @sub_() {
 // CHECK: %[[C_14:.*]] = llvm.mlir.constant(1 : index) : i64
 // CHECK: %[[BOUNDS1:.*]] = omp.bounds   lower_bound(%[[C_12]] : i64) upper_bound(%[[C_11]] : i64) stride(%[[C_14]] : i64) start_idx(%[[C_14]] : i64)
 // CHECK: %[[MAP1:.*]] = omp.map_info var_ptr(%[[ARG_2]] : !llvm.ptr, !llvm.array<10 x i32>)   map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !llvm.ptr {name = ""}
-// CHECK: omp.target   map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr, !llvm.ptr) {
+// CHECK: omp.target   map_entries(%[[MAP0]] -> %[[BB_ARG0:.*]], %[[MAP1]]  -> %[[BB_ARG1:.*]] : !llvm.ptr, !llvm.ptr) {
+// CHECK: ^bb0(%[[BB_ARG0]]: !llvm.ptr, %[[BB_ARG1]]: !llvm.ptr):
 // CHECK:   omp.terminator
 // CHECK: }
 // CHECK: llvm.return
@@ -468,7 +473,8 @@ llvm.func @_QPtarget_map_with_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2:
   %9 = llvm.mlir.constant(1 : index) : i64
   %10 = omp.bounds   lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%9 : i64) start_idx(%9 : i64)
   %11 = omp.map_info var_ptr(%arg2 : !llvm.ptr, !llvm.array<10 x i32>)   map_clauses(tofrom) capture(ByRef) bounds(%10) -> !llvm.ptr {name = ""}
-  omp.target   map_entries(%5, %11 : !llvm.ptr, !llvm.ptr) {
+  omp.target   map_entries(%5 -> %arg3, %11 -> %arg4: !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg3: !llvm.ptr, %arg4: !llvm.ptr):
     omp.terminator
   }
   llvm.return
diff --git a/mlir/test/Dialect/OpenMP/canonicalize.mlir b/mlir/test/Dialect/OpenMP/canonicalize.mlir
index 8aff8f81188be50..de6c931ecc5fd92 100644
--- a/mlir/test/Dialect/OpenMP/canonicalize.mlir
+++ b/mlir/test/Dialect/OpenMP/canonicalize.mlir
@@ -131,8 +131,9 @@ func.func private @foo() -> ()
 
 func.func @constant_hoisting_target(%x : !llvm.ptr) {
   omp.target {
+    ^bb0(%arg0: !llvm.ptr):
     %c1 = arith.constant 10 : i32
-    llvm.store %c1, %x : i32, !llvm.ptr
+    llvm.store %c1, %arg0 : i32, !llvm.ptr
     omp.terminator
   }
   return
@@ -141,4 +142,4 @@ func.func @constant_hoisting_target(%x : !llvm.ptr) {
 // CHECK-LABEL: func.func @constant_hoisting_target
 // CHECK-NOT: arith.constant
 // CHECK: omp.target
-// CHECK-NEXT: arith.constant
+// CHECK: arith.constant
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 6f75f2a62e64136..42e9fb1c64baec7 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1617,7 +1617,9 @@ func.func @omp_threadprivate() {
 func.func @omp_target(%map1: memref<?xi32>) {
   %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(delete) capture(ByRef) -> memref<?xi32> {name = ""}
   // expected-error @below {{to, from, tofrom and alloc map types are permitted}}
-  omp.target map_entries(%mapv : memref<?xi32>){}
+  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>):
+  }
   return
 }
 
@@ -1656,4 +1658,40 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
   return
 }
 
+// -----
+
+func.func @omp_map1(%map1: memref<?xi32>, %map2: i32) {
+  %mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) val(%map2 : i32)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  // expected-error @below {{only one of val or var_ptr must be used}}
+  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>):
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_map2(%map1: memref<?xi32>, %map2: i32) {
+  %mapv = omp.map_info var_ptr( : , tensor<?xi32>) val(%map2 : i32)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  // expected-error @below {{var_type supplied without var_ptr}}
+  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>):
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_map3(%map1: memref<?xi32>, %map2: i32) {
+  %mapv = omp.map_info   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  // expected-error @below {{missing val or var_ptr}}
+  omp.target map_entries(%mapv -> %arg0: memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>):
+    omp.terminator
+  }
+  return
+}
+
 llvm.mlir.global internal @_QFsubEx() : i32
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d59a4f428118bf3..4d88d9ac86fe16c 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -492,16 +492,22 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %map1:
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
     // CHECK: %[[MAP_B:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target map_entries(%[[MAP_A]], %[[MAP_B]] : memref<?xi32>, memref<?xi32>) {
+    // CHECK: omp.target map_entries(%[[MAP_A]] -> {{.*}}, %[[MAP_B]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
     %mapv1 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
     %mapv2 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>)   map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target map_entries(%mapv1, %mapv2 : memref<?xi32>, memref<?xi32>){}
+    omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
+      omp.terminator
+    }
     // CHECK: %[[MAP_C:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
     // CHECK: %[[MAP_D:.*]] = omp.map_info var_ptr(%[[VAL_2:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
-    // CHECK: omp.target map_entries(%[[MAP_C]], %[[MAP_D]] : memref<?xi32>, memref<?xi32>) {
+    // CHECK: omp.target map_entries(%[[MAP_C]] -> {{.*}}, %[[MAP_D]] -> {{.*}} : memref<?xi32>, memref<?xi32>) {
     %mapv3 = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>)   map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
     %mapv4 = omp.map_info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>)   map_clauses(always, from) capture(ByRef) -> memref<?xi32> {name = ""}
-    omp.target map_entries(%mapv3, %mapv4 : memref<?xi32>, memref<?xi32>) {}
+    omp.target map_entries(%mapv3 -> %arg0, %mapv4 -> %arg1 : memref<?xi32>, memref<?xi32>) {
+    ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>):
+      omp.terminator
+    }
     // CHECK: omp.barrier
     omp.barrier
 
@@ -2055,8 +2061,11 @@ func.func @omp_targets_with_map_bounds(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
     %10 = omp.bounds   lower_bound(%7 : i64) upper_bound(%6 : i64) stride(%8 : i64) start_idx(%9 : i64)
     %mapv2 = omp.map_info var_ptr(%arg1 : !llvm.ptr, !llvm.array<10 x i32>)   map_clauses(exit_release_or_enter_alloc) capture(ByCopy) bounds(%10) -> !llvm.ptr {name = ""}
 
-    // CHECK: omp.target map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr, !llvm.ptr)
-    omp.target map_entries(%mapv1, %mapv2 : !llvm.ptr, !llvm.ptr){}
+    // CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
+    omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+      omp.terminator
+    }
 
     // CHECK: omp.target_data map_entries(%[[MAP0]], %[[MAP1]] : !llvm.ptr, !llvm.ptr)
     omp.target_data map_entries(%mapv1, %mapv2 : !llvm.ptr, !llvm.ptr){}
diff --git a/mlir/test/Target/LLVMIR/omptarget-array-sectioning-host.mlir b/mlir/test/Target/LLVMIR/omptarget-array-sectioning-host.mlir
index 056085123480baf..307d8a02ce61daf 100644
--- a/mlir/test/Target/LLVMIR/omptarget-array-sectioning-host.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-array-sectioning-host.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 // This test checks the offload sizes provided to the OpenMP kernel argument
-// structure are correct when lowering to LLVM-IR from MLIR with 3-D bounds 
-// provided for a 3-D array. One with full default size, and the other with 
-// a user specified OpenMP array sectioning. We expect the default sized 
-// array bounds to lower to the full size of the array and the sectioned 
+// structure are correct when lowering to LLVM-IR from MLIR with 3-D bounds
+// provided for a 3-D array. One with full default size, and the other with
+// a user specified OpenMP array sectioning. We expect the default sized
+// array bounds to lower to the full size of the array and the sectioned
 // array to be the size of 3*3*1*element-byte-size (36 bytes in this case).
 
 module attributes {omp.is_target_device = false} {
@@ -18,12 +18,13 @@ module attributes {omp.is_target_device = false} {
     %6 = omp.bounds   lower_bound(%2 : i64) upper_bound(%2 : i64) stride(%2 : i64) start_idx(%2 : i64)
     %7 = omp.map_info var_ptr(%0 : !llvm.ptr, !llvm.array<3 x array<3 x array<3 x i32>>>)   map_clauses(tofrom) capture(ByRef) bounds(%5, %5, %6) -> !llvm.ptr {name = "inarray(1:3,1:3,2:2)"}
     %8 = omp.map_info var_ptr(%1 : !llvm.ptr, !llvm.array<3 x array<3 x array<3 x i32>>>)   map_clauses(tofrom) capture(ByRef) bounds(%5, %5, %5) -> !llvm.ptr {name = "outarray(1:3,1:3,1:3)"}
-    omp.target   map_entries(%7, %8 : !llvm.ptr, !llvm.ptr) {
+    omp.target   map_entries(%7 -> %arg0, %8 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+      ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
       %9 = llvm.mlir.constant(0 : i64) : i64
       %10 = llvm.mlir.constant(1 : i64) : i64
-      %11 = llvm.getelementptr %0[0, %10, %9, %9] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr, !llvm.array<3 x array<3 x array<3 x i32>>>
+      %11 = llvm.getelementptr %arg0[0, %10, %9, %9] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr, !llvm.array<3 x array<3 x array<3 x i32>>>
       %12 = llvm.load %11 : !llvm.ptr -> i32
-      %13 = llvm.getelementptr %1[0, %10, %9, %9] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr, !llvm.array<3 x array<3 x array<3 x i32>>>
+      %13 = llvm.getelementptr %arg1[0, %10, %9, %9] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr, !llvm.array<3 x array<3 x array<3 x i32>>>
       llvm.store %12, %13 : i32, !llvm.ptr
       omp.terminator
     }
diff --git a/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir b/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir
index c0c8640bb30bdad..875d04f584ca96f 100644
--- a/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-device.mlir
@@ -6,9 +6,10 @@ module attributes {omp.is_target_device = true} {
     %1 = llvm.mlir.addressof @_QFEsp : !llvm.ptr
     %2 = omp.map_info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "sp"}
     %3 = omp.map_info var_ptr(%0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr {name = "i"}
-    omp.target map_entries(%2, %3 : !llvm.ptr, !llvm.ptr) {
-      %4 = llvm.load %0 : !llvm.ptr -> i32
-      llvm.store %4, %1 : i32, !llvm.ptr
+    omp.target map_entries(%2 -> %arg0, %3 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+      ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+      %4 = llvm.load %arg1 : !llvm.ptr -> i32
+      llvm.store %4, %arg0 : i32, !llvm.ptr
       omp.terminator
     }
     llvm.return
@@ -32,7 +33,7 @@ module attributes {omp.is_target_device = true} {
 // CHECK: store ptr %[[ARG_BYCOPY]], ptr %[[ALLOCA_BYCOPY]], align 8
 
 // CHECK: user_code.entry:                                  ; preds = %entry
-// CHECK: %[[LOAD_BYREF:.*]] = load ptr, ptr %[[ALLOCA_BYREF]], align 8 
+// CHECK: %[[LOAD_BYREF:.*]] = load ptr, ptr %[[ALLOCA_BYREF]], align 8
 // CHECK: br label %omp.target
 
 // CHECK: omp.target:                                       ; preds = %user_code.entry
diff --git a/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-host.mlir b/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-host.mlir
index ca5dad8b4fc9a85..c8fb4e232f06f50 100644
--- a/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-host.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-byref-bycopy-generation-host.mlir
@@ -6,9 +6,10 @@ module attributes {omp.is_target_device = false} {
     %1 = llvm.mlir.addressof @_QFEsp : !llvm.ptr
     %2 = omp.map_info var_ptr(%1 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "sp"}
     %3 = omp.map_info var_ptr(%0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr {name = "i"}
-    omp.target map_entries(%2, %3 : !llvm.ptr, !llvm.ptr) {
-      %4 = llvm.load %0 : !llvm.ptr -> i32
-      llvm.store %4, %1 : i32, !llvm.ptr
+    omp.target map_entries(%2 -> %arg0, %3 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+      ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+      %4 = llvm.load %arg1 : !llvm.ptr -> i32
+      llvm.store %4, %arg0 : i32, !llvm.ptr
       omp.terminator
     }
     llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir b/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir
index 24795cf70c009e7..cf08761981fb3a8 100644
--- a/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-declare-target-llvm-device.mlir
@@ -1,10 +1,10 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 // This test the generation of additional load operations for declare target link variables
-// inside of target op regions when lowering to IR for device. Unfortunately as the host file is not 
+// inside of target op regions when lowering to IR for device. Unfortunately as the host file is not
 // passed as a module attribute, we miss out on the metadata and entryinfo.
 //
-// Unfortunately, only so much can be tested as the device side is dependent on a *.bc 
+// Unfortunately, only so much can be tested as the device side is dependent on a *.bc
 // file created by the host and appended as an attribute to the module.
 
 module attributes {omp.is_target_device = true} {
@@ -13,18 +13,19 @@ module attributes {omp.is_target_device = true} {
     %0 = llvm.mlir.constant(0 : i32) : i32
     llvm.return %0 : i32
   }
-                                                            
+
   llvm.func @_QQmain() attributes {} {
     %0 = llvm.mlir.addressof @_QMtest_0Esp : !llvm.ptr
-  
+
   // CHECK-DAG:   omp.target:                                       ; preds = %user_code.entry
   // CHECK-DAG: %[[V:.*]] = load ptr, ptr @_QMtest_0Esp_decl_tgt_ref_ptr, align 8
   // CHECK-DAG: store i32 1, ptr %[[V]], align 4
   // CHECK-DAG: br label %omp.region.cont
     %map = omp.map_info var_ptr(%0 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-    omp.target   map_entries(%map : !llvm.ptr) {
+    omp.target   map_entries(%map -> %arg0 : !llvm.ptr) {
+      ^bb0(%arg0: !llvm.ptr):
       %1 = llvm.mlir.constant(1 : i32) : i32
-      llvm.store %1, %0 : i32, !llvm.ptr
+      llvm.store %1, %arg0 : i32, !llvm.ptr
       omp.terminator
     }
 
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index bd399ad935259c0..78bab6ece73e6b0 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -15,11 +15,12 @@ module attributes {omp.is_target_device = true} {
     %map1 = omp.map_info var_ptr(%3 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
     %map2 = omp.map_info var_ptr(%5 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
     %map3 = omp.map_info var_ptr(%7 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-    omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
-      %8 = llvm.load %3 : !llvm.ptr -> i32
-      %9 = llvm.load %5 : !llvm.ptr -> i32
+    omp.target map_entries(%map1 -> %arg0, %map2 -> %arg1, %map3 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr):
+      %8 = llvm.load %arg0 : !llvm.ptr -> i32
+      %9 = llvm.load %arg1 : !llvm.ptr -> i32
       %10 = llvm.add %8, %9  : i32
-      llvm.store %10, %7 : i32, !llvm.ptr
+      llvm.store %10, %arg2 : i32, !llvm.ptr
       omp.terminator
     }
     llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
index 2cd0331087ec01c..6fa039f522e2061 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm-target-device.mlir
@@ -3,10 +3,11 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
 module attributes {omp.is_target_device = true} {
-  llvm.func @writeindex_omp_outline_0_(%arg0: !llvm.ptr, %arg1: !llvm.ptr) attributes {omp.outline_parent_name = "writeindex_"} {
-    %0 = omp.map_info var_ptr(%arg0 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-    %1 = omp.map_info var_ptr(%arg1 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-    omp.target   map_entries(%0, %1 : !llvm.ptr, !llvm.ptr) {
+  llvm.func @writeindex_omp_outline_0_(%val0: !llvm.ptr, %val1: !llvm.ptr) attributes {omp.outline_parent_name = "writeindex_"} {
+    %0 = omp.map_info var_ptr(%val0 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+    %1 = omp.map_info var_ptr(%val1 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
+    omp.target   map_entries(%0 -> %arg0, %1 -> %arg1 : !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
       %2 = llvm.mlir.constant(20 : i32) : i32
       %3 = llvm.mlir.constant(10 : i32) : i32
       llvm.store %3, %arg0 : i32, !llvm.ptr
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
index 4e89b8585c75254..b861dd7a7d315fc 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-llvm.mlir
@@ -15,11 +15,12 @@ module attributes {omp.is_target_device = false} {
     %map1 = omp.map_info var_ptr(%3 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
     %map2 = omp.map_info var_ptr(%5 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
     %map3 = omp.map_info var_ptr(%7 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-    omp.target map_entries(%map1, %map2, %map3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
-      %8 = llvm.load %3 : !llvm.ptr -> i32
-      %9 = llvm.load %5 : !llvm.ptr -> i32
+    omp.target map_entries(%map1 -> %arg0, %map2 -> %arg1, %map3 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr):
+      %8 = llvm.load %arg0 : !llvm.ptr -> i32
+      %9 = llvm.load %arg1 : !llvm.ptr -> i32
       %10 = llvm.add %8, %9  : i32
-      llvm.store %10, %7 : i32, !llvm.ptr
+      llvm.store %10, %arg2 : i32, !llvm.ptr
       omp.terminator
     }
     llvm.return
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
index 1d8799ecd446f03..c80ea1f0a47be75 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-parallel-llvm.mlir
@@ -15,12 +15,13 @@ module attributes {omp.is_target_device = false} {
     %map1 = omp.map_info var_ptr(%3 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
     %map2 = omp.map_info var_ptr(%5 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
     %map3 = omp.map_info var_ptr(%7 : !llvm.ptr, i32)   map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
-    omp.target map_entries( %map1, %map2, %map3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
+    omp.target map_entries( %map1 -> %arg0, %map2 -> %arg1, %map3 -> %arg2 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
+    ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr):
       omp.parallel {
-        %8 = llvm.load %3 : !llvm.ptr -> i32
-        %9 = llvm.load %5 : !llvm.ptr -> i32
+        %8 = llvm.load %arg0 : !llvm.ptr -> i32
+        %9 = llvm.load %arg1 : !llvm.ptr -> i32
         %10 = llvm.add %8, %9  : i32
-        llvm.store %10, %7 : i32, !llvm.ptr
+        llvm.store %10, %arg2 : i32, !llvm.ptr
         omp.terminator
         }
       omp.terminator



More information about the Mlir-commits mailing list