[Mlir-commits] [mlir] [acc] Add attribute for combined constructs (PR #80319)

Razvan Lupusoru llvmlistbot at llvm.org
Wed Mar 6 14:05:09 PST 2024


https://github.com/razvanlupusoru updated https://github.com/llvm/llvm-project/pull/80319

>From 4433dba4aabb01164c050d72cec9c62e8bf5375e Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 1 Feb 2024 10:33:08 -0800
Subject: [PATCH 1/5] [acc] Add attribute for combined constructs

Combined constructs are decomposed into separate operations. However,
this does not adhere to `acc` dialect's goal to be able to regenerate
semantically equivalent clauses as user's intent. Thus, add an attribute
to keep track of the combined constructs.
---
 mlir/include/mlir/Dialect/OpenACC/OpenACC.h   | 10 ++-
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 42 ++++++++++---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 61 +++++++++++++++++++
 mlir/test/Dialect/OpenACC/invalid.mlir        | 30 +++++++++
 mlir/test/Dialect/OpenACC/ops.mlir            | 46 +++++++++++++-
 5 files changed, 175 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index bb3b9617c24edb..941682e6840a06 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -122,15 +122,19 @@ mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp);
 
 /// Used to obtain the attribute name for declare.
 static constexpr StringLiteral getDeclareAttrName() {
-  return StringLiteral("acc.declare");
+  return DeclareAttr::name;
 }
 
 static constexpr StringLiteral getDeclareActionAttrName() {
-  return StringLiteral("acc.declare_action");
+  return DeclareActionAttr::name;
 }
 
 static constexpr StringLiteral getRoutineInfoAttrName() {
-  return StringLiteral("acc.routine_info");
+  return RoutineInfoAttr::name;
+}
+
+static constexpr StringLiteral getCombinedConstructsAttrName() {
+  return CombinedConstructsTypeAttr::name;
 }
 
 struct RuntimeCounters
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 9398cbfdacee46..24acc66bf94972 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -218,6 +218,24 @@ def GangArgTypeArrayAttr :
   let constBuilderCall = ?;
 }
 
+// Combined constructs enumerations
+def OpenACC_KernelsLoop    : I32EnumAttrCase<"KernelsLoop", 1, "kernels_loop">;
+def OpenACC_ParallelLoop   : I32EnumAttrCase<"ParallelLoop", 2, "parallel_loop">;
+def OpenACC_SerialLoop     : I32EnumAttrCase<"SerialLoop", 3, "serial_loop">;
+
+def OpenACC_CombinedConstructsType : I32EnumAttr<"CombinedConstructsType",
+    "Differentiate between combined constructs",
+    [OpenACC_KernelsLoop, OpenACC_ParallelLoop, OpenACC_SerialLoop]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::acc";
+}
+
+def OpenACC_CombinedConstructsAttr : EnumAttr<OpenACC_Dialect,
+                                       OpenACC_CombinedConstructsType,
+                                       "combined_constructs"> {
+  let assemblyFormat = [{ ```<` $value `>` }];
+}
+
 // Define a resource for the OpenACC runtime counters.
 def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;
 
@@ -928,7 +946,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
       Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-      OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      OptionalAttr<DefaultValueAttr>:$defaultAttr,
