[Mlir-commits] [mlir] [acc] Add attribute for combined constructs (PR #80319)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Feb 1 10:34:58 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openacc
@llvm/pr-subscribers-openacc
Author: Razvan Lupusoru (razvanlupusoru)
<details>
<summary>Changes</summary>
Combined constructs are decomposed into separate operations. However, this does not adhere to `acc` dialect's goal to be able to regenerate semantically equivalent clauses as user's intent. Thus, add an attribute to keep track of the combined constructs.
---
Full diff: https://github.com/llvm/llvm-project/pull/80319.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACC.h (+7-3)
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+34-8)
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+61)
- (modified) mlir/test/Dialect/OpenACC/invalid.mlir (+30)
- (modified) mlir/test/Dialect/OpenACC/ops.mlir (+43-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index bb3b9617c24ed..941682e6840a0 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -122,15 +122,19 @@ mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp);
/// Used to obtain the attribute name for declare.
static constexpr StringLiteral getDeclareAttrName() {
- return StringLiteral("acc.declare");
+ return DeclareAttr::name;
}
static constexpr StringLiteral getDeclareActionAttrName() {
- return StringLiteral("acc.declare_action");
+ return DeclareActionAttr::name;
}
static constexpr StringLiteral getRoutineInfoAttrName() {
- return StringLiteral("acc.routine_info");
+ return RoutineInfoAttr::name;
+}
+
+static constexpr StringLiteral getCombinedConstructsAttrName() {
+ return CombinedConstructsTypeAttr::name;
}
struct RuntimeCounters
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 9398cbfdacee4..24acc66bf9497 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -218,6 +218,24 @@ def GangArgTypeArrayAttr :
let constBuilderCall = ?;
}
+// Combined constructs enumerations
+def OpenACC_KernelsLoop : I32EnumAttrCase<"KernelsLoop", 1, "kernels_loop">;
+def OpenACC_ParallelLoop : I32EnumAttrCase<"ParallelLoop", 2, "parallel_loop">;
+def OpenACC_SerialLoop : I32EnumAttrCase<"SerialLoop", 3, "serial_loop">;
+
+def OpenACC_CombinedConstructsType : I32EnumAttr<"CombinedConstructsType",
+ "Differentiate between combined constructs",
+ [OpenACC_KernelsLoop, OpenACC_ParallelLoop, OpenACC_SerialLoop]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::acc";
+}
+
+def OpenACC_CombinedConstructsAttr : EnumAttr<OpenACC_Dialect,
+ OpenACC_CombinedConstructsType,
+ "combined_constructs"> {
+ let assemblyFormat = [{ ```<` $value `>` }];
+}
+
// Define a resource for the OpenACC runtime counters.
def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
@@ -928,7 +946,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ OptionalAttr<DefaultValueAttr>:$defaultAttr,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
let regions = (region AnyRegion:$region);
@@ -989,7 +1008,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
let assemblyFormat = [{
oilist(
- `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1059,7 +1079,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ OptionalAttr<DefaultValueAttr>:$defaultAttr,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
let regions = (region AnyRegion:$region);
@@ -1101,7 +1122,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
let assemblyFormat = [{
oilist(
- `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1168,7 +1190,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
Optional<I1>:$selfCond,
UnitAttr:$selfAttr,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
- OptionalAttr<DefaultValueAttr>:$defaultAttr);
+ OptionalAttr<DefaultValueAttr>:$defaultAttr,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
let regions = (region AnyRegion:$region);
@@ -1229,7 +1252,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
let assemblyFormat = [{
oilist(
- `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
type($asyncOperands), $asyncOperandsDeviceType) `)`
| `num_gangs` `(` custom<NumGangs>($numGangs,
@@ -1550,7 +1574,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$reductionOperands,
- OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
+ OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+ OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined
);
let results = (outs Variadic<AnyType>:$results);
@@ -1642,7 +1667,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
oilist(
- `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
+ `combined` `(` custom<CombinedConstructs>($combined) `)`
+ | `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
$gangOperandsArgType, $gangOperandsDeviceType,
$gangOperandsSegments, $gang)
| `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index ae5da686f8595..a020e6d34aba9 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -815,6 +815,11 @@ LogicalResult acc::ParallelOp::verify() {
if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
return failure();
+ if (getCombined().has_value() &&
+ getCombined().value() != acc::CombinedConstructsType::ParallelLoop) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
}
@@ -1285,6 +1290,45 @@ static void printDeviceTypeOperandsWithKeywordOnly(
p << ")";
}
+static ParseResult
+parseCombinedConstructs(mlir::OpAsmParser &parser,
+ mlir::acc::CombinedConstructsTypeAttr &attr) {
+ // Just parsing first keyword we know which type of combined construct it is.
+ if (succeeded(parser.parseOptionalKeyword("kernels"))) {
+ attr = mlir::acc::CombinedConstructsTypeAttr::get(
+ parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
+ } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
+ attr = mlir::acc::CombinedConstructsTypeAttr::get(
+ parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
+ } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
+ attr = mlir::acc::CombinedConstructsTypeAttr::get(
+ parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
+ } else {
+ parser.emitError(parser.getCurrentLocation(),
+ "expected compute construct name for combined constructs");
+ return failure();
+ }
+
+ // Ensure that the `loop` wording follows the compute construct.
+ return parser.parseKeyword("loop");
+}
+
+static void
+printCombinedConstructs(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::acc::CombinedConstructsTypeAttr attr) {
+ switch (attr.getValue()) {
+ case mlir::acc::CombinedConstructsType::KernelsLoop:
+ p << "kernels loop";
+ break;
+ case mlir::acc::CombinedConstructsType::ParallelLoop:
+ p << "parallel loop";
+ break;
+ case mlir::acc::CombinedConstructsType::SerialLoop:
+ p << "serial loop";
+ break;
+ };
+}
+
//===----------------------------------------------------------------------===//
// SerialOp
//===----------------------------------------------------------------------===//
@@ -1370,6 +1414,11 @@ LogicalResult acc::SerialOp::verify() {
if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
return failure();
+ if (getCombined().has_value() &&
+ getCombined().value() != acc::CombinedConstructsType::SerialLoop) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
}
@@ -1497,6 +1546,11 @@ LogicalResult acc::KernelsOp::verify() {
if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
return failure();
+ if (getCombined().has_value() &&
+ getCombined().value() != acc::CombinedConstructsType::KernelsLoop) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
}
@@ -1854,6 +1908,13 @@ LogicalResult acc::LoopOp::verify() {
"reductions", false)))
return failure();
+ if (getCombined().has_value() &&
+ (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
+ getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
+ getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
+ return emitError("unexpected combined constructs attribute");
+ }
+
// Check non-empty body().
if (getRegion().empty())
return emitError("expected non-empty body.");
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 16df33eec642c..48cbddae071ba 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -738,3 +738,33 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
acc.terminator
}
}
+
+// -----
+
+func.func @acc_combined() {
+ // expected-error @below {{expected compute construct name for combined constructs}}
+ acc.parallel combined() {
+ }
+
+ return
+}
+
+// -----
+
+func.func @acc_combined() {
+ // expected-error @below {{expected 'loop'}}
+ acc.parallel combined(parallel) {
+ }
+
+ return
+}
+
+// -----
+
+func.func @acc_combined() {
+ // expected-error @below {{unexpected combined constructs attribute}}
+ acc.parallel combined(kernels loop) {
+ }
+
+ return
+}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 4e6ed8645cdbc..a10b603e8a07b 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1846,9 +1846,49 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
// -----
-%c2 = arith.constant 2 : i32
-%c1 = arith.constant 1 : i32
-acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+// CHECK-LABEL: func.func @acc_num_gangs
+func.func @acc_num_gangs() {
+ %c2 = arith.constant 2 : i32
+ %c1 = arith.constant 1 : i32
+ acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+ }
+
+ return
}
// CHECK: acc.parallel num_gangs({%c2{{.*}} : i32} [#acc.device_type<default>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+// -----
+
+// CHECK-LABEL: func.func @acc_combined
+func.func @acc_combined() {
+ acc.parallel combined(parallel loop) {
+ acc.loop combined(parallel loop) {
+ acc.yield
+ }
+ acc.terminator
+ }
+
+ acc.kernels combined(kernels loop) {
+ acc.loop combined(kernels loop) {
+ acc.yield
+ }
+ acc.terminator
+ }
+
+ acc.serial combined(serial loop) {
+ acc.loop combined(serial loop) {
+ acc.yield
+ }
+ acc.terminator
+ }
+
+ return
+}
+
+// CHECK: acc.parallel combined(parallel loop)
+// CHECK: acc.loop combined(parallel loop)
+// CHECK: acc.kernels combined(kernels loop)
+// CHECK: acc.loop combined(kernels loop)
+// CHECK: acc.serial combined(serial loop)
+// CHECK: acc.loop combined(serial loop)
``````````
</details>
https://github.com/llvm/llvm-project/pull/80319
More information about the Mlir-commits
mailing list