[Mlir-commits] [mlir] [acc] Consistency between acc.loop and acc compute ops (PR #114549)

Razvan Lupusoru llvmlistbot at llvm.org
Fri Nov 1 08:25:43 PDT 2024


https://github.com/razvanlupusoru created https://github.com/llvm/llvm-project/pull/114549

- GangPrivate and GangFirstPrivate renamed to just Private and Firstprivate respectively. This is makes compute ops consistent with the loop op (and also with the acc spec wording for the clause).
- Added getBody to all compute ops
- Verifier for firstprivate ops / recipes is enabled

>From e1fd00394ebdd38791ac15424f1f9a2672268e7b Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Fri, 1 Nov 2024 07:51:15 -0700
Subject: [PATCH] [acc] Consistency between acc.loop and acc compute ops

- GangPrivate and GangFirstPrivate renamed to just Private and Firstprivate
respectively. This is makes compute ops consistent with the loop op (and
also with the acc spec wording for the clause).
- Added getBody to all compute ops
- Verifier for firstprivate ops / recipes is enabled
---
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 34 ++++++++++++-------
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 20 +++++++----
 .../OpenACC/Transforms/LegalizeDataValues.cpp |  4 +--
 3 files changed, 38 insertions(+), 20 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index d9f38259c0ace0..e305e2fbde5b17 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1114,9 +1114,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
       UnitAttr:$selfAttr,
       Variadic<AnyType>:$reductionOperands,
       OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
-      Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
-      Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$firstprivateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       OptionalAttr<DefaultValueAttr>:$defaultAttr,
@@ -1134,8 +1134,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
       CArg<"mlir::Value", "{}">:$ifCond,
       CArg<"mlir::Value", "{}">:$selfCond,
       CArg<"mlir::ValueRange", "{}">:$reductionOperands,
-      CArg<"mlir::ValueRange", "{}">:$gangPrivateOperands,
-      CArg<"mlir::ValueRange", "{}">:$gangFirstPrivateOperands,
+      CArg<"mlir::ValueRange", "{}">:$privateOperands,
+      CArg<"mlir::ValueRange", "{}">:$firstprivateOperands,
       CArg<"mlir::ValueRange", "{}">:$dataClauseOperands)>];
 
   let extraClassDeclaration = [{
@@ -1145,6 +1145,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
 
+    /// Used to retrieve the block inside the op's region.
+    Block &getBody() { return getRegion().front(); }
+
     /// Return true if the op has the async attribute for the
     /// mlir::acc::DeviceType::None device_type.
     bool hasAsyncOnly();
@@ -1202,15 +1205,15 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
-      | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
-            type($gangFirstPrivateOperands), $firstprivatizations)
+      | `firstprivate` `(` custom<SymOperandList>($firstprivateOperands,
+            type($firstprivateOperands), $firstprivatizations)
         `)`
       | `num_gangs` `(` custom<NumGangs>($numGangs,
             type($numGangs), $numGangsDeviceType, $numGangsSegments) `)`
       | `num_workers` `(` custom<DeviceTypeOperands>($numWorkers,
             type($numWorkers), $numWorkersDeviceType) `)`
       | `private` `(` custom<SymOperandList>(
-            $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
+            $privateOperands, type($privateOperands), $privatizations)
         `)`
       | `vector_length` `(` custom<DeviceTypeOperands>($vectorLength,
             type($vectorLength), $vectorLengthDeviceType) `)`
@@ -1271,9 +1274,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
       UnitAttr:$selfAttr,
       Variadic<AnyType>:$reductionOperands,
       OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
-      Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
-      Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
+      Variadic<OpenACC_PointerLikeTypeInterface>:$firstprivateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       OptionalAttr<DefaultValueAttr>:$defaultAttr,
@@ -1288,6 +1291,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
 
+    /// Used to retrieve the block inside the op's region.
+    Block &getBody() { return getRegion().front(); }
+
     /// Return true if the op has the async attribute for the
     /// mlir::acc::DeviceType::None device_type.
     bool hasAsyncOnly();
@@ -1326,11 +1332,11 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
         `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
-      | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
-            type($gangFirstPrivateOperands), $firstprivatizations)
+      | `firstprivate` `(` custom<SymOperandList>($firstprivateOperands,
+            type($firstprivateOperands), $firstprivatizations)
         `)`
       | `private` `(` custom<SymOperandList>(
-            $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
+            $privateOperands, type($privateOperands), $privatizations)
         `)`
       | `wait` `` custom<WaitClause>($waitOperands, type($waitOperands),
           $waitOperandsDeviceType, $waitOperandsSegments, $hasWaitDevnum,
@@ -1410,6 +1416,9 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
 
+    /// Used to retrieve the block inside the op's region.
+    Block &getBody() { return getRegion().front(); }
+
     /// Return true if the op has the async attribute for the
     /// mlir::acc::DeviceType::None device_type.
     bool hasAsyncOnly();
@@ -1824,6 +1833,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
     /// The i-th data operand passed.
     Value getDataOperand(unsigned i);
 
+    /// Used to retrieve the block inside the op's region.
     Block &getBody() { return getLoopRegions().front()->front(); }
 
     /// Return true if the op has the auto attribute for the
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 919a0853fb6049..280260e0485bb5 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -730,8 +730,8 @@ checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
 }
 
 unsigned ParallelOp::getNumDataOperands() {
-  return getReductionOperands().size() + getGangPrivateOperands().size() +
-         getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
+  return getReductionOperands().size() + getPrivateOperands().size() +
+         getFirstprivateOperands().size() + getDataClauseOperands().size();
 }
 
 Value ParallelOp::getDataOperand(unsigned i) {
@@ -783,9 +783,13 @@ static LogicalResult verifyDeviceTypeAndSegmentCountMatch(
 
 LogicalResult acc::ParallelOp::verify() {
   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
-          *this, getPrivatizations(), getGangPrivateOperands(), "private",
+          *this, getPrivatizations(), getPrivateOperands(), "private",
           "privatizations", /*checkOperandType=*/false)))
     return failure();
+  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
+          *this, getFirstprivatizations(), getFirstprivateOperands(),
+          "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+    return failure();
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
           "reductions", false)))
@@ -1361,8 +1365,8 @@ printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
 //===----------------------------------------------------------------------===//
 
 unsigned SerialOp::getNumDataOperands() {
-  return getReductionOperands().size() + getGangPrivateOperands().size() +
-         getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
+  return getReductionOperands().size() + getPrivateOperands().size() +
+         getFirstprivateOperands().size() + getDataClauseOperands().size();
 }
 
 Value SerialOp::getDataOperand(unsigned i) {
@@ -1420,9 +1424,13 @@ mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
 
 LogicalResult acc::SerialOp::verify() {
   if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
-          *this, getPrivatizations(), getGangPrivateOperands(), "private",
+          *this, getPrivatizations(), getPrivateOperands(), "private",
           "privatizations", /*checkOperandType=*/false)))
     return failure();
+  if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
+          *this, getFirstprivatizations(), getFirstprivateOperands(),
+          "firstprivate", "firstprivatizations", /*checkOperandType=*/false)))
+    return failure();
   if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
           "reductions", false)))
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
index 4038e333adb8b6..026b309ce4969d 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/LegalizeDataValues.cpp
@@ -83,8 +83,8 @@ static void collectAndReplaceInRegion(Op &op, bool hostToDevice) {
                   !std::is_same_v<Op, acc::DataOp> &&
                   !std::is_same_v<Op, acc::DeclareOp>) {
       collectPtrs(op.getReductionOperands(), values, hostToDevice);
-      collectPtrs(op.getGangPrivateOperands(), values, hostToDevice);
-      collectPtrs(op.getGangFirstPrivateOperands(), values, hostToDevice);
+      collectPtrs(op.getPrivateOperands(), values, hostToDevice);
+      collectPtrs(op.getFirstprivateOperands(), values, hostToDevice);
     }
   }
 



More information about the Mlir-commits mailing list