[Mlir-commits] [mlir] inscan reduction and scan op mlir support (PR #114737)

Anchu Rajendran S llvmlistbot at llvm.org
Sun Nov 3 21:06:17 PST 2024


https://github.com/anchuraj created https://github.com/llvm/llvm-project/pull/114737

Scan directive allows to specify scan reductions within an worksharing loop, worksharing loop simd or simd directive which should have an `InScan` modifier associated with it. This change adds the mlir support for the same.

>From a9363752e05efc133a290e37d0a008421ede9ab8 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Sat, 2 Nov 2024 21:57:20 -0500
Subject: [PATCH] Changes for inscan reduction and scan op

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  57 ++++++++
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |  21 +++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  32 ++++-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |   1 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 128 +++++++++++++-----
 mlir/test/Dialect/OpenMP/ops.mlir             |  23 ++++
 6 files changed, 223 insertions(+), 39 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 886554f66afffc..b45d89463639c5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -283,6 +283,34 @@ class OpenMP_DoacrossClauseSkip<
 
 def OpenMP_DoacrossClause : OpenMP_DoacrossClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `exclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_ExclusiveClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    Variadic<AnyType>:$exclusive_vars
+  );
+
+  let optAssemblyFormat = [{
+    `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)`
+  }];
+
+  let description = [{
+    The exclusive clause is used on a separating directive that separates a
+    structured block into two structured block sequences. If it
+    is specified, the input phase excludes the preceding structured block 
+    sequence and instead includes the following structured block sequence, 
+    while the scan phase includes the preceding structured block sequence.
+  }];
+}
+
+def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [10.5.1] `filter` clause
 //===----------------------------------------------------------------------===//
@@ -393,6 +421,34 @@ class OpenMP_HasDeviceAddrClauseSkip<
 
 def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `inclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_InclusiveClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    Variadic<AnyType>:$inclusive_vars
+  );
+
+  let optAssemblyFormat = [{
+    `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)`
+  }];
+
+  let description = [{
+    The inclusive clause is used on a separating directive that separates a
+    structured block into two structured block sequences. If it is specified,
+    the input phase includes the preceding structured block sequence and the
+    scan phase includes the following structured block sequence.
+  }];
+}
+
+def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>;
+
+
 //===----------------------------------------------------------------------===//
 // V5.2: [15.1.2] `hint` clause
 //===----------------------------------------------------------------------===//
@@ -983,6 +1039,7 @@ class OpenMP_ReductionClauseSkip<
   ];
 
   let arguments = (ins
+    OptionalAttr<ReductionModifierAttr>:$reduction_mod,
     Variadic<OpenMP_PointerLikeType>:$reduction_vars,
     OptionalAttr<DenseBoolArrayAttr>:$reduction_byref,
     OptionalAttr<SymbolRefArrayAttr>:$reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index b1a9e3330522b2..23086556bbb2f5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -178,6 +178,27 @@ def OrderModifier
 def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
                                     "order_mod">;
 
+//===----------------------------------------------------------------------===//
+// reduction_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+
+def ReductionModifier : OpenMP_I32EnumAttr<
+    "ReductionModifier",
+    "reduction modifier", [
+      ReductionModifierInScan,
+      ReductionModifierTask,
+      ReductionModifierDefault
+    ]>;
+
+def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
+                                                  "reduction_modifier"> {
+  let assemblyFormat = "`(` $value `)`";
+}
+
 //===----------------------------------------------------------------------===//
 // sched_mod enum.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 626539cb7bde42..a03f18a816c39e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -170,7 +170,7 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -215,7 +215,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -274,7 +274,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -422,7 +422,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -476,7 +476,7 @@ def SimdOp : OpenMP_Op<"simd", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -680,7 +680,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
     custom<InReductionPrivateReductionRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms, $reduction_vars,
+        type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
         type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
@@ -1560,6 +1560,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
   let hasVerifier = 1;
 }
 
+def ScanOp : OpenMP_Op<"scan", [
+    AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+  ], clauses = [
+  OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+  let summary = "scan directive";
+  let description = [{
+    The scan directive allows to specify scan reduction. Scan directive
+    should be enclosed with in a parent directive along with which , a
+    reduction clause with `InScan` modifier must be specified. Scan directive
+    allows to separate code blocks to input phase and scan phase in the region
+    enclosed by the parent.
+  }] # clausesDescription;
+
+  let builders = [
+    OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)>
+  ];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.19.5.7 declare reduction Directive
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index aa241b91d758ca..233739e1d6d917 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -451,6 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
+        /* reduction_mod = */ nullptr,
         /* reduction_vars = */ llvm::SmallVector<Value>{},
         /* reduction_byref = */ DenseBoolArrayAttr{},
         /* reduction_syms = */ ArrayAttr{});
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e1df647d6a3c71..0ad7fe2c2cf243 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -491,16 +491,27 @@ struct PrivateParseArgs {
                    SmallVectorImpl<Type> &types, ArrayAttr &syms)
       : vars(vars), types(types), syms(syms) {}
 };
