[Mlir-commits] [mlir] f4f53f8 - [mlir][openacc] Use new reduction design in acc.parallel

Valentin Clement llvmlistbot at llvm.org
Wed May 24 10:38:30 PDT 2023


Author: Valentin Clement
Date: 2023-05-24T10:38:25-07:00
New Revision: f4f53f8b90184df8120bc841de49be5707964876

URL: https://github.com/llvm/llvm-project/commit/f4f53f8b90184df8120bc841de49be5707964876
DIFF: https://github.com/llvm/llvm-project/commit/f4f53f8b90184df8120bc841de49be5707964876.diff

LOG: [mlir][openacc] Use new reduction design in acc.parallel

After D150818 the reduction clause is represented
with a acc.reduction.recipe operation and an operand.
This patch updates the acc.parallel op for the new design.

Reviewed By: razvanlupusoru, jeanPerier

Differential Revision: https://reviews.llvm.org/D151146

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
    mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
    mlir/test/Dialect/OpenACC/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 2edc8bb969af1..fd4e9edb541fc 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -635,8 +635,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
                        Optional<I1>:$ifCond,
                        Optional<I1>:$selfCond,
                        UnitAttr:$selfAttr,
-                       OptionalAttr<OpenACC_ReductionOperatorAttr>:$reductionOp,
                        Variadic<AnyType>:$reductionOperands,
+                       OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
                        Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
                        OptionalAttr<SymbolRefArrayAttr>:$privatizations,
                        Variadic<AnyType>:$gangFirstPrivateOperands,
@@ -661,14 +661,16 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
             type($gangFirstPrivateOperands) `)`
       | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
       | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
