[Mlir-commits] [mlir] [OpenMP][MLIR] Add `private` clause to `omp.target` (PR #91202)

Kareem Ergawy llvmlistbot at llvm.org
Wed May 8 00:21:43 PDT 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/91202

>From d53d96bc8e110e26172dce0eca6dd7ebbae3d9a1 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Mon, 6 May 2024 08:12:55 -0500
Subject: [PATCH 1/3] [OpenMP][MLIR] Add `private` clause to `omp.target`

Starts the effort to support delayed privatization for `omp.target`.
This PR extends the `omp.target` MLIR op with a `private` clause similar
to what we currently have for `omp.parallel` in order to model
privatized variables.
---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  6 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 55 +++++++++++++++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         |  2 +-
 mlir/test/Dialect/OpenMP/ops.mlir             | 40 +++++++++++++-
 4 files changed, 96 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a40676d071e62..a641588eaa8d4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1787,7 +1787,10 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
                        UnitAttr:$nowait,
                        Variadic<OpenMP_PointerLikeType>:$is_device_ptr,
                        Variadic<OpenMP_PointerLikeType>:$has_device_addr,
-                       Variadic<AnyType>:$map_operands);
+                       Variadic<AnyType>:$map_operands,
+                       Variadic<AnyType>:$private_vars,
+                       OptionalAttr<SymbolRefArrayAttr>:$privatizers);
+
   let regions = (region AnyRegion:$region);
 
   let builders = [
@@ -1802,6 +1805,7 @@ def TargetOp : OpenMP_Op<"target", [IsolatedFromAbove, MapClauseOwningOpInterfac
     | `is_device_ptr` `(` $is_device_ptr `:` type($is_device_ptr) `)`
     | `has_device_addr` `(` $has_device_addr `:` type($has_device_addr) `)`
     | `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
+    | `private` `(` custom<PrivateList>($private_vars, type($private_vars), $privatizers) `)`
     | `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
     ) $region attr-dict
   }];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0799090cdea98..cedcc40864d66 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,14 +469,18 @@ ParseResult parseClauseWithRegionArgs(
 static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
                                       ValueRange argsSubrange,
                                       StringRef clauseName, ValueRange operands,
-                                      TypeRange types, ArrayAttr symbols) {
-  p << clauseName << "(";
+                                      TypeRange types, ArrayAttr symbols,
+                                      bool printPrefixSuffix = true) {
+  if (printPrefixSuffix)
+    p << clauseName << "(";
+
   llvm::interleaveComma(
       llvm::zip_equal(symbols, operands, argsSubrange, types), p, [&p](auto t) {
         auto [sym, op, arg, type] = t;
         p << sym << " " << op << " -> " << arg << " : " << type;
       });
-  p << ") ";
+  if (printPrefixSuffix)
+    p << ") ";
 }
 
 static ParseResult parseParallelRegion(
@@ -1048,6 +1052,48 @@ static void printMapEntries(OpAsmPrinter &p, Operation *op,
   }
 }
 
+static ParseResult parsePrivateList(
+    OpAsmParser &parser,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
+    SmallVectorImpl<Type> &privateOperandTypes, ArrayAttr &privatizerSymbols) {
+  SmallVector<SymbolRefAttr> privateSymRefs;
+  SmallVector<OpAsmParser::Argument> regionPrivateArgs;
+
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(privateSymRefs.emplace_back()) ||
+            parser.parseOperand(privateOperands.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseArgument(regionPrivateArgs.emplace_back()) ||
+            parser.parseColonType(privateOperandTypes.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
+
+  SmallVector<Attribute> privateSymAttrs(privateSymRefs.begin(),
+                                         privateSymRefs.end());
+  privatizerSymbols = ArrayAttr::get(parser.getContext(), privateSymAttrs);
+
+  return success();
+}
+
+static void printPrivateList(OpAsmPrinter &p, Operation *op,
+                             ValueRange privateVarOperands,
+                             TypeRange privateVarTypes,
+                             ArrayAttr privatizerSymbols) {
+  auto targetOp = mlir::dyn_cast<mlir::omp::TargetOp>(op);
+  assert(targetOp);
+
+  auto &region = op->getRegion(0);
+  auto *argsBegin = region.front().getArguments().begin();
+  MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
+                               argsBegin + targetOp.getMapOperands().size() +
+                                   privateVarTypes.size());
+  printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
+                            privateVarTypes, privatizerSymbols,
+                            /*printPrefixSuffix=*/false);
+}
+
 static void printCaptureType(OpAsmPrinter &p, Operation *op,
                              VariableCaptureKindAttr mapCaptureType) {
   std::string typeCapStr;
@@ -1262,7 +1308,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
       builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
       makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
       clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
-      clauses.mapVars);
+      clauses.mapVars, clauses.privateVars,
+      ArrayAttr::get(builder.getContext(), clauses.privatizers));
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 511e7d396c687..138c2c9d418dc 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2087,7 +2087,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 60fc10f9d64b7..f0b76c117a456 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0, 0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2550,3 +2550,41 @@ func.func @parallel_op_reduction_and_private(%priv_var: !llvm.ptr, %priv_var2: !
   }
   return
 }