+
+static ReductionModifierAttr nullReductionMod = nullptr;
 struct ReductionParseArgs {
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   SmallVectorImpl<Type> &types;
   DenseBoolArrayAttr &byref;
   ArrayAttr &syms;
+  ReductionModifierAttr &reductionMod;
+  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+                     SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
+                     ArrayAttr &syms, ReductionModifierAttr &redMod)
+      : vars(vars), types(types), byref(byref), syms(syms),
+        reductionMod(redMod) {}
   ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                      SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
                      ArrayAttr &syms)
-      : vars(vars), types(types), byref(byref), syms(syms) {}
+      : vars(vars), types(types), byref(byref), syms(syms),
+        reductionMod(nullReductionMod) {}
 };
+
+// specifies the arguments needs for `reduction` clause
 struct AllRegionParseArgs {
   std::optional<ReductionParseArgs> inReductionArgs;
   std::optional<MapParseArgs> mapArgs;
@@ -517,7 +528,8 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
     SmallVectorImpl<Type> &types,
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
-    ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+    ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr,
+    ReductionModifierAttr &reductionMod = nullReductionMod) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<bool> isByRefVec;
   unsigned regionArgOffset = regionPrivateArgs.size();
@@ -525,6 +537,16 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseLParen())
     return failure();
 
+  StringRef enumStr;
+  if (succeeded(parser.parseOptionalKeyword("type"))) {
+    if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
+        parser.parseComma())
+      return failure();
+    std::optional<ReductionModifier> enumValue =
+        symbolizeReductionModifier(enumStr);
+    reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+  }
+
   if (parser.parseCommaSeparatedList([&]() {
         if (byref)
           isByRefVec.push_back(
@@ -615,15 +637,14 @@ static ParseResult parseBlockArgClause(
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
     if (!reductionArgs)
       return failure();
-
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
-            &reductionArgs->syms, &reductionArgs->byref)))
+            &reductionArgs->syms, &reductionArgs->byref,
+            reductionArgs->reductionMod)))
       return failure();
   }
   return success();
 }
-
 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
                                        AllRegionParseArgs args) {
   llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
@@ -704,6 +725,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
@@ -712,7 +734,7 @@ static ParseResult parseInReductionPrivateReductionRegion(
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+                             reductionSyms, reductionMod);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -729,13 +751,14 @@ static ParseResult parsePrivateReductionRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+                             reductionSyms, reductionMod);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -784,9 +807,12 @@ struct ReductionPrintArgs {
   TypeRange types;
   DenseBoolArrayAttr byref;
   ArrayAttr syms;
+  ReductionModifierAttr reductionMod;
   ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
-                     ArrayAttr syms)
-      : vars(vars), types(types), byref(byref), syms(syms) {}
+                     ArrayAttr syms,
+                     ReductionModifierAttr reductionMod = nullReductionMod)
+      : vars(vars), types(types), byref(byref), syms(syms),
+        reductionMod(reductionMod) {}
 };
 struct AllRegionPrintArgs {
   std::optional<ReductionPrintArgs> inReductionArgs;
@@ -799,17 +825,21 @@ struct AllRegionPrintArgs {
 };
 } // namespace
 
