[Mlir-commits] [mlir] [MLIR][omp] Add omp.workshare op (PR #101443)

Ivan R. Ivanov llvmlistbot at llvm.org
Sat Oct 19 10:23:00 PDT 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/101443

>From b0d76b0f7428f423b07c9ea2d191bd5444e158cd Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:09:09 +0900
Subject: [PATCH 1/4] [MLIR][omp] Add omp.workshare op

Add custom omp loop wrapper

Add recursive memory effects trait to workshare

Remove stray include

Remove omp.workshare verifier

Add assembly format for wrapper and add test

Add verification and descriptions
---
 .../Dialect/OpenMP/OpenMPClauseOperands.h     | 73 ++++++++++++++++++-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 43 +++++++++++
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 23 ++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         | 42 +++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 69 ++++++++++++++++++
 5 files changed, 246 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 1247a871f93c6d..2223636c0d7a0f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -44,10 +44,75 @@ struct DeviceTypeClauseOps {
 // TODO: Add `indirect` clause.
 using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
 
-/// omp.target_enter_data, omp.target_exit_data and omp.target_update take the
-/// same clauses, so we give the structure to be shared by all of them a
-/// representative name.
-using TargetEnterExitUpdateDataOperands = TargetEnterDataOperands;
+using DistributeOperands =
+    detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
+                    PrivateClauseOps>;
+
+using LoopNestOperands = detail::Clauses<LoopRelatedOps>;
+
+using MaskedOperands = detail::Clauses<FilterClauseOps>;
+
+using OrderedOperands = detail::Clauses<DoacrossClauseOps>;
+
+using OrderedRegionOperands = detail::Clauses<ParallelizationLevelClauseOps>;
+
+using ParallelOperands =
+    detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
+                    PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
+
+using SectionsOperands = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
+                                         PrivateClauseOps, ReductionClauseOps>;
+
+using SimdOperands =
+    detail::Clauses<AlignedClauseOps, IfClauseOps, LinearClauseOps,
+                    NontemporalClauseOps, OrderClauseOps, PrivateClauseOps,
+                    ReductionClauseOps, SafelenClauseOps, SimdlenClauseOps>;
+
+using SingleOperands = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
+                                       NowaitClauseOps, PrivateClauseOps>;
+
+// TODO `defaultmap`, `uses_allocators` clauses.
+using TargetOperands =
+    detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
+                    HasDeviceAddrClauseOps, IfClauseOps, InReductionClauseOps,
+                    IsDevicePtrClauseOps, MapClauseOps, NowaitClauseOps,
+                    PrivateClauseOps, ThreadLimitClauseOps>;
+
+using TargetDataOperands =
+    detail::Clauses<DeviceClauseOps, IfClauseOps, MapClauseOps,
+                    UseDeviceAddrClauseOps, UseDevicePtrClauseOps>;
+
+using TargetEnterExitUpdateDataOperands =
+    detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
+                    NowaitClauseOps>;
+
+// TODO `affinity`, `detach` clauses.
+using TaskOperands =
+    detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
+                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
+                    PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
+
+using TaskgroupOperands =
+    detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
+
+using TaskloopOperands =
+    detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
+                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
+                    NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
+                    PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
+
+using TaskwaitOperands = detail::Clauses<DependClauseOps, NowaitClauseOps>;
+
+using TeamsOperands =
+    detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
+                    PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
+
+using WorkshareOperands = detail::Clauses<NowaitClauseOps>;
+
+using WsloopOperands =
+    detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
+                    OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
+                    ReductionClauseOps, ScheduleClauseOps>;
 
 } // namespace omp
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 45313200d4f0b9..71908b6112e3e6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -313,6 +313,49 @@ def SingleOp : OpenMP_Op<"single", traits = [
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// 2.8.3 Workshare Construct
+//===----------------------------------------------------------------------===//
+
+def WorkshareOp : OpenMP_Op<"workshare", traits = [
+    RecursiveMemoryEffects,
+  ], clauses = [
+    OpenMP_NowaitClause,
+  ], singleRegion = true> {
+  let summary = "workshare directive";
+  let description = [{
+    The workshare construct divides the execution of the enclosed structured
+    block into separate units of work, and causes the threads of the team to
+    share the work such that each unit is executed only once by one thread, in
+    the context of its implicit task
+
+    This operation is used for the intermediate representation of the workshare
+    block before the work gets divided between the threads. See the flang
+    LowerWorkshare pass for details.
+  }] # clausesDescription;
+
+  let builders = [
+    OpBuilder<(ins CArg<"const WorkshareOperands &">:$clauses)>
+  ];
+}
+
+def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
+    DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    RecursiveMemoryEffects, SingleBlock
+  ], singleRegion = true> {
+  let summary = "contains loop nests to be parallelized by workshare";
+  let description = [{
+    This operation wraps a loop nest that is marked for dividing into units of
+    work by an encompassing omp.workshare operation.
+  }];
+
+  let builders = [
+    OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
+  ];
+  let assemblyFormat = "$region attr-dict";
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Loop Nest
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e1df647d6a3c71..cbb8fe36eb24e2 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1921,6 +1921,29 @@ LogicalResult SingleOp::verify() {
                                   getCopyprivateSyms());
 }
 
+//===----------------------------------------------------------------------===//
+// WorkshareOp
+//===----------------------------------------------------------------------===//
+
+void WorkshareOp::build(OpBuilder &builder, OperationState &state,
+                        const WorkshareOperands &clauses) {
+  WorkshareOp::build(builder, state, clauses.nowait);
+}
+
+//===----------------------------------------------------------------------===//
+// WorkshareLoopWrapperOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkshareLoopWrapperOp::verify() {
+  if (!isWrapper())
+    return emitOpError() << "must be a loop wrapper";
+  if (getNestedWrapper())
+    return emitError() << "nested wrappers not supported";
+  if (!(*this)->getParentOfType<WorkshareOp>())
+    return emitError() << "must be nested in an omp.workshare";
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // LoopWrapperInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index fd89ec31c64a60..e8482574408c0c 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2577,3 +2577,45 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)
   } {omp.composite}
   return
 }