+
+// CHECK-LABEL: omp_target_private
+func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
+  %mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
+
+  // CHECK: omp.target
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  // CHECK: omp.target
+
+  // CHECK-SAME: map_entries(
+  // CHECK-SAME:   %[[MAP1_VAR:[^[:space:]]+]] -> %[[MAP1_ARG:[^[:space:]]+]],
+  // CHECK-SAME:   %[[MAP2_VAR:[^[:space:]]+]] -> %[[MAP2_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : memref<?xi32>, memref<?xi32>
+  // CHECK-SAME: )
+
+  // CHECK-SAME: private(
+  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   : !llvm.ptr
+  // CHECK-SAME: )
+  omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
+  // CHECK: ^bb0(%[[MAP1_ARG]]: memref<?xi32>, %[[MAP2_ARG]]: memref<?xi32>
+  // CHECK-SAME: , %[[PRIV_ARG]]: !llvm.ptr):
+  ^bb0(%arg0: memref<?xi32>, %arg1: memref<?xi32>, %priv_arg: !llvm.ptr):
+    omp.terminator
+  }
+
+  return
+}

>From d6f265b626969f51efc5f596ab8f82bbf0113674 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Tue, 7 May 2024 03:35:55 -0500
Subject: [PATCH 2/3] review comments

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cedcc40864d66..432a0b0600e03 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -469,9 +469,8 @@ ParseResult parseClauseWithRegionArgs(
 static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
                                       ValueRange argsSubrange,
                                       StringRef clauseName, ValueRange operands,
-                                      TypeRange types, ArrayAttr symbols,
-                                      bool printPrefixSuffix = true) {
-  if (printPrefixSuffix)
+                                      TypeRange types, ArrayAttr symbols) {
+  if (!clauseName.empty())
     p << clauseName << "(";
 
   llvm::interleaveComma(
@@ -479,7 +478,8 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, Operation *op,
         auto [sym, op, arg, type] = t;
         p << sym << " " << op << " -> " << arg << " : " << type;
       });
-  if (printPrefixSuffix)
+
+  if (!clauseName.empty())
     p << ") ";
 }
 
@@ -1089,9 +1089,9 @@ static void printPrivateList(OpAsmPrinter &p, Operation *op,
   MutableArrayRef argsSubrange(argsBegin + targetOp.getMapOperands().size(),
                                argsBegin + targetOp.getMapOperands().size() +
                                    privateVarTypes.size());
-  printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands,
-                            privateVarTypes, privatizerSymbols,
-                            /*printPrefixSuffix=*/false);
+  printClauseWithRegionArgs(
+      p, op, argsSubrange, /*clauseName=*/llvm::StringRef{}, privateVarOperands,
+      privateVarTypes, privatizerSymbols);
 }
 
 static void printCaptureType(OpAsmPrinter &p, Operation *op,
@@ -1302,14 +1302,14 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                      const TargetClauseOps &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
-  // inReductionDeclSymbols, privateVars, privatizers, reductionVars,
-  // reductionByRefAttr, reductionDeclSymbols.
+  // inReductionDeclSymbols, reductionVars, reductionByRefAttr,
+  // reductionDeclSymbols.
   TargetOp::build(
       builder, state, clauses.ifVar, clauses.deviceVar, clauses.threadLimitVar,
       makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
       clauses.nowaitAttr, clauses.isDevicePtrVars, clauses.hasDeviceAddrVars,
       clauses.mapVars, clauses.privateVars,
-      ArrayAttr::get(builder.getContext(), clauses.privatizers));
+      makeArrayAttr(ctx, clauses.privatizers));
 }
 
 LogicalResult TargetOp::verify() {

>From 42e78cf7228aeced33d94b1d603ef955bb439bf3 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Wed, 8 May 2024 02:21:29 -0500
Subject: [PATCH 3/3] review comments, Pranav

---
 mlir/test/Dialect/OpenMP/ops.mlir | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index f0b76c117a456..828c9d2c3b84f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -737,7 +737,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%if_cond, %device, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0, 0>} : ( i1, si32, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>} : ( i1, si32, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2558,7 +2558,7 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
 
   // CHECK: omp.target
   // CHECK-SAME: private(
-  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
   // CHECK-SAME:   : !llvm.ptr
   // CHECK-SAME: )
   omp.target private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {
@@ -2570,13 +2570,13 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   // CHECK: omp.target
 
   // CHECK-SAME: map_entries(
-  // CHECK-SAME:   %[[MAP1_VAR:[^[:space:]]+]] -> %[[MAP1_ARG:[^[:space:]]+]],
-  // CHECK-SAME:   %[[MAP2_VAR:[^[:space:]]+]] -> %[[MAP2_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]],
+  // CHECK-SAME:   %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]]
   // CHECK-SAME:   : memref<?xi32>, memref<?xi32>
   // CHECK-SAME: )
 
   // CHECK-SAME: private(
-  // CHECK-SAME:   @x.privatizer %[[PRIV_VAR:[^[:space:]]+]] -> %[[PRIV_ARG:[^[:space:]]+]]
+  // CHECK-SAME:   @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]]
   // CHECK-SAME:   : !llvm.ptr
   // CHECK-SAME: )
   omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg : !llvm.ptr) {



More information about the Mlir-commits mailing list