-static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
-                                      StringRef clauseName,
-                                      ValueRange argsSubrange,
-                                      ValueRange operands, TypeRange types,
-                                      ArrayAttr symbols = nullptr,
-                                      DenseBoolArrayAttr byref = nullptr) {
+static void printClauseWithRegionArgs(
+    OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
+    ValueRange argsSubrange, ValueRange operands, TypeRange types,
+    ArrayAttr symbols = nullptr, DenseBoolArrayAttr byref = nullptr,
+    ReductionModifierAttr reductionMod = nullptr) {
   if (argsSubrange.empty())
     return;
 
   p << clauseName << "(";
 
+  if (reductionMod) {
+    p << "type: " << stringifyReductionModifier(reductionMod.getValue())
+      << ", ";
+  }
+
   if (!symbols) {
     llvm::SmallVector<Attribute> values(operands.size(), nullptr);
     symbols = ArrayAttr::get(ctx, values);
@@ -859,7 +889,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
   if (reductionArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               reductionArgs->vars, reductionArgs->types,
-                              reductionArgs->syms, reductionArgs->byref);
+                              reductionArgs->syms, reductionArgs->byref,
+                              reductionArgs->reductionMod);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -916,14 +947,15 @@ static void printInReductionPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
+    ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
+    ValueRange reductionVars, TypeRange reductionTypes,
     DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+                             reductionSyms, reductionMod);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -937,13 +969,14 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 static void printPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
-    TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
+    TypeRange privateTypes, ArrayAttr privateSyms,
+    ReductionModifierAttr reductionMod, ValueRange reductionVars,
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+                             reductionSyms, reductionMod);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1700,7 +1733,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
-                    /*reduction_vars=*/ValueRange(),
+                    /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
                     /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
   state.addAttributes(attributes);
 }
@@ -1711,7 +1744,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.ifExpr, clauses.numThreads, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.procBindKind, clauses.reductionVars,
+                    clauses.procBindKind, clauses.reductionMod,
+                    clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -1810,12 +1844,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
                     const TeamsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms.
-  TeamsOp::build(
-      builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
-      /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-      makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
+  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+                 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+                 /*private_vars=*/{}, /*private_syms=*/nullptr,
+                 clauses.reductionMod, clauses.reductionVars,
+                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                 makeArrayAttr(ctx, clauses.reductionSyms),
+                 clauses.threadLimit);
 }
 
 LogicalResult TeamsOp::verify() {
@@ -1872,7 +1907,8 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.nowait, /*private_vars=*/{},
-                    /*private_syms=*/nullptr, clauses.reductionVars,
+                    /*private_syms=*/nullptr, clauses.reductionMod,
+                    clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -1958,7 +1994,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
         /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
         /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
-        /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
+        nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
         /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
         /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
         /*schedule_simd=*/false);
@@ -1975,7 +2011,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
       /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
       clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
       clauses.ordered, /*private_vars=*/{}, /*private_syms=*/nullptr,
-      clauses.reductionVars,
+      clauses.reductionMod, clauses.reductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
       makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
       clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
@@ -2025,7 +2061,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 /*private_vars=*/{}, /*private_syms=*/nullptr,
-                clauses.reductionVars,
+                clauses.reductionMod, clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
                 clauses.simdlen);
@@ -2257,7 +2293,7 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
       makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
       makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
       clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
-      /*private_syms=*/nullptr, clauses.reductionVars,
+      /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
       makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
 }
@@ -2840,6 +2876,32 @@ void MaskedOp::build(OpBuilder &builder, OperationState &state,
   MaskedOp::build(builder, state, clauses.filteredThreadId);
 }
 
+//===----------------------------------------------------------------------===//
+// Spec 5.2: Scan construct (5.6)
+//===----------------------------------------------------------------------===//
+
+void ScanOp::build(OpBuilder &builder, OperationState &state,
+                   const ScanOperands &clauses) {
+  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
+}
+
+LogicalResult ScanOp::verify() {
+  if (mlir::omp::WsloopOp parentOp =
+          (*this)->getParentOfType<mlir::omp::WsloopOp>()) {
+    if (parentOp.getReductionModAttr().getValue() ==
+        mlir::omp::ReductionModifier::InScan) {
+      return success();
+    }
+  } else if (mlir::omp::SimdOp parentOp =
+                 (*this)->getParentOfType<mlir::omp::SimdOp>()) {
+    if (parentOp.getReductionModAttr().getValue() ==
+        mlir::omp::ReductionModifier::InScan) {
+      return success();
+    }
+  }
+  return failure();
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 6f11b451fa00a3..8b445608fdca7a 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -900,6 +900,29 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
   return
 }
 
+// CHECK-LABEL: func @wsloop_inscan_reduction
+func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+  // CHECK: reduction(type: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(type:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+       // CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
+       omp.scan inclusive(%prv : !llvm.ptr)
+       omp.yield
+    }
+  }
+  // CHECK: reduction(type: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(type:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+       // CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
+       omp.scan exclusive(%prv : !llvm.ptr)
+       omp.yield
+    }
+  }
+  return
+}
+
 // CHECK-LABEL: func @wsloop_reduction_byref
 func.func @wsloop_reduction_byref(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32



More information about the Mlir-commits mailing list