+      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -989,7 +1008,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
 
   let assemblyFormat = [{
     oilist(
-        `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+        `combined` `(` custom<CombinedConstructs>($combined) `)`
+      | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1059,7 +1079,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
       Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-      OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      OptionalAttr<DefaultValueAttr>:$defaultAttr,
+      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1101,7 +1122,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
 
   let assemblyFormat = [{
     oilist(
-        `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+        `combined` `(` custom<CombinedConstructs>($combined) `)`
+      | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1168,7 +1190,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
       Optional<I1>:$selfCond,
       UnitAttr:$selfAttr,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
-      OptionalAttr<DefaultValueAttr>:$defaultAttr);
+      OptionalAttr<DefaultValueAttr>:$defaultAttr,
+      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1229,7 +1252,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
 
   let assemblyFormat = [{
     oilist(
-        `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+      `combined` `(` custom<CombinedConstructs>($combined) `)`
+      | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `num_gangs` `(` custom<NumGangs>($numGangs,
@@ -1550,7 +1574,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
       Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
       Variadic<AnyType>:$reductionOperands,
-      OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
+      OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
+      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined
   );
 
   let results = (outs Variadic<AnyType>:$results);
@@ -1642,7 +1667,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
+        `combined` `(` custom<CombinedConstructs>($combined) `)`
+      | `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
             $gangOperandsSegments, $gang)
       | `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index ae5da686f8595a..a020e6d34aba91 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -815,6 +815,11 @@ LogicalResult acc::ParallelOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
     return failure();
 
+  if (getCombined().has_value() &&
+      getCombined().value() != acc::CombinedConstructsType::ParallelLoop) {
+    return emitError("unexpected combined constructs attribute");
+  }
+
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
@@ -1285,6 +1290,45 @@ static void printDeviceTypeOperandsWithKeywordOnly(
   p << ")";
 }
 
+static ParseResult
+parseCombinedConstructs(mlir::OpAsmParser &parser,
+                        mlir::acc::CombinedConstructsTypeAttr &attr) {
+  // Just parsing first keyword we know which type of combined construct it is.
+  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
+    attr = mlir::acc::CombinedConstructsTypeAttr::get(
+        parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
+  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
+    attr = mlir::acc::CombinedConstructsTypeAttr::get(
+        parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
+  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
+    attr = mlir::acc::CombinedConstructsTypeAttr::get(
+        parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
+  } else {
+    parser.emitError(parser.getCurrentLocation(),
+                     "expected compute construct name for combined constructs");
+    return failure();
+  }
+
+  // Ensure that the `loop` wording follows the compute construct.
+  return parser.parseKeyword("loop");
+}
+
+static void
+printCombinedConstructs(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                        mlir::acc::CombinedConstructsTypeAttr attr) {
+  switch (attr.getValue()) {
+  case mlir::acc::CombinedConstructsType::KernelsLoop:
+    p << "kernels loop";
+    break;
+  case mlir::acc::CombinedConstructsType::ParallelLoop:
+    p << "parallel loop";
+    break;
+  case mlir::acc::CombinedConstructsType::SerialLoop:
+    p << "serial loop";
+    break;
+  };
+}
+
 //===----------------------------------------------------------------------===//
 // SerialOp
 //===----------------------------------------------------------------------===//
@@ -1370,6 +1414,11 @@ LogicalResult acc::SerialOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
     return failure();
 
+  if (getCombined().has_value() &&
+      getCombined().value() != acc::CombinedConstructsType::SerialLoop) {
+    return emitError("unexpected combined constructs attribute");
+  }
+
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }
 
@@ -1497,6 +1546,11 @@ LogicalResult acc::KernelsOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
     return failure();
 
+  if (getCombined().has_value() &&
+      getCombined().value() != acc::CombinedConstructsType::KernelsLoop) {
+    return emitError("unexpected combined constructs attribute");
+  }
+
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
 }
 
@@ -1854,6 +1908,13 @@ LogicalResult acc::LoopOp::verify() {
           "reductions", false)))
     return failure();
 
+  if (getCombined().has_value() &&
+      (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
+       getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
+       getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
+    return emitError("unexpected combined constructs attribute");
+  }
+
   // Check non-empty body().
   if (getRegion().empty())
     return emitError("expected non-empty body.");
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index 16df33eec642ce..48cbddae071bab 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -738,3 +738,33 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
     acc.terminator
   }
 }
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{expected compute construct name for combined constructs}}
+  acc.parallel combined() {
+  }
+
+  return
+}
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{expected 'loop'}}
+  acc.parallel combined(parallel) {
+  }
+
+  return
+}
+
+// -----
+
+func.func @acc_combined() {
+  // expected-error @below {{unexpected combined constructs attribute}}
+  acc.parallel combined(kernels loop) {
+  }
+
+  return
+}
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 4e6ed8645cdbce..a10b603e8a07bd 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1846,9 +1846,49 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {
 
 // -----
 
-%c2 = arith.constant 2 : i32
-%c1 = arith.constant 1 : i32
-acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+// CHECK-LABEL: func.func @acc_num_gangs
+func.func @acc_num_gangs() {
+  %c2 = arith.constant 2 : i32
+  %c1 = arith.constant 1 : i32
+  acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
+  }
+
+  return
 }
 
 // CHECK: acc.parallel num_gangs({%c2{{.*}} : i32} [#acc.device_type<default>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])
+
+// -----
+
+// CHECK-LABEL: func.func @acc_combined
+func.func @acc_combined() {
+  acc.parallel combined(parallel loop) {
+    acc.loop combined(parallel loop) {
+      acc.yield
+    }
+    acc.terminator
+  }
+
+  acc.kernels combined(kernels loop) {
+    acc.loop combined(kernels loop) {
+      acc.yield
+    }
+    acc.terminator
+  }
+
+  acc.serial combined(serial loop) {
+    acc.loop combined(serial loop) {
+      acc.yield
+    }
+    acc.terminator
+  }
+
+  return
+}
+
+// CHECK: acc.parallel combined(parallel loop)
+// CHECK: acc.loop combined(parallel loop)
+// CHECK: acc.kernels combined(kernels loop)
+// CHECK: acc.loop combined(kernels loop)
+// CHECK: acc.serial combined(serial loop)
+// CHECK: acc.loop combined(serial loop)

>From c99ede10b7b578564760ba8fe3d55aaaf1e062e2 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 6 Mar 2024 11:36:28 -0800
Subject: [PATCH 2/5] Restore irrelevant changes around attribute names

---
 mlir/include/mlir/Dialect/OpenACC/OpenACC.h | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
index 941682e6840a06..0c8e0b45878206 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACC.h
@@ -122,15 +122,15 @@ mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp);
 
 /// Used to obtain the attribute name for declare.
 static constexpr StringLiteral getDeclareAttrName() {
-  return DeclareAttr::name;
+  return StringLiteral("acc.declare");
 }
 
 static constexpr StringLiteral getDeclareActionAttrName() {
-  return DeclareActionAttr::name;
+  return StringLiteral("acc.declare_action");
 }
 
 static constexpr StringLiteral getRoutineInfoAttrName() {
-  return RoutineInfoAttr::name;
+  return StringLiteral("acc.routine_info");
 }
 
 static constexpr StringLiteral getCombinedConstructsAttrName() {

>From f7eae7ca3ce1ee768e46de012a4b04f3924ad10b Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 6 Mar 2024 11:38:33 -0800
Subject: [PATCH 3/5] Fix optional use to be consistent with rest of file

---
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 517ab7eb651ddb..0dff9e3618b908 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -815,8 +815,8 @@ LogicalResult acc::ParallelOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
     return failure();
 
-  if (getCombined().has_value() &&
-      getCombined().value() != acc::CombinedConstructsType::ParallelLoop) {
+  if (getCombined() &&
+      *getCombined() != acc::CombinedConstructsType::ParallelLoop) {
     return emitError("unexpected combined constructs attribute");
   }
 
@@ -1412,8 +1412,8 @@ LogicalResult acc::SerialOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
     return failure();
 
-  if (getCombined().has_value() &&
-      getCombined().value() != acc::CombinedConstructsType::SerialLoop) {
+  if (getCombined() &&
+      *getCombined() != acc::CombinedConstructsType::SerialLoop) {
     return emitError("unexpected combined constructs attribute");
   }
 
@@ -1544,8 +1544,8 @@ LogicalResult acc::KernelsOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
     return failure();
 
-  if (getCombined().has_value() &&
-      getCombined().value() != acc::CombinedConstructsType::KernelsLoop) {
+  if (getCombined() &&
+      *getCombined() != acc::CombinedConstructsType::KernelsLoop) {
     return emitError("unexpected combined constructs attribute");
   }
 

>From 8ee90ce1e47db1882bd63e134ee99c7011db31fd Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 6 Mar 2024 13:26:54 -0800
Subject: [PATCH 4/5] Use unit attr for compute constructs combined. Simplify
 asm printer and reduce printing verbosity

---
 .../mlir/Dialect/OpenACC/OpenACCOps.td        | 20 +++++------
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 35 +++++--------------
 mlir/test/Dialect/OpenACC/invalid.mlir        | 18 +++++++---
 mlir/test/Dialect/OpenACC/ops.mlir            | 24 ++++++-------
 4 files changed, 45 insertions(+), 52 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 14484ddb45caf9..0df0be3a61046f 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -952,7 +952,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       OptionalAttr<DefaultValueAttr>:$defaultAttr,
-      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
+      UnitAttr:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1012,9 +1012,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
   }];
 
   let assemblyFormat = [{
+    ( `combined` `(` `loop` `)` $combined^)?
     oilist(
-        `combined` `(` custom<CombinedConstructs>($combined) `)`
-      | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+        `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1089,7 +1089,7 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
       OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       OptionalAttr<DefaultValueAttr>:$defaultAttr,
-      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
+      UnitAttr:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1130,9 +1130,9 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
   }];
 
   let assemblyFormat = [{
+    ( `combined` `(` `loop` `)` $combined^)?
     oilist(
-        `combined` `(` custom<CombinedConstructs>($combined) `)`
-      | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+        `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `firstprivate` `(` custom<SymOperandList>($gangFirstPrivateOperands,
@@ -1205,7 +1205,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
       UnitAttr:$selfAttr,
       Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
       OptionalAttr<DefaultValueAttr>:$defaultAttr,
-      OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined);
+      UnitAttr:$combined);
 
   let regions = (region AnyRegion:$region);
 
@@ -1265,9 +1265,9 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
   }];
 
   let assemblyFormat = [{
+    ( `combined` `(` `loop` `)` $combined^)?
     oilist(
-      `combined` `(` custom<CombinedConstructs>($combined) `)`
-      | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+        `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
       | `async` `(` custom<DeviceTypeOperands>($asyncOperands,
             type($asyncOperands), $asyncOperandsDeviceType) `)`
       | `num_gangs` `(` custom<NumGangs>($numGangs,
@@ -1691,7 +1691,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
     oilist(
-        `combined` `(` custom<CombinedConstructs>($combined) `)`
+        `combined` `(` custom<CombinedConstructsLoop>($combined) `)`
       | `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
             $gangOperandsSegments, $gang)
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 0dff9e3618b908..11ade13bc7a26a 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -815,11 +815,6 @@ LogicalResult acc::ParallelOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*this)))
     return failure();
 
-  if (getCombined() &&
-      *getCombined() != acc::CombinedConstructsType::ParallelLoop) {
-    return emitError("unexpected combined constructs attribute");
-  }
-
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
@@ -1289,8 +1284,8 @@ static void printDeviceTypeOperandsWithKeywordOnly(
 }
 
 static ParseResult
-parseCombinedConstructs(mlir::OpAsmParser &parser,
-                        mlir::acc::CombinedConstructsTypeAttr &attr) {
+parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
+                            mlir::acc::CombinedConstructsTypeAttr &attr) {
   // Just parsing first keyword we know which type of combined construct it is.
   if (succeeded(parser.parseOptionalKeyword("kernels"))) {
     attr = mlir::acc::CombinedConstructsTypeAttr::get(
@@ -1303,26 +1298,24 @@ parseCombinedConstructs(mlir::OpAsmParser &parser,
         parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
   } else {
     parser.emitError(parser.getCurrentLocation(),
-                     "expected compute construct name for combined constructs");
+                     "expected compute construct name");
     return failure();
   }
-
-  // Ensure that the `loop` wording follows the compute construct.
-  return parser.parseKeyword("loop");
+  return success();
 }
 
 static void
-printCombinedConstructs(mlir::OpAsmPrinter &p, mlir::Operation *op,
-                        mlir::acc::CombinedConstructsTypeAttr attr) {
+printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                            mlir::acc::CombinedConstructsTypeAttr attr) {
   switch (attr.getValue()) {
   case mlir::acc::CombinedConstructsType::KernelsLoop:
-    p << "kernels loop";
+    p << "kernels";
     break;
   case mlir::acc::CombinedConstructsType::ParallelLoop:
-    p << "parallel loop";
+    p << "parallel";
     break;
   case mlir::acc::CombinedConstructsType::SerialLoop:
-    p << "serial loop";
+    p << "serial";
     break;
   };
 }
@@ -1412,11 +1405,6 @@ LogicalResult acc::SerialOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*this)))
     return failure();
 
-  if (getCombined() &&
-      *getCombined() != acc::CombinedConstructsType::SerialLoop) {
-    return emitError("unexpected combined constructs attribute");
-  }
-
   return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
 }
 
@@ -1544,11 +1532,6 @@ LogicalResult acc::KernelsOp::verify() {
   if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*this)))
     return failure();
 
-  if (getCombined() &&
-      *getCombined() != acc::CombinedConstructsType::KernelsLoop) {
-    return emitError("unexpected combined constructs attribute");
-  }
-
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
 }
 
diff --git a/mlir/test/Dialect/OpenACC/invalid.mlir b/mlir/test/Dialect/OpenACC/invalid.mlir
index ca311634ce655e..ec5430420524ce 100644
--- a/mlir/test/Dialect/OpenACC/invalid.mlir
+++ b/mlir/test/Dialect/OpenACC/invalid.mlir
@@ -742,7 +742,7 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
 // -----
 
 func.func @acc_combined() {
-  // expected-error @below {{expected compute construct name for combined constructs}}
+  // expected-error @below {{expected 'loop'}}
   acc.parallel combined() {
   }
 
@@ -751,9 +751,19 @@ func.func @acc_combined() {
 
 // -----
 
+func.func @acc_combined() {
+  // expected-error @below {{expected compute construct name}}
+  acc.loop combined(loop) {
+  }
+
+  return
+}
+
+// -----
+
 func.func @acc_combined() {
   // expected-error @below {{expected 'loop'}}
-  acc.parallel combined(parallel) {
+  acc.parallel combined(parallel loop) {
   }
 
   return
@@ -762,8 +772,8 @@ func.func @acc_combined() {
 // -----
 
 func.func @acc_combined() {
-  // expected-error @below {{unexpected combined constructs attribute}}
-  acc.parallel combined(kernels loop) {
+  // expected-error @below {{expected ')'}}
+  acc.loop combined(parallel loop) {
   }
 
   return
diff --git a/mlir/test/Dialect/OpenACC/ops.mlir b/mlir/test/Dialect/OpenACC/ops.mlir
index 6d9c2eb1e55061..2ef2178cb2b63a 100644
--- a/mlir/test/Dialect/OpenACC/ops.mlir
+++ b/mlir/test/Dialect/OpenACC/ops.mlir
@@ -1862,22 +1862,22 @@ func.func @acc_num_gangs() {
 
 // CHECK-LABEL: func.func @acc_combined
 func.func @acc_combined() {
-  acc.parallel combined(parallel loop) {
-    acc.loop combined(parallel loop) {
+  acc.parallel combined(loop) {
+    acc.loop combined(parallel) {
       acc.yield
     }
     acc.terminator
   }
 
-  acc.kernels combined(kernels loop) {
-    acc.loop combined(kernels loop) {
+  acc.kernels combined(loop) {
+    acc.loop combined(kernels) {
       acc.yield
     }
     acc.terminator
   }
 
-  acc.serial combined(serial loop) {
-    acc.loop combined(serial loop) {
+  acc.serial combined(loop) {
+    acc.loop combined(serial) {
       acc.yield
     }
     acc.terminator
@@ -1886,9 +1886,9 @@ func.func @acc_combined() {
   return
 }
 
-// CHECK: acc.parallel combined(parallel loop)
-// CHECK: acc.loop combined(parallel loop)
-// CHECK: acc.kernels combined(kernels loop)
-// CHECK: acc.loop combined(kernels loop)
-// CHECK: acc.serial combined(serial loop)
-// CHECK: acc.loop combined(serial loop)
+// CHECK: acc.parallel combined(loop)
+// CHECK: acc.loop combined(parallel)
+// CHECK: acc.kernels combined(loop)
+// CHECK: acc.loop combined(kernels)
+// CHECK: acc.serial combined(loop)
+// CHECK: acc.loop combined(serial)

>From 3dc717a94b59f7b9f5eb2a52ebf74ecaff6d0165 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Wed, 6 Mar 2024 14:02:58 -0800
Subject: [PATCH 5/5] Move combined printing/parser out of oilist for loop

---
 .../mlir/Dialect/OpenACC/OpenACCOps.td        |  4 +-
 mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp       | 57 +++++++++++--------
 2 files changed, 34 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 0df0be3a61046f..b5ad46361fa698 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1690,9 +1690,9 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
 
   let hasCustomAssemblyFormat = 1;
   let assemblyFormat = [{
+    custom<CombinedConstructsLoop>($combined)
     oilist(
-        `combined` `(` custom<CombinedConstructsLoop>($combined) `)`
-      | `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
+        `gang` `` custom<GangClause>($gangOperands, type($gangOperands),
             $gangOperandsArgType, $gangOperandsDeviceType,
             $gangOperandsSegments, $gang)
       | `worker` `` custom<DeviceTypeOperandsWithKeywordOnly>(
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 11ade13bc7a26a..c09a3403f9a3e3 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1286,20 +1286,25 @@ static void printDeviceTypeOperandsWithKeywordOnly(
 static ParseResult
 parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
                             mlir::acc::CombinedConstructsTypeAttr &attr) {
-  // Just parsing first keyword we know which type of combined construct it is.
-  if (succeeded(parser.parseOptionalKeyword("kernels"))) {
-    attr = mlir::acc::CombinedConstructsTypeAttr::get(
-        parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
-  } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
-    attr = mlir::acc::CombinedConstructsTypeAttr::get(
-        parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
-  } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
-    attr = mlir::acc::CombinedConstructsTypeAttr::get(
-        parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
-  } else {
-    parser.emitError(parser.getCurrentLocation(),
-                     "expected compute construct name");
-    return failure();
+  if (succeeded(parser.parseOptionalKeyword("combined"))) {
+    if (parser.parseLParen())
+      return failure();
+    if (succeeded(parser.parseOptionalKeyword("kernels"))) {
+      attr = mlir::acc::CombinedConstructsTypeAttr::get(
+          parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
+    } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
+      attr = mlir::acc::CombinedConstructsTypeAttr::get(
+          parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
+    } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
+      attr = mlir::acc::CombinedConstructsTypeAttr::get(
+          parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
+    } else {
+      parser.emitError(parser.getCurrentLocation(),
+                       "expected compute construct name");
+      return failure();
+    }
+    if (parser.parseRParen())
+      return failure();
   }
   return success();
 }
@@ -1307,17 +1312,19 @@ parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
 static void
 printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
                             mlir::acc::CombinedConstructsTypeAttr attr) {
-  switch (attr.getValue()) {
-  case mlir::acc::CombinedConstructsType::KernelsLoop:
-    p << "kernels";
-    break;
-  case mlir::acc::CombinedConstructsType::ParallelLoop:
-    p << "parallel";
-    break;
-  case mlir::acc::CombinedConstructsType::SerialLoop:
-    p << "serial";
-    break;
-  };
+  if (attr) {
+    switch (attr.getValue()) {
+    case mlir::acc::CombinedConstructsType::KernelsLoop:
+      p << "combined(kernels)";
+      break;
+    case mlir::acc::CombinedConstructsType::ParallelLoop:
+      p << "combined(parallel)";
+      break;
+    case mlir::acc::CombinedConstructsType::SerialLoop:
+      p << "combined(serial)";
+      break;
+    };
+  }
 }
 
 //===----------------------------------------------------------------------===//



More information about the Mlir-commits mailing list