[Mlir-commits] [mlir] 7a9ef0f - Adding masked operation to OpenMP Dialect (#96022)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 4 16:06:05 PDT 2024


Author: Anchu Rajendran S
Date: 2024-07-04T16:06:01-07:00
New Revision: 7a9ef0f2688805d0e7ea22f91eb3608e8cab6fd4

URL: https://github.com/llvm/llvm-project/commit/7a9ef0f2688805d0e7ea22f91eb3608e8cab6fd4
DIFF: https://github.com/llvm/llvm-project/commit/7a9ef0f2688805d0e7ea22f91eb3608e8cab6fd4.diff

LOG: Adding masked operation to OpenMP Dialect (#96022)

Adding MLIR Op support for omp masked. Omp masked is introduced in 5.2
standard and allows a region to be executed by threads
specified by a programmer. This is achieved with the help of filter
clause which helps to specify thread id expected to execute the region.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index e4955fec80b4f..0eefe06055b7d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -81,6 +81,10 @@ struct DoacrossClauseOps {
   IntegerAttr doacrossNumLoopsAttr;
 };
 
+struct FilterClauseOps {
+  Value filteredThreadIdVar;
+};
+
 struct FinalClauseOps {
   Value finalVar;
 };
@@ -254,8 +258,7 @@ using DistributeClauseOps =
 
 using LoopNestClauseOps = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
 
-// TODO `filter` clause.
-using MaskedClauseOps = detail::Clauses<>;
+using MaskedClauseOps = detail::Clauses<FilterClauseOps>;
 
 using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>;
 

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1fa6edb28a288..99150bc5dff39 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1204,4 +1204,32 @@ class OpenMP_UseDevicePtrClauseSkip<
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [10.5.1] `filter` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_FilterClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
+                    description, extraClassDeclaration> {
+  let arguments = (ins
+    Optional<IntLikeType>:$filtered_thread_id
+  );
+
+  let assemblyFormat = [{
+    `filter` `(` $filtered_thread_id `:` type($filtered_thread_id) `)`
+  }];
+
+  let description = [{
+    If `filter` is specified, the masked construct masks the execution of
+    the region to only the thread id filtered. Other threads executing the
+    parallel region are not expected to execute the region specified within
+    the `masked` directive. If `filter` is not specified, master thread is
+    expected to execute the region enclosed within `masked` directive.
+  }];
+}
+
+def OpenMP_FilterClause : OpenMP_FilterClauseSkip<>;
+
 #endif // OPENMP_CLAUSES

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..1a1ca5e71b3e2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1577,4 +1577,21 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
   let hasRegionVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// [Spec 5.2] 10.5 masked Construct
+//===----------------------------------------------------------------------===//
+def MaskedOp : OpenMP_Op<"masked", clauses = [
+    OpenMP_FilterClause
+  ], singleRegion = 1> {
+  let summary = "masked construct";
+  let description = [{
+    Masked construct allows to specify a structured block to be executed by a subset of 
+    threads of the current team.
+  }] # clausesDescription;
+
+  let builders = [
+    OpBuilder<(ins CArg<"const MaskedClauseOps &">:$clauses)>
+  ];
+}
+
 #endif // OPENMP_OPS

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index abbd857dad67a..23f291bfc2232 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2578,6 +2578,15 @@ LogicalResult PrivateClauseOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Spec 5.2: Masked construct (10.5)
+//===----------------------------------------------------------------------===//
+
+void MaskedOp::build(OpBuilder &builder, OperationState &state,
+                     const MaskedClauseOps &clauses) {
+  MaskedOp::build(builder, state, clauses.filteredThreadIdVar);
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..6a04b9ead746c 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2358,3 +2358,21 @@ func.func @byref_in_private(%arg0: index) {
 
   return
 }
+
+// -----
+func.func @masked_arg_type_mismatch(%arg0: f32) {
+  // expected-error @below {{'omp.masked' op operand #0 must be integer or index, but got 'f32'}}
+  "omp.masked"(%arg0) ({
+      omp.terminator
+    }) : (f32) -> ()
+  return
+}
+
+// -----
+func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
+  // expected-error @below {{'omp.masked' op operand group starting at #0 requires 0 or 1 element, but found 2}}
+  "omp.masked"(%arg0, %arg1) ({
+      omp.terminator
+    }) : (i32, i32) -> ()
+  return
+}

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..d6b655dd20ef8 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -16,6 +16,20 @@ func.func @omp_master() -> () {
   return
 }
 
+// CHECK-LABEL: omp_masked
+func.func @omp_masked(%filtered_thread_id : i32) -> () {
+  // CHECK: omp.masked filter(%{{.*}} : i32)
+  "omp.masked" (%filtered_thread_id) ({
+    omp.terminator
+  }) : (i32) -> ()
+
+  // CHECK: omp.masked
+  "omp.masked" () ({
+    omp.terminator
+  }) : () -> ()
+  return
+}
+
 func.func @omp_taskwait() -> () {
   // CHECK: omp.taskwait
   omp.taskwait


        


More information about the Mlir-commits mailing list