[llvm] 30bd11f - [MLIR][OpenMP] Fixed the missing inclusive clause in omp.wsloop and fix order clause

Shraiysh Vaishay via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 28 01:48:17 PDT 2021


Author: Shraiysh Vaishay
Date: 2021-10-28T14:18:05+05:30
New Revision: 30bd11fab47f75e43ba9d0133978d964eef819ca

URL: https://github.com/llvm/llvm-project/commit/30bd11fab47f75e43ba9d0133978d964eef819ca
DIFF: https://github.com/llvm/llvm-project/commit/30bd11fab47f75e43ba9d0133978d964eef819ca.diff

LOG: [MLIR][OpenMP] Fixed the missing inclusive clause in omp.wsloop and fix order clause

This patch adds the inclusive clause (which was missed in previous
reorganization - https://reviews.llvm.org/D110903) in omp.wsloop operation.
Added a test for validating it.

Also fixes the order clause, which was not accepting any values. It now accepts
"concurrent" as a value, as specified in the standard.

Reviewed By: kiranchandramohan, peixin, clementval

Differential Revision: https://reviews.llvm.org/D112198

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMP.td
    llvm/unittests/Frontend/OpenMPParsingTest.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/test/Dialect/OpenMP/invalid.mlir
    mlir/test/Dialect/OpenMP/ops.mlir

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index fffd8d75f6a71..5fd30412721be 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -285,11 +285,13 @@ def OMPC_NonTemporal : Clause<"nontemporal"> {
   let isValueList = true; 
 }
 
-def OMP_ORDER_concurrent : ClauseVal<"default",2,0> { let isDefault = 1; }
+def OMP_ORDER_concurrent : ClauseVal<"concurrent",1,1> {}
+def OMP_ORDER_unknown : ClauseVal<"unknown",2,0> { let isDefault = 1; }
 def OMPC_Order : Clause<"order"> {
   let clangClass = "OMPOrderClause";
   let enumClauseValue = "OrderKind";
   let allowedClauseValues = [
+    OMP_ORDER_unknown,
     OMP_ORDER_concurrent
   ];
 }

diff  --git a/llvm/unittests/Frontend/OpenMPParsingTest.cpp b/llvm/unittests/Frontend/OpenMPParsingTest.cpp
index ea06b34658c53..227e08c511d13 100644
--- a/llvm/unittests/Frontend/OpenMPParsingTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPParsingTest.cpp
@@ -55,8 +55,9 @@ TEST(OpenMPParsingTest, isAllowedClauseForDirective) {
 }
 
 TEST(OpenMPParsingTest, getOrderKind) {
-  EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_concurrent);
-  EXPECT_EQ(getOrderKind("default"), OMP_ORDER_concurrent);
+  EXPECT_EQ(getOrderKind("foobar"), OMP_ORDER_unknown);
+  EXPECT_EQ(getOrderKind("unknown"), OMP_ORDER_unknown);
+  EXPECT_EQ(getOrderKind("concurrent"), OMP_ORDER_concurrent);
 }
 
 TEST(OpenMPParsingTest, getProcBindKind) {

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 14e899489f6fc..e85a4b722aced 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -492,7 +492,6 @@ enum ClauseType {
   collapseClause,
   orderClause,
   orderedClause,
-  inclusiveClause,
   memoryOrderClause,
   hintClause,
   COUNT
@@ -577,8 +576,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
     // segments
     if (clause == defaultClause || clause == procBindClause ||
         clause == nowaitClause || clause == collapseClause ||
-        clause == orderClause || clause == orderedClause ||
-        clause == inclusiveClause)
+        clause == orderClause || clause == orderedClause)
       continue;
 
     pos[clause] = currPos++;
@@ -596,7 +594,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
                           bool allowRepeat = false) -> ParseResult {
     if (!llvm::is_contained(clauses, clause))
       return parser.emitError(parser.getCurrentLocation())
-             << clauseKeyword << "is not a valid clause for the " << opName
+             << clauseKeyword << " is not a valid clause for the " << opName
              << " operation";
     if (done[clause] && !allowRepeat)
       return parser.emitError(parser.getCurrentLocation())
@@ -717,12 +715,7 @@ static ParseResult parseClauses(OpAsmParser &parser, OperationState &result,
           parser.parseKeyword(&order) || parser.parseRParen())
         return failure();
       auto attr = parser.getBuilder().getStringAttr(order);
-      result.addAttribute("order", attr);
-    } else if (clauseKeyword == "inclusive") {
-      if (checkAllowed(inclusiveClause))
-        return failure();
-      auto attr = UnitAttr::get(parser.getBuilder().getContext());
-      result.addAttribute("inclusive", attr);
+      result.addAttribute("order_val", attr);
     } else if (clauseKeyword == "memory_order") {
       StringRef memoryOrder;
       if (checkAllowed(memoryOrderClause) || parser.parseLParen() ||
@@ -875,11 +868,11 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
 ///
 /// wsloop ::= `omp.wsloop` loop-control clause-list
 /// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
-/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
+/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
 /// steps := `step` `(`ssa-id-list`)`
 /// clause-list ::= clause clause-list | empty
 /// clause ::= private | firstprivate | lastprivate | linear | schedule |
-//             collapse | nowait | ordered | order | inclusive | reduction
+//             collapse | nowait | ordered | order | reduction
 static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
 
   // Parse an opening `(` followed by induction variables followed by `)`
@@ -906,6 +899,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
       parser.resolveOperands(upper, loopVarType, result.operands))
     return failure();
 
+  if (succeeded(parser.parseOptionalKeyword("inclusive"))) {
+    auto attr = UnitAttr::get(parser.getBuilder().getContext());
+    result.addAttribute("inclusive", attr);
+  }
+
   // Parse step values.
   SmallVector<OpAsmParser::OperandType> steps;
   if (parser.parseKeyword("step") ||
@@ -936,7 +934,11 @@ static ParseResult parseWsLoopOp(OpAsmParser &parser, OperationState &result) {
 static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
   auto args = op.getRegion().front().getArguments();
   p << " (" << args << ") : " << args[0].getType() << " = (" << op.lowerBound()
-    << ") to (" << op.upperBound() << ") step (" << op.step() << ") ";
+    << ") to (" << op.upperBound() << ") ";
+  if (op.inclusive()) {
+    p << "inclusive ";
+  }
+  p << "step (" << op.step() << ") ";
 
   printDataVars(p, op.private_vars(), "private");
   printDataVars(p, op.firstprivate_vars(), "firstprivate");
@@ -962,15 +964,14 @@ static void printWsLoopOp(OpAsmPrinter &p, WsLoopOp op) {
   if (auto ordered = op.ordered_val())
     p << "ordered(" << ordered << ") ";
 
+  if (auto order = op.order_val())
+    p << "order(" << order << ") ";
+
   if (!op.reduction_vars().empty()) {
     p << "reduction(";
     printReductionVarList(p, op.reductions(), op.reduction_vars());
   }
 
-  if (op.inclusive()) {
-    p << "inclusive ";
-  }
-
   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
 }
 

diff  --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e57ddfcbfec86..36eee320af555 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -69,7 +69,62 @@ func @copyin_once(%n : memref<i32>) {
 }
 
 // -----
- 
+
+func @lastprivate_not_allowed(%n : memref<i32>) {
+  // expected-error at +1 {{lastprivate is not a valid clause for the omp.parallel operation}}
+  omp.parallel lastprivate(%n : memref<i32>) {}
+  return
+}
+
+// -----
+
+func @nowait_not_allowed(%n : memref<i32>) {
+  // expected-error at +1 {{nowait is not a valid clause for the omp.parallel operation}}
+  omp.parallel nowait {}
+  return
+}
+
+// -----
+
+func @linear_not_allowed(%data_var : memref<i32>, %linear_var : i32) {
+  // expected-error at +1 {{linear is not a valid clause for the omp.parallel operation}}
+  omp.parallel linear(%data_var = %linear_var : memref<i32>)  {}
+  return
+}
+
+// -----
+
+func @schedule_not_allowed() {
+  // expected-error at +1 {{schedule is not a valid clause for the omp.parallel operation}}
+  omp.parallel schedule(static) {}
+  return
+}
+
+// -----
+
+func @collapse_not_allowed() {
+  // expected-error at +1 {{collapse is not a valid clause for the omp.parallel operation}}
+  omp.parallel collapse(3) {}
+  return
+}
+
+// -----
+
+func @order_not_allowed() {
+  // expected-error at +1 {{order is not a valid clause for the omp.parallel operation}}
+  omp.parallel order(concurrent) {}
+  return
+}
+
+// -----
+
+func @ordered_not_allowed() {
+  // expected-error at +1 {{ordered is not a valid clause for the omp.parallel operation}}
+  omp.parallel ordered(2) {}
+}
+
+// -----
+
 func @default_once() {
   // expected-error at +1 {{at most one default clause can appear on the omp.parallel operation}}
   omp.parallel default(private) default(firstprivate) {
@@ -90,6 +145,78 @@ func @proc_bind_once() {
 
 // -----
 
+func @inclusive_not_a_clause(%lb : index, %ub : index, %step : index) {
+  // expected-error @below {{inclusive is not a valid clause}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait inclusive {
+    omp.yield
+  }
+}
+
+// -----
+
+func @order_value(%lb : index, %ub : index, %step : index) {
+  // expected-error @below {{attribute 'order_val' failed to satisfy constraint: OrderKind Clause}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(default) {
+    omp.yield
+  }
+}
+
+// -----
+
+func @shared_not_allowed(%lb : index, %ub : index, %step : index, %var : memref<i32>) {
+  // expected-error @below {{shared is not a valid clause for the omp.wsloop operation}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) shared(%var) {
+    omp.yield
+  }
+}
+
+// -----
+
+func @copyin(%lb : index, %ub : index, %step : index, %var : memref<i32>) {
+  // expected-error @below {{copyin is not a valid clause for the omp.wsloop operation}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) copyin(%var) {
+    omp.yield
+  }
+}
+
+// -----
+
+func @if_not_allowed(%lb : index, %ub : index, %step : index, %bool_var : i1) {
+  // expected-error @below {{if is not a valid clause for the omp.wsloop operation}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) if(%bool_var: i1) {
+    omp.yield
+  }
+}
+
+// -----
+
+func @num_threads_not_allowed(%lb : index, %ub : index, %step : index, %int_var : i32) {
+  // expected-error @below {{num_threads is not a valid clause for the omp.wsloop operation}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) num_threads(%int_var: i32) {
+    omp.yield
+  }
+}
+
+// -----
+
+func @default_not_allowed(%lb : index, %ub : index, %step : index) {
+  // expected-error @below {{default is not a valid clause for the omp.wsloop operation}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) default(private) {
+    omp.yield
+  }
+}
+
+// -----
+
+func @proc_bind_not_allowed(%lb : index, %ub : index, %step : index) {
+  // expected-error @below {{proc_bind is not a valid clause for the omp.wsloop operation}}
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) proc_bind(close) {
+    omp.yield
+  }
+}
+
+// -----
+
 // expected-error @below {{op expects initializer region with one argument of the reduction type}}
 omp.reduction.declare @add_f32 : f64
 init {

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 0d7c7af74579b..4d0801de0cd44 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -123,7 +123,27 @@ func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_threads :
    omp.terminator
  }
 
- return
+  // CHECK: omp.parallel default(private)
+  omp.parallel default(private) {
+    omp.terminator
+  }
+
+  // CHECK: omp.parallel default(firstprivate)
+  omp.parallel default(firstprivate) {
+    omp.terminator
+  }
+
+  // CHECK: omp.parallel default(shared)
+  omp.parallel default(shared) {
+    omp.terminator
+  }
+
+  // CHECK: omp.parallel default(none)
+  omp.parallel default(none) {
+    omp.terminator
+  }
+
+  return
 }
 
 // CHECK-LABEL: omp_wsloop
@@ -207,6 +227,21 @@ func @omp_wsloop_pretty(%lb : index, %ub : index, %step : index,
     omp.yield
   }
 
+  // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) inclusive step (%{{.*}})
+  omp.wsloop (%iv) : index = (%lb) to (%ub) inclusive step (%step) {
+    omp.yield
+  }
+
+  // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) nowait {
+    omp.yield
+  }
+
+  // CHECK: omp.wsloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}}) nowait order(concurrent)
+  omp.wsloop (%iv) : index = (%lb) to (%ub) step (%step) order(concurrent) nowait {
+    omp.yield
+  }
+
   return
 }
 


        


More information about the llvm-commits mailing list