[Mlir-commits] [mlir] Adding masked operation to OpenMP Dialect (PR #96022)

Anchu Rajendran S llvmlistbot at llvm.org
Thu Jul 4 07:57:24 PDT 2024


https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/96022

>From 90a7d37fa5563c1e8b1a1dde9dc898ab83858379 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Sat, 15 Jun 2024 00:03:59 -0500
Subject: [PATCH 1/4] Adding masked operations to OpenMP Dialect

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 41 +++++++++++++++++++
 mlir/test/Dialect/OpenMP/ops.mlir             | 15 +++++++
 2 files changed, 56 insertions(+)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 99e14cd1b7b48..b446b5fcc8576 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1577,4 +1577,45 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
   let hasRegionVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// 2.19.5.4 reduction clause
+//===----------------------------------------------------------------------===//
+
+def ReductionOp : OpenMP_Op<"reduction"> {
+  let summary = "reduction construct";
+  let description = [{
+    Indicates the value that is produced by the current reduction-participating
+    entity for a reduction requested in some ancestor. The reduction is
+    identified by the accumulator, but the value of the accumulator may not be
+    updated immediately.
+  }];
+
+  let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator);
+  let assemblyFormat = [{
+    $operand `,` $accumulator attr-dict `:` type($operand) `,` type($accumulator)
+  }];
+  let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// [Spec 5.2] 10.5 masked Construct
+//===----------------------------------------------------------------------===//
+def MaskedOp : OpenMP_Op<"masked"> {
+  let summary = "masked construct";
+  let description = [{
+    Masked construct allows to specify a structured block to be executed by a subset of 
+    threads of the current team. Filter clause allows to select the threads expected to
+    execute the region
+  }];
+
+  let arguments = (ins Optional<I32>:$filteredThreadId);
+  let regions = (region AnyRegion:$region);
+
+  let assemblyFormat = [{
+    oilist(
+      `filter` `(` $filteredThreadId `:` type($filteredThreadId) `)`
+    ) $region attr-dict
+  }];
+}
+
 #endif // OPENMP_OPS
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index eb283840aa7ee..da981f73f7d6b 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -16,6 +16,21 @@ func.func @omp_master() -> () {
   return
 }
 
+// CHECK-LABEL: omp_masked
+func.func @omp_masked(%filtered_thread_id : i32) -> () {
+
+
+  // CHECK: omp.masked filter(%{{.*}} : i32)
+  "omp.masked" (%filtered_thread_id) ({
+    omp.terminator
+  }) : (i32) -> ()
+
+  // CHECK: omp.masked
+  "omp.masked" () ({
+    omp.terminator
+  }) : () -> ()
+  return
+}
 func.func @omp_taskwait() -> () {
   // CHECK: omp.taskwait
   omp.taskwait

>From 21263f27bbea2235b28610f5af43e4d85f350419 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 20 Jun 2024 00:30:20 -0500
Subject: [PATCH 2/4] xR2: Adding some formatting changes

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 20 -------------------
 mlir/test/Dialect/OpenMP/ops.mlir             |  3 +--
 2 files changed, 1 insertion(+), 22 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index b446b5fcc8576..5432740f569d6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1577,26 +1577,6 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
   let hasRegionVerifier = 1;
 }
 
-//===----------------------------------------------------------------------===//
-// 2.19.5.4 reduction clause
-//===----------------------------------------------------------------------===//
-
-def ReductionOp : OpenMP_Op<"reduction"> {
-  let summary = "reduction construct";
-  let description = [{
-    Indicates the value that is produced by the current reduction-participating
-    entity for a reduction requested in some ancestor. The reduction is
-    identified by the accumulator, but the value of the accumulator may not be
-    updated immediately.
-  }];
-
-  let arguments= (ins AnyType:$operand, OpenMP_PointerLikeType:$accumulator);
-  let assemblyFormat = [{
-    $operand `,` $accumulator attr-dict `:` type($operand) `,` type($accumulator)
-  }];
-  let hasVerifier = 1;
-}
-
 //===----------------------------------------------------------------------===//
 // [Spec 5.2] 10.5 masked Construct
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index da981f73f7d6b..d6b655dd20ef8 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -18,8 +18,6 @@ func.func @omp_master() -> () {
 
 // CHECK-LABEL: omp_masked
 func.func @omp_masked(%filtered_thread_id : i32) -> () {
-
-
   // CHECK: omp.masked filter(%{{.*}} : i32)
   "omp.masked" (%filtered_thread_id) ({
     omp.terminator
@@ -31,6 +29,7 @@ func.func @omp_masked(%filtered_thread_id : i32) -> () {
   }) : () -> ()
   return
 }
+
 func.func @omp_taskwait() -> () {
   // CHECK: omp.taskwait
   omp.taskwait

>From 8271c8e375e6b7f7cabcadcd98daf4170decdb63 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 2 Jul 2024 13:44:24 -0500
Subject: [PATCH 3/4] R3: Updating the op definition according to new clause
 definitions

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 28 +++++++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 18 ++++++------
 mlir/test/Dialect/OpenMP/invalid.mlir         | 18 ++++++++++++
 3 files changed, 54 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1fa6edb28a288..99150bc5dff39 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1204,4 +1204,32 @@ class OpenMP_UseDevicePtrClauseSkip<
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [10.5.1] `filter` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_FilterClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause</*isRequired=*/false, traits, arguments, assemblyFormat,
+                    description, extraClassDeclaration> {
+  let arguments = (ins
+    Optional<IntLikeType>:$filtered_thread_id
+  );
+
+  let assemblyFormat = [{
+    `filter` `(` $filtered_thread_id `:` type($filtered_thread_id) `)`
+  }];
+
+  let description = [{
+    If `filter` is specified, the masked construct masks the execution of
+    the region to only the thread id filtered. Other threads executing the
+    parallel region are not expected to execute the region specified within
+    the `masked` directive. If `filter` is not specified, master thread is
+    expected to execute the region enclosed within `masked` directive.
+  }];
+}
+
+def OpenMP_FilterClause : OpenMP_FilterClauseSkip<>;
+
 #endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5432740f569d6..b47dee23564d1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1580,22 +1580,20 @@ def DeclareReductionOp : OpenMP_Op<"declare_reduction", [IsolatedFromAbove,
 //===----------------------------------------------------------------------===//
 // [Spec 5.2] 10.5 masked Construct
 //===----------------------------------------------------------------------===//
-def MaskedOp : OpenMP_Op<"masked"> {
+def MaskedOp : OpenMP_Op<"masked", clauses = [
+    OpenMP_FilterClause
+  ], singleRegion = 1> {
   let summary = "masked construct";
   let description = [{
     Masked construct allows to specify a structured block to be executed by a subset of 
-    threads of the current team. Filter clause allows to select the threads expected to
-    execute the region
-  }];
+    threads of the current team.
+  }] # clausesDescription;
 
-  let arguments = (ins Optional<I32>:$filteredThreadId);
   let regions = (region AnyRegion:$region);
 
-  let assemblyFormat = [{
-    oilist(
-      `filter` `(` $filteredThreadId `:` type($filteredThreadId) `)`
-    ) $region attr-dict
-  }];
+  let builders = [
+    OpBuilder<(ins CArg<"const MaskedClauseOps &">:$clauses)>
+  ];
 }
 
 #endif // OPENMP_OPS
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2915963f704d3..6a04b9ead746c 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2358,3 +2358,21 @@ func.func @byref_in_private(%arg0: index) {
 
   return
 }
+
+// -----
+func.func @masked_arg_type_mismatch(%arg0: f32) {
+  // expected-error @below {{'omp.masked' op operand #0 must be integer or index, but got 'f32'}}
+  "omp.masked"(%arg0) ({
+      omp.terminator
+    }) : (f32) -> ()
+  return
+}
+
+// -----
+func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {
+  // expected-error @below {{'omp.masked' op operand group starting at #0 requires 0 or 1 element, but found 2}}
+  "omp.masked"(%arg0, %arg1) ({
+      omp.terminator
+    }) : (i32, i32) -> ()
+  return
+}

>From 5df2f927104eb8f2c5513b16b3e8b350f0b070c1 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 4 Jul 2024 09:55:48 -0500
Subject: [PATCH 4/4] R4: Adding new build operation for Masked, removed
 redundant codes

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h | 7 +++++--
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td           | 2 --
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp            | 9 +++++++++
 3 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index e4955fec80b4f..0eefe06055b7d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -81,6 +81,10 @@ struct DoacrossClauseOps {
   IntegerAttr doacrossNumLoopsAttr;
 };
 
+struct FilterClauseOps {
+  Value filteredThreadIdVar;
+};
+
 struct FinalClauseOps {
   Value finalVar;
 };
@@ -254,8 +258,7 @@ using DistributeClauseOps =
 
 using LoopNestClauseOps = detail::Clauses<CollapseClauseOps, LoopRelatedOps>;
 
-// TODO `filter` clause.
-using MaskedClauseOps = detail::Clauses<>;
+using MaskedClauseOps = detail::Clauses<FilterClauseOps>;
 
 using OrderedOpClauseOps = detail::Clauses<DoacrossClauseOps>;
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index b47dee23564d1..1a1ca5e71b3e2 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1589,8 +1589,6 @@ def MaskedOp : OpenMP_Op<"masked", clauses = [
     threads of the current team.
   }] # clausesDescription;
 
-  let regions = (region AnyRegion:$region);
-
   let builders = [
     OpBuilder<(ins CArg<"const MaskedClauseOps &">:$clauses)>
   ];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index abbd857dad67a..05dde88f3a48b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2578,6 +2578,15 @@ LogicalResult PrivateClauseOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Spec 5.2: Masked construct (10.5)
+//===----------------------------------------------------------------------===//
+
+void MaskedOp::build(OpBuilder &builder, OperationState &state,
+                              const MaskedClauseOps &clauses) {
+  MaskedOp::build(builder, state, clauses.filteredThreadIdVar);
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 



More information about the Mlir-commits mailing list