[Mlir-commits] [mlir] [flang] [mlir][flang][openacc] Device type support on acc routine op (PR #78375)
Valentin Clement バレンタイン クレメン
llvmlistbot at llvm.org
Wed Jan 17 14:54:47 PST 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/78375
>From 74ca692f5ed54842d8f5b7b661765b82f92c64d6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 11 Jan 2024 13:07:31 -0800
Subject: [PATCH 1/2] [mlir][openacc] Device type support on acc routine op
---
flang/lib/Lower/OpenACC.cpp | 138 +++++++--
flang/test/Lower/OpenACC/acc-routine.f90 | 18 +-
.../mlir/Dialect/OpenACC/OpenACCOps.td | 58 +++-
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 287 ++++++++++++++++--
mlir/test/Dialect/OpenACC/ops.mlir | 4 +-
.../Dialect/OpenACC/OpenACCOpsTest.cpp | 58 ++++
6 files changed, 495 insertions(+), 68 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index db9ed72bc87257..fd89d27db74dc0 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -3469,6 +3469,72 @@ static void genACC(Fortran::lower::AbstractConverter &converter,
llvm_unreachable("unsupported declarative directive");
}
+static bool hasDeviceType(llvm::SmallVector<mlir::Attribute> &arrayAttr,
+ mlir::acc::DeviceType deviceType) {
+ for (auto attr : arrayAttr) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ if (deviceTypeAttr.getValue() == deviceType)
+ return true;
+ }
+ return false;
+}
+
+template <typename RetTy, typename AttrTy>
+static std::optional<RetTy>
+getAttributeValueByDeviceType(llvm::SmallVector<mlir::Attribute> &attributes,
+ llvm::SmallVector<mlir::Attribute> &deviceTypes,
+ mlir::acc::DeviceType deviceType) {
+ assert(attributes.size() == deviceTypes.size() &&
+ "expect same number of attributes");
+ for (auto it : llvm::enumerate(deviceTypes)) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(it.value());
+ if (deviceTypeAttr.getValue() == deviceType) {
+ if constexpr (std::is_same_v<mlir::StringAttr, AttrTy>) {
+ auto strAttr = mlir::dyn_cast<AttrTy>(attributes[it.index()]);
+ return strAttr.getValue();
+ } else if constexpr (std::is_same_v<mlir::IntegerAttr, AttrTy>) {
+ auto intAttr =
+ mlir::dyn_cast<mlir::IntegerAttr>(attributes[it.index()]);
+ return intAttr.getInt();
+ }
+ }
+ }
+ return std::nullopt;
+}
+
+static bool compareDeviceTypeInfo(
+ mlir::acc::RoutineOp op,
+ llvm::SmallVector<mlir::Attribute> &bindNameArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &bindNameDeviceTypeArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &gangArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &gangDimArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &gangDimDeviceTypeArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &seqArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &workerArrayAttr,
+ llvm::SmallVector<mlir::Attribute> &vectorArrayAttr) {
+ for (uint32_t dtypeInt = 0;
+ dtypeInt != mlir::acc::getMaxEnumValForDeviceType(); ++dtypeInt) {
+ auto dtype = static_cast<mlir::acc::DeviceType>(dtypeInt);
+ if (op.getBindNameValue(dtype) !=
+ getAttributeValueByDeviceType<llvm::StringRef, mlir::StringAttr>(
+ bindNameArrayAttr, bindNameDeviceTypeArrayAttr, dtype))
+ return false;
+ if (op.hasGang(dtype) != hasDeviceType(gangArrayAttr, dtype))
+ return false;
+ if (op.getGangDimValue(dtype) !=
+ getAttributeValueByDeviceType<int64_t, mlir::IntegerAttr>(
+ gangDimArrayAttr, gangDimDeviceTypeArrayAttr, dtype))
+ return false;
+ if (op.hasSeq(dtype) != hasDeviceType(seqArrayAttr, dtype))
+ return false;
+ if (op.hasWorker(dtype) != hasDeviceType(workerArrayAttr, dtype))
+ return false;
+ if (op.hasVector(dtype) != hasDeviceType(vectorArrayAttr, dtype))
+ return false;
+ }
+ return true;
+}
+
static void attachRoutineInfo(mlir::func::FuncOp func,
mlir::SymbolRefAttr routineAttr) {
llvm::SmallVector<mlir::SymbolRefAttr> routines;
@@ -3518,17 +3584,23 @@ void Fortran::lower::genOpenACCRoutineConstruct(
funcName = funcOp.getName();
}
}
- bool hasSeq = false, hasGang = false, hasWorker = false, hasVector = false,
- hasNohost = false;
- std::optional<std::string> bindName = std::nullopt;
- std::optional<int64_t> gangDim = std::nullopt;
+ bool hasNohost = false;
+
+ llvm::SmallVector<mlir::Attribute> seqDeviceTypes, vectorDeviceTypes,
+ workerDeviceTypes, bindNameDeviceTypes, bindNames, gangDeviceTypes,
+ gangDimDeviceTypes, gangDimValues;
+
+ // device_type attribute is set to `none` until a device_type clause is
+ // encountered.
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), mlir::acc::DeviceType::None);
for (const Fortran::parser::AccClause &clause : clauses.v) {
if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
- hasSeq = true;
+ seqDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto *gangClause =
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
- hasGang = true;
+
if (gangClause->v) {
const Fortran::parser::AccGangArgList &x = *gangClause->v;
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
@@ -3539,21 +3611,27 @@ void Fortran::lower::genOpenACCRoutineConstruct(
if (!dimValue)
mlir::emitError(loc,
"dim value must be a constant positive integer");
- gangDim = *dimValue;
+ gangDimValues.push_back(
+ builder.getIntegerAttr(builder.getI64Type(), *dimValue));
+ gangDimDeviceTypes.push_back(crtDeviceTypeAttr);
}
}
+ } else {
+ gangDeviceTypes.push_back(crtDeviceTypeAttr);
}
} else if (std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
- hasVector = true;
+ vectorDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
- hasWorker = true;
+ workerDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (std::get_if<Fortran::parser::AccClause::Nohost>(&clause.u)) {
hasNohost = true;
} else if (const auto *bindClause =
std::get_if<Fortran::parser::AccClause::Bind>(&clause.u)) {
if (const auto *name =
std::get_if<Fortran::parser::Name>(&bindClause->v.u)) {
- bindName = converter.mangleName(*name->symbol);
+ bindNames.push_back(
+ builder.getStringAttr(converter.mangleName(*name->symbol)));
+ bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
} else if (const auto charExpr =
std::get_if<Fortran::parser::ScalarDefaultCharExpr>(
&bindClause->v.u)) {
@@ -3562,8 +3640,18 @@ void Fortran::lower::genOpenACCRoutineConstruct(
*charExpr);
if (!name)
mlir::emitError(loc, "Could not retrieve the bind name");
- bindName = *name;
+ bindNames.push_back(builder.getStringAttr(*name));
+ bindNameDeviceTypes.push_back(crtDeviceTypeAttr);
}
+ } else if (const auto *deviceTypeClause =
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
+ &clause.u)) {
+ const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
+ deviceTypeClause->v;
+ assert(deviceTypeExprList.v.size() == 1 &&
+ "expect only one device_type expr");
+ crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
+ builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
}
}
@@ -3575,12 +3663,11 @@ void Fortran::lower::genOpenACCRoutineConstruct(
if (routineOp.getFuncName().str().compare(funcName) == 0) {
// If the routine is already specified with the same clauses, just skip
// the operation creation.
- if (routineOp.getBindName() == bindName &&
- routineOp.getGang() == hasGang &&
- routineOp.getWorker() == hasWorker &&
- routineOp.getVector() == hasVector && routineOp.getSeq() == hasSeq &&
- routineOp.getNohost() == hasNohost &&
- routineOp.getGangDim() == gangDim)
+ if (compareDeviceTypeInfo(routineOp, bindNames, bindNameDeviceTypes,
+ gangDeviceTypes, gangDimValues,
+ gangDimDeviceTypes, seqDeviceTypes,
+ workerDeviceTypes, vectorDeviceTypes) &&
+ routineOp.getNohost() == hasNohost)
return;
mlir::emitError(loc, "Routine already specified with different clauses");
}
@@ -3588,10 +3675,19 @@ void Fortran::lower::genOpenACCRoutineConstruct(
modBuilder.create<mlir::acc::RoutineOp>(
loc, routineOpName.str(), funcName,
- bindName ? builder.getStringAttr(*bindName) : mlir::StringAttr{}, hasGang,
- hasWorker, hasVector, hasSeq, hasNohost, /*implicit=*/false,
- gangDim ? builder.getIntegerAttr(builder.getIntegerType(32), *gangDim)
- : mlir::IntegerAttr{});
+ bindNames.empty() ? nullptr : builder.getArrayAttr(bindNames),
+ bindNameDeviceTypes.empty() ? nullptr
+ : builder.getArrayAttr(bindNameDeviceTypes),
+ workerDeviceTypes.empty() ? nullptr
+ : builder.getArrayAttr(workerDeviceTypes),
+ vectorDeviceTypes.empty() ? nullptr
+ : builder.getArrayAttr(vectorDeviceTypes),
+ seqDeviceTypes.empty() ? nullptr : builder.getArrayAttr(seqDeviceTypes),
+ hasNohost, /*implicit=*/false,
+ gangDeviceTypes.empty() ? nullptr : builder.getArrayAttr(gangDeviceTypes),
+ gangDimValues.empty() ? nullptr : builder.getArrayAttr(gangDimValues),
+ gangDimDeviceTypes.empty() ? nullptr
+ : builder.getArrayAttr(gangDimDeviceTypes));
if (funcOp)
attachRoutineInfo(funcOp, builder.getSymbolRefAttr(routineOpName.str()));
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 8b94279503334a..8e9e65da32cd19 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -2,12 +2,14 @@
! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
-
+! CHECK: acc.routine @acc_routine_16 func(@_QPacc_routine18) bind("_QPacc_routine17" [#acc.device_type<host>], "_QPacc_routine16" [#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_15 func(@_QPacc_routine17) worker ([#acc.device_type<host>]) vector ([#acc.device_type<multicore>])
+! CHECK: acc.routine @acc_routine_14 func(@_QPacc_routine16) gang([#acc.device_type<nvidia>]) seq ([#acc.device_type<host>])
! CHECK: acc.routine @acc_routine_10 func(@_QPacc_routine11) seq
! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
-! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim = 1 : i32)
+! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(1 : i64)
! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
@@ -106,3 +108,15 @@ subroutine acc_routine14()
subroutine acc_routine15()
!$acc routine bind(acc_routine16)
end subroutine
+
+subroutine acc_routine16()
+ !$acc routine device_type(host) seq dtype(nvidia) gang
+end subroutine
+
+subroutine acc_routine17()
+ !$acc routine device_type(host) worker dtype(multicore) vector
+end subroutine
+
+subroutine acc_routine18()
+ !$acc routine device_type(host) bind(acc_routine17) dtype(multicore) bind(acc_routine16)
+end subroutine
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 24f129d92805c0..7344ab2852b9ce 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1994,27 +1994,63 @@ def OpenACC_RoutineOp : OpenACC_Op<"routine", [IsolatedFromAbove]> {
let arguments = (ins SymbolNameAttr:$sym_name,
SymbolNameAttr:$func_name,
- OptionalAttr<StrAttr>:$bind_name,
- UnitAttr:$gang,
- UnitAttr:$worker,
- UnitAttr:$vector,
- UnitAttr:$seq,
+ OptionalAttr<StrArrayAttr>:$bindName,
+ OptionalAttr<DeviceTypeArrayAttr>:$bindNameDeviceType,
+ OptionalAttr<DeviceTypeArrayAttr>:$worker,
+ OptionalAttr<DeviceTypeArrayAttr>:$vector,
+ OptionalAttr<DeviceTypeArrayAttr>:$seq,
UnitAttr:$nohost,
UnitAttr:$implicit,
- OptionalAttr<APIntAttr>:$gangDim);
+ OptionalAttr<DeviceTypeArrayAttr>:$gang,
+ OptionalAttr<I64ArrayAttr>:$gangDim,
+ OptionalAttr<DeviceTypeArrayAttr>:$gangDimDeviceType);
let extraClassDeclaration = [{
static StringRef getGangDimKeyword() { return "dim"; }
+
+ /// Return true if the op has the worker attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasWorker();
+ /// Return true if the op has the worker attribute for the given
+ /// device_type.
+ bool hasWorker(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the vector attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasVector();
+ /// Return true if the op has the vector attribute for the given
+ /// device_type.
+ bool hasVector(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the seq attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasSeq();
+ /// Return true if the op has the seq attribute for the given
+ /// device_type.
+ bool hasSeq(mlir::acc::DeviceType deviceType);
+
+ /// Return true if the op has the gang attribute for the
+ /// mlir::acc::DeviceType::None device_type.
+ bool hasGang();
+ /// Return true if the op has the gang attribute for the given
+ /// device_type.
+ bool hasGang(mlir::acc::DeviceType deviceType);
+
+ std::optional<int64_t> getGangDimValue();
+ std::optional<int64_t> getGangDimValue(mlir::acc::DeviceType deviceType);
+
+ std::optional<llvm::StringRef> getBindNameValue();
+ std::optional<llvm::StringRef> getBindNameValue(mlir::acc::DeviceType deviceType);
}];
let assemblyFormat = [{
$sym_name `func` `(` $func_name `)`
oilist (
- `bind` `(` $bind_name `)`
- | `gang` `` custom<RoutineGangClause>($gang, $gangDim)
- | `worker` $worker
- | `vector` $vector
- | `seq` $seq
+ `bind` `(` custom<BindName>($bindName, $bindNameDeviceType) `)`
+ | `gang` `` custom<RoutineGangClause>($gang, $gangDim, $gangDimDeviceType)
+ | `worker` custom<DeviceTypeArrayAttr>($worker)
+ | `vector` custom<DeviceTypeArrayAttr>($vector)
+ | `seq` custom<DeviceTypeArrayAttr>($seq)
| `nohost` $nohost
| `implicit` $implicit
) attr-dict-with-keyword
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index bf3264b5da9802..82e614cb7572f6 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1046,7 +1046,7 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
return success();
}
-bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
+static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
return true;
return false;
@@ -2131,55 +2131,278 @@ LogicalResult acc::DeclareOp::verify() {
// RoutineOp
//===----------------------------------------------------------------------===//
+static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
+ mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(arrayAttr))
+ return false;
+
+ for (auto attr : *arrayAttr) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ if (deviceTypeAttr.getValue() == deviceType)
+ return true;
+ }
+
+ return false;
+}
+
+static unsigned getParallelismForDeviceType(acc::RoutineOp op,
+ acc::DeviceType dtype) {
+ unsigned parallelism = 0;
+ parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
+ parallelism += op.hasWorker(dtype) ? 1 : 0;
+ parallelism += op.hasVector(dtype) ? 1 : 0;
+ parallelism += op.hasSeq(dtype) ? 1 : 0;
+ return parallelism;
+}
+
LogicalResult acc::RoutineOp::verify() {
- int parallelism = 0;
- parallelism += getGang() ? 1 : 0;
- parallelism += getWorker() ? 1 : 0;
- parallelism += getVector() ? 1 : 0;
- parallelism += getSeq() ? 1 : 0;
+ unsigned baseParallelism =
+ getParallelismForDeviceType(*this, acc::DeviceType::None);
- if (parallelism > 1)
+ if (baseParallelism > 1)
return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
"be present at the same time";
+ for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
+ ++dtypeInt) {
+ auto dtype = static_cast<acc::DeviceType>(dtypeInt);
+ if (dtype == acc::DeviceType::None)
+ continue;
+ unsigned parallelism = getParallelismForDeviceType(*this, dtype);
+
+ if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
+ return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
+ "be present at the same time";
+ }
+
return success();
}
-static ParseResult parseRoutineGangClause(OpAsmParser &parser, UnitAttr &gang,
- IntegerAttr &gangDim) {
- // Since gang clause exists, ensure that unit attribute is set.
- gang = UnitAttr::get(parser.getBuilder().getContext());
+static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName,
+ mlir::ArrayAttr &deviceTypes) {
+ llvm::SmallVector<mlir::Attribute> bindNameAttrs;
+ llvm::SmallVector<mlir::Attribute> deviceTypeAttrs;
- // Next, look for dim on gang. Don't initialize `gangDim` yet since
- // we leave it without attribute if there is no `dim` specifier.
- if (succeeded(parser.parseOptionalLParen())) {
- // Look for syntax that looks like `dim = 1 : i32`.
- // Thus first look for `dim =`
- if (failed(parser.parseKeyword(RoutineOp::getGangDimKeyword())) ||
- failed(parser.parseEqual()))
- return failure();
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(bindNameAttrs.emplace_back()))
+ return failure();
+ if (failed(parser.parseOptionalLSquare())) {
+ deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ } else {
+ if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ }
+ return success();
+ })))
+ return failure();
- int64_t dimValue;
- Type valueType;
- // Now look for `1 : i32`
- if (failed(parser.parseInteger(dimValue)) ||
- failed(parser.parseColonType(valueType)))
- return failure();
+ bindName = ArrayAttr::get(parser.getContext(), bindNameAttrs);
+ deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
+
+ return success();
+}
+
+static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ std::optional<mlir::ArrayAttr> bindName,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+ llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
+ [&](const auto &pair) {
+ p << std::get<0>(pair);
+ printSingleDeviceType(p, std::get<1>(pair));
+ });
+}
+
+static ParseResult parseRoutineGangClause(OpAsmParser &parser,
+ mlir::ArrayAttr &gang,
+ mlir::ArrayAttr &gangDim,
+ mlir::ArrayAttr &gangDimDeviceTypes) {
- gangDim = IntegerAttr::get(valueType, dimValue);
+ llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
+ gangDimDeviceTypeAttrs;
+ bool needCommaBeforeOperands = false;
- if (failed(parser.parseRParen()))
+ // Gang keyword only
+ if (failed(parser.parseOptionalLParen())) {
+ gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ gang = ArrayAttr::get(parser.getContext(), gangAttrs);
+ return success();
+ }
+
+ // Parse keyword only attributes
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(gangAttrs.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+ if (parser.parseRSquare())
return failure();
+ needCommaBeforeOperands = true;
+ }
+
+ if (needCommaBeforeOperands && failed(parser.parseComma()))
+ return failure();
+
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(gangDimAttrs.emplace_back()))
+ return failure();
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else {
+ gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ }
+ return success();
+ })))
+ return failure();
+
+ if (failed(parser.parseRParen()))
+ return failure();
+
+ gang = ArrayAttr::get(parser.getContext(), gangAttrs);
+ gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
+ gangDimDeviceTypes =
+ ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
+
+ return success();
+}
+
+void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
+ std::optional<mlir::ArrayAttr> gang,
+ std::optional<mlir::ArrayAttr> gangDim,
+ std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
+
+ if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
+ gang->size() == 1) {
+ auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
+ if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+ return;
+ }
+
+ p << "(";
+
+ printDeviceTypes(p, gang);
+
+ if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
+ p << ", ";
+
+ if (hasDeviceTypeValues(gangDimDeviceTypes))
+ llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
+ [&](const auto &pair) {
+ p << std::get<0>(pair);
+ printSingleDeviceType(p, std::get<1>(pair));
+ });
+
+ p << ")";
+}
+
+static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
+ mlir::ArrayAttr &deviceTypes) {
+ llvm::SmallVector<mlir::Attribute> attributes;
+ // Keyword only
+ if (failed(parser.parseOptionalLParen())) {
+ attributes.push_back(mlir::acc::DeviceTypeAttr::get(
+ parser.getContext(), mlir::acc::DeviceType::None));
+ deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
+ return success();
}
+ // Parse device type attributes
+ if (succeeded(parser.parseOptionalLSquare())) {
+ if (failed(parser.parseCommaSeparatedList([&]() {
+ if (parser.parseAttribute(attributes.emplace_back()))
+ return failure();
+ return success();
+ })))
+ return failure();
+ if (parser.parseRSquare() || parser.parseRParen())
+ return failure();
+ }
+ deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
return success();
}
-void printRoutineGangClause(OpAsmPrinter &p, Operation *op, UnitAttr gang,
- IntegerAttr gangDim) {
- if (gangDim)
- p << "(" << RoutineOp::getGangDimKeyword() << " = " << gangDim.getValue()
- << " : " << gangDim.getType() << ")";
+static void
+printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ std::optional<mlir::ArrayAttr> deviceTypes) {
+
+ if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
+ auto deviceTypeAttr =
+ mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
+ if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
+ return;
+ }
+
+ if (!hasDeviceTypeValues(deviceTypes))
+ return;
+
+ p << "([";
+ llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
+ auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
+ p << dTypeAttr;
+ });
+ p << "])";
+}
+
+bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
+ return hasDeviceType(getWorker(), deviceType);
+}
+
+bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
+ return hasDeviceType(getVector(), deviceType);
+}
+
+bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
+ return hasDeviceType(getSeq(), deviceType);
+}
+
+std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
+ return getBindNameValue(mlir::acc::DeviceType::None);
+}
+
+std::optional<llvm::StringRef>
+RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(getBindNameDeviceType()))
+ return std::nullopt;
+ if (auto pos = findSegment(*getBindNameDeviceType(), deviceType)) {
+ auto attr = (*getBindName())[*pos];
+ auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
+ return stringAttr.getValue();
+ }
+ return std::nullopt;
+}
+
+bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
+
+bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
+ return hasDeviceType(getGang(), deviceType);
+}
+
+std::optional<int64_t> RoutineOp::getGangDimValue() {
+ return getGangDimValue(mlir::acc::DeviceType::None);
+}
+
+std::optional<int64_t>
+RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
+ if (!hasDeviceTypeValues(getGangDimDeviceType()))
+ return std::nullopt;
+ if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
+ auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
+ return intAttr.getInt();
+ }
+ return std::nullopt;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 8fa37bc98294ce..77599fdd32e567 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1656,7 +1656,7 @@ acc.routine @acc_func_rout5 func(@acc_func) bind("acc_func_gpu_worker") worker
acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") implicit gang
acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32)
+acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
// CHECK-LABEL: func.func @acc_func(
// CHECK: attributes {acc.routine_info = #acc.routine_info<[@acc_func_rout1, @acc_func_rout2, @acc_func_rout3,
@@ -1669,7 +1669,7 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
// CHECK: acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
// CHECK: acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") gang implicit
// CHECK: acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim = 1 : i32)
+// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
// -----
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index d78d7b0fdf6769..474f887928992a 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -347,3 +347,61 @@ TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) {
}
op->removeVectorAttr();
}
+
+TEST_F(OpenACCOpsTest, routineOpTest) {
+ OwningOpRef<RoutineOp> op =
+ b.create<RoutineOp>(loc, TypeRange{}, ValueRange{});
+
+ EXPECT_FALSE(op->hasSeq());
+ EXPECT_FALSE(op->hasVector());
+ EXPECT_FALSE(op->hasWorker());
+
+ for (auto d : dtypes) {
+ EXPECT_FALSE(op->hasSeq(d));
+ EXPECT_FALSE(op->hasVector(d));
+ EXPECT_FALSE(op->hasWorker(d));
+ }
+
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op->setSeqAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op->hasSeq());
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op->hasSeq(d));
+ op->removeSeqAttr();
+
+ op->setVectorAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op->hasVector());
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op->hasVector(d));
+ op->removeVectorAttr();
+
+ op->setWorkerAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op->hasWorker());
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op->hasWorker(d));
+ op->removeWorkerAttr();
+
+ op->setGangAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op->hasGang());
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op->hasGang(d));
+ op->removeGangAttr();
+
+ op->setGangDimDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+ op->setGangDimAttr(b.getArrayAttr({b.getIntegerAttr(b.getI64Type(), 8)}));
+ EXPECT_TRUE(op->getGangDimValue().has_value());
+ EXPECT_EQ(op->getGangDimValue().value(), 8);
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op->getGangDimValue(d).has_value());
+ op->removeGangDimDeviceTypeAttr();
+ op->removeGangDimAttr();
+
+ op->setBindNameDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+ op->setBindNameAttr(b.getArrayAttr({b.getStringAttr("fname")}));
+ EXPECT_TRUE(op->getBindNameValue().has_value());
+ EXPECT_EQ(op->getBindNameValue().value(), "fname");
+ for (auto d : dtypesWithoutNone)
+ EXPECT_FALSE(op->getBindNameValue(d).has_value());
+ op->removeBindNameDeviceTypeAttr();
+ op->removeBindNameAttr();
+}
>From 66f8415ecf19b9f7e2010dce8869f14d47db80c4 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Wed, 17 Jan 2024 14:54:30 -0800
Subject: [PATCH 2/2] Keep dim keyword for gang
---
flang/test/Lower/OpenACC/acc-routine.f90 | 2 +-
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 5 ++++-
mlir/test/Dialect/OpenACC/ops.mlir | 4 ++--
3 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/flang/test/Lower/OpenACC/acc-routine.f90 b/flang/test/Lower/OpenACC/acc-routine.f90
index 8e9e65da32cd19..2fe150e70b0cfb 100644
--- a/flang/test/Lower/OpenACC/acc-routine.f90
+++ b/flang/test/Lower/OpenACC/acc-routine.f90
@@ -9,7 +9,7 @@
! CHECK: acc.routine @acc_routine_9 func(@_QPacc_routine10) seq
! CHECK: acc.routine @acc_routine_8 func(@_QPacc_routine9) bind("_QPacc_routine9a")
! CHECK: acc.routine @acc_routine_7 func(@_QPacc_routine8) bind("routine8_")
-! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(1 : i64)
+! CHECK: acc.routine @acc_routine_6 func(@_QPacc_routine7) gang(dim: 1 : i64)
! CHECK: acc.routine @acc_routine_5 func(@_QPacc_routine6) nohost
! CHECK: acc.routine @acc_routine_4 func(@_QPacc_routine5) worker
! CHECK: acc.routine @acc_routine_3 func(@_QPacc_routine4) vector
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 82e614cb7572f6..5cfbe233a670f3 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2248,7 +2248,9 @@ static ParseResult parseRoutineGangClause(OpAsmParser &parser,
return failure();
if (failed(parser.parseCommaSeparatedList([&]() {
- if (parser.parseAttribute(gangDimAttrs.emplace_back()))
+ if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
+ parser.parseColon() ||
+ parser.parseAttribute(gangDimAttrs.emplace_back()))
return failure();
if (succeeded(parser.parseOptionalLSquare())) {
if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
@@ -2295,6 +2297,7 @@ void printRoutineGangClause(OpAsmPrinter &p, Operation *op,
if (hasDeviceTypeValues(gangDimDeviceTypes))
llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
[&](const auto &pair) {
+ p << acc::RoutineOp::getGangDimKeyword() << ": ";
p << std::get<0>(pair);
printSingleDeviceType(p, std::get<1>(pair));
});
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 77599fdd32e567..99b44183758d95 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1656,7 +1656,7 @@ acc.routine @acc_func_rout5 func(@acc_func) bind("acc_func_gpu_worker") worker
acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") implicit gang
acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
+acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64)
// CHECK-LABEL: func.func @acc_func(
// CHECK: attributes {acc.routine_info = #acc.routine_info<[@acc_func_rout1, @acc_func_rout2, @acc_func_rout3,
@@ -1669,7 +1669,7 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
// CHECK: acc.routine @acc_func_rout6 func(@acc_func) bind("acc_func_gpu_seq") seq
// CHECK: acc.routine @acc_func_rout7 func(@acc_func) bind("acc_func_gpu_imp_gang") gang implicit
// CHECK: acc.routine @acc_func_rout8 func(@acc_func) bind("acc_func_gpu_vector_nohost") vector nohost
-// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(1 : i64)
+// CHECK: acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(dim: 1 : i64)
// -----
More information about the Mlir-commits
mailing list