-      | `private` `(` custom<PrivatizationList>(
+      | `private` `(` custom<SymOperandList>(
             $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
         `)`
       | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
       | `wait` `(` $waitOperands `:` type($waitOperands) `)`
       | `self` `(` $selfCond `)`
       | `if` `(` $ifCond `)`
-      | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)`
+      | `reduction` `(` custom<SymOperandList>(
+            $reductionOperands, type($reductionOperands), $reductionRecipes)
+        `)`
     )
     $region attr-dict-with-keyword
   }];
@@ -727,7 +729,7 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
       | `async` `(` $async `:` type($async) `)`
       | `firstprivate` `(` $gangFirstPrivateOperands `:`
             type($gangFirstPrivateOperands) `)`
-      | `private` `(` custom<PrivatizationList>(
+      | `private` `(` custom<SymOperandList>(
             $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
         `)`
       | `wait` `(` $waitOperands `:` type($waitOperands) `)`
@@ -1061,7 +1063,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
         `gang` `` custom<GangClause>($gangNum, type($gangNum), $gangStatic, type($gangStatic), $hasGang)
       | `worker` `` custom<WorkerClause>($workerNum, type($workerNum), $hasWorker)
       | `vector` `` custom<VectorClause>($vectorLength, type($vectorLength), $hasVector)
-      | `private` `(` custom<PrivatizationList>(
+      | `private` `(` custom<SymOperandList>(
             $privateOperands, type($privateOperands), $privatizations)
         `)`
       | `tile` `(` $tileOperands `:` type($tileOperands) `)`

diff  --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index ac33c3a84e00b..5e5a00e141a47 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -443,13 +443,13 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
 // Custom parser and printer verifier for private clause
 //===----------------------------------------------------------------------===//
 
-static ParseResult parsePrivatizationList(
+static ParseResult parseSymOperandList(
     mlir::OpAsmParser &parser,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
-    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &privatizationSymbols) {
-  llvm::SmallVector<SymbolRefAttr> privatizationVec;
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
+  llvm::SmallVector<SymbolRefAttr> attributes;
   if (failed(parser.parseCommaSeparatedList([&]() {
-        if (parser.parseAttribute(privatizationVec.emplace_back()) ||
+        if (parser.parseAttribute(attributes.emplace_back()) ||
             parser.parseArrow() ||
             parser.parseOperand(operands.emplace_back()) ||
             parser.parseColonType(types.emplace_back()))
@@ -457,22 +457,21 @@ static ParseResult parsePrivatizationList(
         return success();
       })))
     return failure();
-  llvm::SmallVector<mlir::Attribute> privatizations(privatizationVec.begin(),
-                                                    privatizationVec.end());
-  privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations);
+  llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+                                               attributes.end());
+  symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
   return success();
 }
 
-static void
-printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op,
-                       mlir::OperandRange privateOperands,
-                       mlir::TypeRange privateTypes,
-                       std::optional<mlir::ArrayAttr> privatizations) {
-  for (unsigned i = 0, e = privatizations->size(); i < e; ++i) {
+static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                                mlir::OperandRange operands,
+                                mlir::TypeRange types,
+                                std::optional<mlir::ArrayAttr> attributes) {
+  for (unsigned i = 0, e = attributes->size(); i < e; ++i) {
     if (i != 0)
       p << ", ";
-    p << (*privatizations)[i] << " -> " << privateOperands[i] << " : "
-      << privateOperands[i].getType();
+    p << (*attributes)[i] << " -> " << operands[i] << " : "
+      << operands[i].getType();
   }
 }
 
@@ -495,40 +494,43 @@ static LogicalResult checkDataOperands(Op op,
   return success();
 }
 
+template <typename Op>
 static LogicalResult
-checkPrivatizationList(Operation *op,
-                       std::optional<mlir::ArrayAttr> privatizations,
-                       mlir::OperandRange privateOperands) {
-  if (!privateOperands.empty()) {
-    if (!privatizations || privatizations->size() != privateOperands.size())
-      return op->emitOpError() << "expected as many privatizations symbol "
-                                  "reference as private operands";
+checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
+                    mlir::OperandRange operands, llvm::StringRef operandName,
+                    llvm::StringRef symbolName) {
+  if (!operands.empty()) {
+    if (!attributes || attributes->size() != operands.size())
+      return op->emitOpError()
+             << "expected as many " << symbolName << " symbol reference as "
+             << operandName << " operands";
   } else {
-    if (privatizations)
-      return op->emitOpError() << "unexpected privatizations symbol reference";
+    if (attributes)
+      return op->emitOpError()
+             << "unexpected " << symbolName << " symbol reference";
     return success();
   }
 
-  llvm::DenseSet<Value> privates;
-  for (auto args : llvm::zip(privateOperands, *privatizations)) {
-    mlir::Value privateOperand = std::get<0>(args);
+  llvm::DenseSet<Value> set;
+  for (auto args : llvm::zip(operands, *attributes)) {
+    mlir::Value operand = std::get<0>(args);
 
-    if (!privates.insert(privateOperand).second)
-      return op->emitOpError() << "private operand appears more than once";
+    if (!set.insert(operand).second)
+      return op->emitOpError()
+             << operandName << " operand appears more than once";
 
-    mlir::Type varType = privateOperand.getType();
+    mlir::Type varType = operand.getType();
     auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
-    auto decl =
-        SymbolTable::lookupNearestSymbolFrom<PrivateRecipeOp>(op, symbolRef);
+    auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
     if (!decl)
-      return op->emitOpError() << "expected symbol reference " << symbolRef
-                               << " to point to a private declaration";
+      return op->emitOpError()
+             << "expected symbol reference " << symbolRef << " to point to a "
+             << operandName << " declaration";
 
     if (decl.getType() && decl.getType() != varType)
       return op->emitOpError()
-             << "expected private (" << varType
-             << ") to be the same type as private declaration ("
-             << decl.getType() << ")";
+             << "expected private (" << varType << ") to be the same type as "
+             << operandName << " declaration (" << decl.getType() << ")";
   }
 
   return success();
@@ -550,8 +552,13 @@ Value ParallelOp::getDataOperand(unsigned i) {
 }
 
 LogicalResult acc::ParallelOp::verify() {
-  if (failed(checkPrivatizationList(*this, getPrivatizations(),
-                                    getGangPrivateOperands())))
+  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
+          *this, getPrivatizations(), getGangPrivateOperands(), "private",
+          "privatizations")))
+    return failure();
+  if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
+          *this, getReductionRecipes(), getReductionOperands(), "reduction",
+          "reductions")))
     return failure();
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
@@ -573,6 +580,7 @@ Value SerialOp::getDataOperand(unsigned i) {
 }
 
 LogicalResult acc::SerialOp::verify() {
+
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }
 
@@ -729,8 +737,9 @@ LogicalResult acc::LoopOp::verify() {
   if (getSeq() && (getHasGang() || getHasWorker() || getHasVector()))
     return emitError("gang, worker or vector cannot appear with the seq attr");
 
-  if (failed(checkPrivatizationList(*this, getPrivatizations(),
-                                    getPrivateOperands())))
+  if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
+          *this, getPrivatizations(), getPrivateOperands(), "private",
+          "privatizations")))
     return failure();
 
   // Check non-empty body().

diff  --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 032e09dc4e3e1..5d9d59abbd14b 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1430,3 +1430,13 @@ acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
 // CHECK:         %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i64
 // CHECK:         acc.yield %[[RES]] : i64
 // CHECK:       }
+
+func.func @acc_reduc_test(%a : i64) -> () {
+  acc.parallel reduction(@reduction_add_i64 -> %a : i64) {
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @acc_reduc_test(
+// CHECK-SAME:    %[[ARG0:.*]]: i64)
+// CHECK:         acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64)


        


More information about the Mlir-commits mailing list