+
+// -----
+func.func @nested_wrapper(%idx : index) {
+  omp.workshare {
+    // expected-error @below {{nested wrappers not supported}}
+    omp.workshare.loop_wrapper {
+      omp.simd {
+        omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+          omp.yield
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+func.func @not_wrapper() {
+  omp.workshare {
+    // expected-error @below {{must be a loop wrapper}}
+    omp.workshare.loop_wrapper {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+func.func @missing_workshare(%idx : index) {
+  // expected-error @below {{must be nested in an omp.workshare}}
+  omp.workshare.loop_wrapper {
+    omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+      omp.yield
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6f11b451fa00a3..d839db9d7583e3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2749,3 +2749,72 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
 
   return
 }
+
+// CHECK-LABEL: func @omp_workshare
+func.func @omp_workshare() {
+  // CHECK: omp.workshare {
+  omp.workshare {
+    "test.payload"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_nowait
+func.func @omp_workshare_nowait() {
+  // CHECK: omp.workshare nowait {
+  omp.workshare nowait {
+    "test.payload"() : () -> ()
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare_multiple_blocks
+func.func @omp_workshare_multiple_blocks() {
+  // CHECK: omp.workshare {
+  omp.workshare {
+    cf.br ^bb2
+    ^bb2:
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare.loop_wrapper
+func.func @omp_workshare.loop_wrapper(%idx : index) {
+  // CHECK-NEXT: omp.workshare {
+  omp.workshare {
+    // CHECK-NEXT: omp.workshare.loop_wrapper
+    omp.workshare.loop_wrapper {
+      // CHECK-NEXT: omp.loop_nest
+      omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+        omp.yield
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func @omp_workshare.loop_wrapper_attrs
+func.func @omp_workshare.loop_wrapper_attrs(%idx : index) {
+  // CHECK-NEXT: omp.workshare {
+  omp.workshare {
+    // CHECK-NEXT: omp.workshare.loop_wrapper {
+    omp.workshare.loop_wrapper {
+      // CHECK-NEXT: omp.loop_nest
+      omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+        omp.yield
+      }
+      omp.terminator
+    // CHECK: } {attr_in_dict}
+    } {attr_in_dict}
+    omp.terminator
+  }
+  return
+}

>From 88d6d03f1ca0e9769eb94e0c620e7357dec2e3e1 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Thu, 22 Aug 2024 18:05:31 +0900
Subject: [PATCH 2/4] wrong replace

---
 mlir/test/Dialect/OpenMP/ops.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d839db9d7583e3..c4d4d5055706e6 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2784,8 +2784,8 @@ func.func @omp_workshare_multiple_blocks() {
   return
 }
 
-// CHECK-LABEL: func @omp_workshare.loop_wrapper
-func.func @omp_workshare.loop_wrapper(%idx : index) {
+// CHECK-LABEL: func @omp_workshare_loop_wrapper
+func.func @omp_workshare_loop_wrapper(%idx : index) {
   // CHECK-NEXT: omp.workshare {
   omp.workshare {
     // CHECK-NEXT: omp.workshare.loop_wrapper
@@ -2801,8 +2801,8 @@ func.func @omp_workshare.loop_wrapper(%idx : index) {
   return
 }
 
-// CHECK-LABEL: func @omp_workshare.loop_wrapper_attrs
-func.func @omp_workshare.loop_wrapper_attrs(%idx : index) {
+// CHECK-LABEL: func @omp_workshare_loop_wrapper_attrs
+func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
   // CHECK-NEXT: omp.workshare {
   omp.workshare {
     // CHECK-NEXT: omp.workshare.loop_wrapper {

>From f8d7b474aeea8123f4cfdae805b14bdfa569db85 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 20 Oct 2024 01:33:13 +0900
Subject: [PATCH 3/4] Fix wsloopwrapperop

---
 .../Dialect/OpenMP/OpenMPClauseOperands.h     | 73 +------------------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  4 -
 mlir/test/Dialect/OpenMP/ops.mlir             |  2 -
 4 files changed, 5 insertions(+), 76 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 2223636c0d7a0f..1247a871f93c6d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -44,75 +44,10 @@ struct DeviceTypeClauseOps {
 // TODO: Add `indirect` clause.
 using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>;
 
-using DistributeOperands =
-    detail::Clauses<AllocateClauseOps, DistScheduleClauseOps, OrderClauseOps,
-                    PrivateClauseOps>;
-
-using LoopNestOperands = detail::Clauses<LoopRelatedOps>;
-
-using MaskedOperands = detail::Clauses<FilterClauseOps>;
-
-using OrderedOperands = detail::Clauses<DoacrossClauseOps>;
-
-using OrderedRegionOperands = detail::Clauses<ParallelizationLevelClauseOps>;
-
-using ParallelOperands =
-    detail::Clauses<AllocateClauseOps, IfClauseOps, NumThreadsClauseOps,
-                    PrivateClauseOps, ProcBindClauseOps, ReductionClauseOps>;
-
-using SectionsOperands = detail::Clauses<AllocateClauseOps, NowaitClauseOps,
-                                         PrivateClauseOps, ReductionClauseOps>;
-
-using SimdOperands =
-    detail::Clauses<AlignedClauseOps, IfClauseOps, LinearClauseOps,
-                    NontemporalClauseOps, OrderClauseOps, PrivateClauseOps,
-                    ReductionClauseOps, SafelenClauseOps, SimdlenClauseOps>;
-
-using SingleOperands = detail::Clauses<AllocateClauseOps, CopyprivateClauseOps,
-                                       NowaitClauseOps, PrivateClauseOps>;
-
-// TODO `defaultmap`, `uses_allocators` clauses.
-using TargetOperands =
-    detail::Clauses<AllocateClauseOps, DependClauseOps, DeviceClauseOps,
-                    HasDeviceAddrClauseOps, IfClauseOps, InReductionClauseOps,
-                    IsDevicePtrClauseOps, MapClauseOps, NowaitClauseOps,
-                    PrivateClauseOps, ThreadLimitClauseOps>;
-
-using TargetDataOperands =
-    detail::Clauses<DeviceClauseOps, IfClauseOps, MapClauseOps,
-                    UseDeviceAddrClauseOps, UseDevicePtrClauseOps>;
-
-using TargetEnterExitUpdateDataOperands =
-    detail::Clauses<DependClauseOps, DeviceClauseOps, IfClauseOps, MapClauseOps,
-                    NowaitClauseOps>;
-
-// TODO `affinity`, `detach` clauses.
-using TaskOperands =
-    detail::Clauses<AllocateClauseOps, DependClauseOps, FinalClauseOps,
-                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
-                    PriorityClauseOps, PrivateClauseOps, UntiedClauseOps>;
-
-using TaskgroupOperands =
-    detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
-
-using TaskloopOperands =
-    detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
-                    IfClauseOps, InReductionClauseOps, MergeableClauseOps,
-                    NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
-                    PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
-
-using TaskwaitOperands = detail::Clauses<DependClauseOps, NowaitClauseOps>;
-
-using TeamsOperands =
-    detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
-                    PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
-
-using WorkshareOperands = detail::Clauses<NowaitClauseOps>;
-
-using WsloopOperands =
-    detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
-                    OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
-                    ReductionClauseOps, ScheduleClauseOps>;
+/// omp.target_enter_data, omp.target_exit_data and omp.target_update take the
+/// same clauses, so we give the structure to be shared by all of them a
+/// representative name.
+using TargetEnterExitUpdateDataOperands = TargetEnterDataOperands;
 
 } // namespace omp
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 71908b6112e3e6..c8e2c9ed4fd614 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -340,7 +340,7 @@ def WorkshareOp : OpenMP_Op<"workshare", traits = [
 }
 
 def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
-    DeclareOpInterfaceMethods<LoopWrapperInterface>,
+    DeclareOpInterfaceMethods<LoopWrapperInterface>, NoTerminator,
     RecursiveMemoryEffects, SingleBlock
   ], singleRegion = true> {
   let summary = "contains loop nests to be parallelized by workshare";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index cbb8fe36eb24e2..e849f68e9c8380 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1935,10 +1935,6 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,
 //===----------------------------------------------------------------------===//
 
 LogicalResult WorkshareLoopWrapperOp::verify() {
-  if (!isWrapper())
-    return emitOpError() << "must be a loop wrapper";
-  if (getNestedWrapper())
-    return emitError() << "nested wrappers not supported";
   if (!(*this)->getParentOfType<WorkshareOp>())
     return emitError() << "must be nested in an omp.workshare";
   return success();
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c4d4d5055706e6..0532b71307a616 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2794,7 +2794,6 @@ func.func @omp_workshare_loop_wrapper(%idx : index) {
       omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
         omp.yield
       }
-      omp.terminator
     }
     omp.terminator
   }
@@ -2811,7 +2810,6 @@ func.func @omp_workshare_loop_wrapper_attrs(%idx : index) {
       omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
         omp.yield
       }
-      omp.terminator
     // CHECK: } {attr_in_dict}
     } {attr_in_dict}
     omp.terminator

>From d001eec514d0e7104a2279165e76a0c1694174a1 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 20 Oct 2024 02:22:08 +0900
Subject: [PATCH 4/4] Fix op tests

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp |  2 ++
 mlir/test/Dialect/OpenMP/invalid.mlir        | 11 ++++-------
 2 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e849f68e9c8380..43c1ec66e1ae31 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1937,6 +1937,8 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,
 LogicalResult WorkshareLoopWrapperOp::verify() {
   if (!(*this)->getParentOfType<WorkshareOp>())
     return emitError() << "must be nested in an omp.workshare";
+  if (getNestedWrapper())
+    return emitError() << "cannot be composite";
   return success();
 }
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e8482574408c0c..d56629e76b09c3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2581,15 +2581,13 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)
 // -----
 func.func @nested_wrapper(%idx : index) {
   omp.workshare {
-    // expected-error @below {{nested wrappers not supported}}
+    // expected-error @below {{cannot be composite}}
     omp.workshare.loop_wrapper {
       omp.simd {
         omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
           omp.yield
         }
-        omp.terminator
-      }
-      omp.terminator
+      } {omp.composite}
     }
     omp.terminator
   }
@@ -2599,9 +2597,9 @@ func.func @nested_wrapper(%idx : index) {
 // -----
 func.func @not_wrapper() {
   omp.workshare {
-    // expected-error @below {{must be a loop wrapper}}
+    // expected-error @below {{op nested in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
     omp.workshare.loop_wrapper {
-      omp.terminator
+      %0 = arith.constant 0 : index
     }
     omp.terminator
   }
@@ -2615,7 +2613,6 @@ func.func @missing_workshare(%idx : index) {
     omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
       omp.yield
     }
-    omp.terminator
   }
   return
 }



More information about the Mlir-commits mailing list