[Mlir-commits] [mlir] f4f53f8 - [mlir][openacc] Use new reduction design in acc.parallel
Valentin Clement
llvmlistbot at llvm.org
Wed May 24 10:38:30 PDT 2023
Author: Valentin Clement
Date: 2023-05-24T10:38:25-07:00
New Revision: f4f53f8b90184df8120bc841de49be5707964876
URL: https://github.com/llvm/llvm-project/commit/f4f53f8b90184df8120bc841de49be5707964876
DIFF: https://github.com/llvm/llvm-project/commit/f4f53f8b90184df8120bc841de49be5707964876.diff
LOG: [mlir][openacc] Use new reduction design in acc.parallel
After D150818 the reduction clause is represented
with a acc.reduction.recipe operation and an operand.
This patch updates the acc.parallel op for the new design.
Reviewed By: razvanlupusoru, jeanPerier
Differential Revision: https://reviews.llvm.org/D151146
Added:
Modified:
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 2edc8bb969af1..fd4e9edb541fc 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -635,8 +635,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
Optional<I1>:$ifCond,
Optional<I1>:$selfCond,
UnitAttr:$selfAttr,
- OptionalAttr<OpenACC_ReductionOperatorAttr>:$reductionOp,
Variadic<AnyType>:$reductionOperands,
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$gangFirstPrivateOperands,
@@ -661,14 +661,16 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
type($gangFirstPrivateOperands) `)`
| `num_gangs` `(` $numGangs `:` type($numGangs) `)`
| `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
- | `private` `(` custom<PrivatizationList>(
+ | `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
| `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
| `self` `(` $selfCond `)`
| `if` `(` $ifCond `)`
- | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)`
+ | `reduction` `(` custom<SymOperandList>(
+ $reductionOperands, type($reductionOperands), $reductionRecipes)
+ `)`
)
$region attr-dict-with-keyword
}];
@@ -727,7 +729,7 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
| `async` `(` $async `:` type($async) `)`
| `firstprivate` `(` $gangFirstPrivateOperands `:`
type($gangFirstPrivateOperands) `)`
- | `private` `(` custom<PrivatizationList>(
+ | `private` `(` custom<SymOperandList>(
$gangPrivateOperands, type($gangPrivateOperands), $privatizations)
`)`
| `wait` `(` $waitOperands `:` type($waitOperands) `)`
@@ -1061,7 +1063,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
`gang` `` custom<GangClause>($gangNum, type($gangNum), $gangStatic, type($gangStatic), $hasGang)
| `worker` `` custom<WorkerClause>($workerNum, type($workerNum), $hasWorker)
| `vector` `` custom<VectorClause>($vectorLength, type($vectorLength), $hasVector)
- | `private` `(` custom<PrivatizationList>(
+ | `private` `(` custom<SymOperandList>(
$privateOperands, type($privateOperands), $privatizations)
`)`
| `tile` `(` $tileOperands `:` type($tileOperands) `)`
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index ac33c3a84e00b..5e5a00e141a47 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -443,13 +443,13 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
// Custom parser and printer verifier for private clause
//===----------------------------------------------------------------------===//
-static ParseResult parsePrivatizationList(
+static ParseResult parseSymOperandList(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
- llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &privatizationSymbols) {
- llvm::SmallVector<SymbolRefAttr> privatizationVec;
+ llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &symbols) {
+ llvm::SmallVector<SymbolRefAttr> attributes;
if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseAttribute(privatizationVec.emplace_back()) ||
+ if (parser.parseAttribute(attributes.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
@@ -457,22 +457,21 @@ static ParseResult parsePrivatizationList(
return success();
})))
return failure();
- llvm::SmallVector<mlir::Attribute> privatizations(privatizationVec.begin(),
- privatizationVec.end());
- privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations);
+ llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
+ attributes.end());
+ symbols = ArrayAttr::get(parser.getContext(), arrayAttr);
return success();
}
-static void
-printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::OperandRange privateOperands,
- mlir::TypeRange privateTypes,
- std::optional<mlir::ArrayAttr> privatizations) {
- for (unsigned i = 0, e = privatizations->size(); i < e; ++i) {
+static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::OperandRange operands,
+ mlir::TypeRange types,
+ std::optional<mlir::ArrayAttr> attributes) {
+ for (unsigned i = 0, e = attributes->size(); i < e; ++i) {
if (i != 0)
p << ", ";
- p << (*privatizations)[i] << " -> " << privateOperands[i] << " : "
- << privateOperands[i].getType();
+ p << (*attributes)[i] << " -> " << operands[i] << " : "
+ << operands[i].getType();
}
}
@@ -495,40 +494,43 @@ static LogicalResult checkDataOperands(Op op,
return success();
}
+template <typename Op>
static LogicalResult
-checkPrivatizationList(Operation *op,
- std::optional<mlir::ArrayAttr> privatizations,
- mlir::OperandRange privateOperands) {
- if (!privateOperands.empty()) {
- if (!privatizations || privatizations->size() != privateOperands.size())
- return op->emitOpError() << "expected as many privatizations symbol "
- "reference as private operands";
+checkSymOperandList(Operation *op, std::optional<mlir::ArrayAttr> attributes,
+ mlir::OperandRange operands, llvm::StringRef operandName,
+ llvm::StringRef symbolName) {
+ if (!operands.empty()) {
+ if (!attributes || attributes->size() != operands.size())
+ return op->emitOpError()
+ << "expected as many " << symbolName << " symbol reference as "
+ << operandName << " operands";
} else {
- if (privatizations)
- return op->emitOpError() << "unexpected privatizations symbol reference";
+ if (attributes)
+ return op->emitOpError()
+ << "unexpected " << symbolName << " symbol reference";
return success();
}
- llvm::DenseSet<Value> privates;
- for (auto args : llvm::zip(privateOperands, *privatizations)) {
- mlir::Value privateOperand = std::get<0>(args);
+ llvm::DenseSet<Value> set;
+ for (auto args : llvm::zip(operands, *attributes)) {
+ mlir::Value operand = std::get<0>(args);
- if (!privates.insert(privateOperand).second)
- return op->emitOpError() << "private operand appears more than once";
+ if (!set.insert(operand).second)
+ return op->emitOpError()
+ << operandName << " operand appears more than once";
- mlir::Type varType = privateOperand.getType();
+ mlir::Type varType = operand.getType();
auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
- auto decl =
- SymbolTable::lookupNearestSymbolFrom<PrivateRecipeOp>(op, symbolRef);
+ auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
if (!decl)
- return op->emitOpError() << "expected symbol reference " << symbolRef
- << " to point to a private declaration";
+ return op->emitOpError()
+ << "expected symbol reference " << symbolRef << " to point to a "
+ << operandName << " declaration";
if (decl.getType() && decl.getType() != varType)
return op->emitOpError()
- << "expected private (" << varType
- << ") to be the same type as private declaration ("
- << decl.getType() << ")";
+ << "expected private (" << varType << ") to be the same type as "
+ << operandName << " declaration (" << decl.getType() << ")";
}
return success();
@@ -550,8 +552,13 @@ Value ParallelOp::getDataOperand(unsigned i) {
}
LogicalResult acc::ParallelOp::verify() {
- if (failed(checkPrivatizationList(*this, getPrivatizations(),
- getGangPrivateOperands())))
+ if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
+ *this, getPrivatizations(), getGangPrivateOperands(), "private",
+ "privatizations")))
+ return failure();
+ if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
+ *this, getReductionRecipes(), getReductionOperands(), "reduction",
+ "reductions")))
return failure();
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
@@ -573,6 +580,7 @@ Value SerialOp::getDataOperand(unsigned i) {
}
LogicalResult acc::SerialOp::verify() {
+
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}
@@ -729,8 +737,9 @@ LogicalResult acc::LoopOp::verify() {
if (getSeq() && (getHasGang() || getHasWorker() || getHasVector()))
return emitError("gang, worker or vector cannot appear with the seq attr");
- if (failed(checkPrivatizationList(*this, getPrivatizations(),
- getPrivateOperands())))
+ if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
+ *this, getPrivatizations(), getPrivateOperands(), "private",
+ "privatizations")))
return failure();
// Check non-empty body().
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 032e09dc4e3e1..5d9d59abbd14b 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1430,3 +1430,13 @@ acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
// CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i64
// CHECK: acc.yield %[[RES]] : i64
// CHECK: }
+
+func.func @acc_reduc_test(%a : i64) -> () {
+ acc.parallel reduction(@reduction_add_i64 -> %a : i64) {
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @acc_reduc_test(
+// CHECK-SAME: %[[ARG0:.*]]: i64)
+// CHECK: acc.parallel reduction(@reduction_add_i64 -> %[[ARG0]] : i64)
More information about the Mlir-commits
mailing list