[Mlir-commits] [mlir] [mlir][OpenMP] inscan reduction modifier and scan op mlir support (PR #114737)

Anchu Rajendran S llvmlistbot at llvm.org
Tue Jan 21 13:11:57 PST 2025


https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/114737

>From 59ac266b0251b7336059603a385688c11fde8187 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Sat, 2 Nov 2024 21:57:20 -0500
Subject: [PATCH 1/6] Changes for inscan reduction and scan op

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  71 +++++
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |  21 ++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  42 ++-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |   1 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 255 ++++++++++++------
 mlir/test/Dialect/OpenMP/invalid.mlir         |  79 ++++++
 mlir/test/Dialect/OpenMP/ops.mlir             |  23 ++
 7 files changed, 400 insertions(+), 92 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8af054be322a55..56ecc15dfc8799 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -334,6 +334,40 @@ class OpenMP_DoacrossClauseSkip<
 
 def OpenMP_DoacrossClause : OpenMP_DoacrossClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `exclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_ExclusiveClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    Variadic<AnyType>:$exclusive_vars
+  );
+
+  let optAssemblyFormat = [{
+    `exclusive` `(` $exclusive_vars `:` type($exclusive_vars) `)`
+  }];
+
+  let extraClassDeclaration = [{
+    bool hasExclusiveVars() {
+      return getExclusiveVars().size()>0;
+    }
+  }];
+
+  let description = [{
+    The exclusive clause is used on a separating directive that separates a
+    structured block into two structured block sequences. If it
+    is specified, the input phase excludes the preceding structured block 
+    sequence and instead includes the following structured block sequence, 
+    while the scan phase includes the preceding structured block sequence.
+  }];
+}
+
+def OpenMP_ExclusiveClause : OpenMP_ExclusiveClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [10.5.1] `filter` clause
 //===----------------------------------------------------------------------===//
@@ -444,6 +478,40 @@ class OpenMP_HasDeviceAddrClauseSkip<
 
 def OpenMP_HasDeviceAddrClause : OpenMP_HasDeviceAddrClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V5.2: [5.4.7] `inclusive` clause
+//===----------------------------------------------------------------------===//
+
+class OpenMP_InclusiveClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let arguments = (ins
+    Variadic<AnyType>:$inclusive_vars
+  );
+
+  let optAssemblyFormat = [{
+    `inclusive` `(` $inclusive_vars `:` type($inclusive_vars) `)`
+  }];
+
+  let extraClassDeclaration = [{
+    bool hasInclusiveVars() {
+      return getInclusiveVars().size()>0;
+    }
+  }];
+
+  let description = [{
+    The inclusive clause is used on a separating directive that separates a
+    structured block into two structured block sequences. If it is specified,
+    the input phase includes the preceding structured block sequence and the
+    scan phase includes the following structured block sequence.
+  }];
+}
+
+def OpenMP_InclusiveClause : OpenMP_InclusiveClauseSkip<>;
+
+
 //===----------------------------------------------------------------------===//
 // V5.2: [15.1.2] `hint` clause
 //===----------------------------------------------------------------------===//
@@ -544,6 +612,7 @@ class OpenMP_InReductionClauseSkip<
   ];
 
   let arguments = (ins
+    OptionalAttr<ReductionModifierAttr>:$in_reduction_mod,
     Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
     OptionalAttr<DenseBoolArrayAttr>:$in_reduction_byref,
     OptionalAttr<SymbolRefArrayAttr>:$in_reduction_syms
@@ -1100,6 +1169,7 @@ class OpenMP_ReductionClauseSkip<
   ];
 
   let arguments = (ins
+    OptionalAttr<ReductionModifierAttr>:$reduction_mod,
     Variadic<OpenMP_PointerLikeType>:$reduction_vars,
     OptionalAttr<DenseBoolArrayAttr>:$reduction_byref,
     OptionalAttr<SymbolRefArrayAttr>:$reduction_syms
@@ -1230,6 +1300,7 @@ class OpenMP_TaskReductionClauseSkip<
   ];
 
   let arguments = (ins
+    OptionalAttr<ReductionModifierAttr>:$task_reduction_mod,
     Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
     OptionalAttr<DenseBoolArrayAttr>:$task_reduction_byref,
     OptionalAttr<SymbolRefArrayAttr>:$task_reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 2091c0c76dff72..25e08aa726af40 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -179,6 +179,27 @@ def OrderModifier
 def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
                                     "order_mod">;
 
+//===----------------------------------------------------------------------===//
+// reduction_modifier enum.
+//===----------------------------------------------------------------------===//
+
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+
+def ReductionModifier : OpenMP_I32EnumAttr<
+    "ReductionModifier",
+    "reduction modifier", [
+      ReductionModifierInScan,
+      ReductionModifierTask,
+      ReductionModifierDefault
+    ]>;
+
+def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
+                                                  "reduction_modifier"> {
+  let assemblyFormat = "`(` $value `)`";
+}
+
 //===----------------------------------------------------------------------===//
 // sched_mod enum.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index c5b88904367086..6c62c83398b1a7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -178,7 +178,7 @@ def ParallelOp : OpenMP_Op<"parallel", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -223,7 +223,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -282,7 +282,7 @@ def SectionsOp : OpenMP_Op<"sections", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -469,7 +469,7 @@ def LoopOp : OpenMP_Op<"loop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -521,7 +521,7 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -575,7 +575,7 @@ def SimdOp : OpenMP_Op<"simd", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<PrivateReductionRegion>($region, $private_vars, type($private_vars),
-        $private_syms, $reduction_vars, type($reduction_vars), $reduction_byref,
+        $private_syms, $reduction_mod, $reduction_vars, type($reduction_vars), $reduction_byref,
         $reduction_syms) attr-dict
   }];
 
@@ -702,7 +702,7 @@ def TaskOp
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<InReductionPrivateRegion>(
-        $region, $in_reduction_vars, type($in_reduction_vars),
+        $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
         type($private_vars), $private_syms) attr-dict
   }];
@@ -780,9 +780,9 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<InReductionPrivateReductionRegion>(
-        $region, $in_reduction_vars, type($in_reduction_vars),
+        $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
-        type($private_vars), $private_syms, $reduction_vars,
+        type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
         type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
   }];
 
@@ -827,7 +827,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<TaskReductionRegion>(
-        $region, $task_reduction_vars, type($task_reduction_vars),
+        $region, $task_reduction_mod, $task_reduction_vars, type($task_reduction_vars),
         $task_reduction_byref, $task_reduction_syms) attr-dict
   }];
 
@@ -1289,7 +1289,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<HostEvalInReductionMapPrivateRegion>(
-        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
+        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_mod, $in_reduction_vars,
         type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
         $map_vars, type($map_vars), $private_vars, type($private_vars),
         $private_syms, $private_maps) attr-dict
@@ -1706,6 +1706,26 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
   let hasVerifier = 1;
 }
 
+def ScanOp : OpenMP_Op<"scan", [
+    AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+  ], clauses = [
+  OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+  let summary = "scan directive";
+  let description = [{
+    The scan directive allows to specify scan reduction. Scan directive
+    should be enclosed with in a parent directive along with which , a
+    reduction clause with `InScan` modifier must be specified. Scan directive
+    allows to separate code blocks to input phase and scan phase in the region
+    enclosed by the parent.
+  }] # clausesDescription;
+
+  let builders = [
+    OpBuilder<(ins CArg<"const ScanOperands &">:$clauses)>
+  ];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // 2.19.5.7 declare reduction Directive
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index aa241b91d758ca..233739e1d6d917 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -451,6 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
         /* proc_bind_kind = */ omp::ClauseProcBindKindAttr{},
+        /* reduction_mod = */ nullptr,
         /* reduction_vars = */ llvm::SmallVector<Value>{},
         /* reduction_byref = */ DenseBoolArrayAttr{},
         /* reduction_syms = */ ArrayAttr{});
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 5a619254a5ee14..2bc14f5abfc695 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -494,16 +494,22 @@ struct PrivateParseArgs {
                    DenseI64ArrayAttr *mapIndices = nullptr)
       : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
+
 struct ReductionParseArgs {
+  ReductionModifierAttr &reductionMod;
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   SmallVectorImpl<Type> &types;
   DenseBoolArrayAttr &byref;
   ArrayAttr &syms;
-  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+  ReductionParseArgs(ReductionModifierAttr &redMod,
+                     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                      SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
                      ArrayAttr &syms)
-      : vars(vars), types(types), byref(byref), syms(syms) {}
+      : reductionMod(redMod), vars(vars), types(types), byref(byref),
+        syms(syms) {}
 };
+
+// specifies the arguments needs for `reduction` clause
 struct AllRegionParseArgs {
   std::optional<MapParseArgs> hostEvalArgs;
   std::optional<ReductionParseArgs> inReductionArgs;
@@ -522,7 +528,8 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<Type> &types,
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
-    DenseBoolArrayAttr *byref = nullptr) {
+    DenseBoolArrayAttr *byref = nullptr,
+    ReductionModifierAttr *reductionMod = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
@@ -531,6 +538,16 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseLParen())
     return failure();
 
+  StringRef enumStr;
+  if (succeeded(parser.parseOptionalKeyword("Id"))) {
+    if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
+        parser.parseComma())
+      return failure();
+    std::optional<ReductionModifier> enumValue =
+        symbolizeReductionModifier(enumStr);
+    *reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+  }
+
   if (parser.parseCommaSeparatedList([&]() {
         if (byref)
           isByRefVec.push_back(
@@ -635,16 +652,14 @@ static ParseResult parseBlockArgClause(
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
     if (!reductionArgs)
       return failure();
-
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
-            &reductionArgs->syms, /*mapIndices=*/nullptr,
-            &reductionArgs->byref)))
+            &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
+            &(reductionArgs->reductionMod))))
       return failure();
   }
   return success();
 }
-
 static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
                                        AllRegionParseArgs args) {
   llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
@@ -695,7 +710,7 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
 static ParseResult parseHostEvalInReductionMapPrivateRegion(
     OpAsmParser &parser, Region &region,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
-    SmallVectorImpl<Type> &hostEvalTypes,
+    SmallVectorImpl<Type> &hostEvalTypes, ReductionModifierAttr &inReductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -706,8 +721,9 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
     DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
+  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+                               inReductionTypes, inReductionByref,
+                               inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            &privateMaps);
@@ -715,35 +731,38 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
 }
 
 static ParseResult parseInReductionPrivateRegion(
-    OpAsmParser &parser, Region &region,
+    OpAsmParser &parser, Region &region, ReductionModifierAttr &inReductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   AllRegionParseArgs args;
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
+  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+                               inReductionTypes, inReductionByref,
+                               inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
   return parseBlockArgRegion(parser, region, args);
 }
 
 static ParseResult parseInReductionPrivateReductionRegion(
-    OpAsmParser &parser, Region &region,
+    OpAsmParser &parser, Region &region, ReductionModifierAttr &inReductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
+  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+                               inReductionTypes, inReductionByref,
+                               inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
-  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+                             reductionByref, reductionSyms);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -760,24 +779,27 @@ static ParseResult parsePrivateReductionRegion(
     OpAsmParser &parser, Region &region,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ReductionModifierAttr &reductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &reductionVars,
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
-  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+                             reductionByref, reductionSyms);
   return parseBlockArgRegion(parser, region, args);
 }
 
 static ParseResult parseTaskReductionRegion(
     OpAsmParser &parser, Region &region,
+    ReductionModifierAttr &taskReductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
     SmallVectorImpl<Type> &taskReductionTypes,
     DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
   AllRegionParseArgs args;
-  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
-                                 taskReductionByref, taskReductionSyms);
+  args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
+                                 taskReductionTypes, taskReductionByref,
+                                 taskReductionSyms);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -813,13 +835,15 @@ struct PrivatePrintArgs {
       : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
+  ReductionModifierAttr reductionMod;
   ValueRange vars;
   TypeRange types;
   DenseBoolArrayAttr byref;
   ArrayAttr syms;
-  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
-                     ArrayAttr syms)
-      : vars(vars), types(types), byref(byref), syms(syms) {}
+  ReductionPrintArgs(ReductionModifierAttr reductionMod, ValueRange vars,
+                     TypeRange types, DenseBoolArrayAttr byref, ArrayAttr syms)
+      : reductionMod(reductionMod), vars(vars), types(types), byref(byref),
+        syms(syms) {}
 };
 struct AllRegionPrintArgs {
   std::optional<MapPrintArgs> hostEvalArgs;
@@ -833,18 +857,21 @@ struct AllRegionPrintArgs {
 };
 } // namespace
 
-static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
-                                      StringRef clauseName,
-                                      ValueRange argsSubrange,
-                                      ValueRange operands, TypeRange types,
-                                      ArrayAttr symbols = nullptr,
-                                      DenseI64ArrayAttr mapIndices = nullptr,
-                                      DenseBoolArrayAttr byref = nullptr) {
+static void printClauseWithRegionArgs(
+    OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
+    ValueRange argsSubrange, ValueRange operands, TypeRange types,
+    ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
+    DenseBoolArrayAttr byref = nullptr,
+    ReductionModifierAttr reductionMod = nullptr) {
   if (argsSubrange.empty())
     return;
 
   p << clauseName << "(";
 
+  if (reductionMod) {
+    p << "Id: " << stringifyReductionModifier(reductionMod.getValue()) << ", ";
+  }
+
   if (!symbols) {
     llvm::SmallVector<Attribute> values(operands.size(), nullptr);
     symbols = ArrayAttr::get(ctx, values);
@@ -902,10 +929,10 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
                     ValueRange argsSubrange,
                     std::optional<ReductionPrintArgs> reductionArgs) {
   if (reductionArgs)
-    printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
-                              reductionArgs->vars, reductionArgs->types,
-                              reductionArgs->syms, /*mapIndices=*/nullptr,
-                              reductionArgs->byref);
+    printClauseWithRegionArgs(
+        p, ctx, clauseName, argsSubrange, reductionArgs->vars,
+        reductionArgs->types, reductionArgs->syms, /*mapIndices=*/nullptr,
+        reductionArgs->byref, reductionArgs->reductionMod);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -937,46 +964,53 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 static void printHostEvalInReductionMapPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
-    TypeRange hostEvalTypes, ValueRange inReductionVars,
-    TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
-    ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
-    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+    TypeRange hostEvalTypes, ReductionModifierAttr inReductionMod,
+    ValueRange inReductionVars, TypeRange inReductionTypes,
+    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
+    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
+    TypeRange privateTypes, ArrayAttr privateSyms,
     DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
+  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+                               inReductionTypes, inReductionByref,
+                               inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printInReductionPrivateRegion(
-    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
+    OpAsmPrinter &p, Operation *op, Region &region,
+    ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
     ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
+  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+                               inReductionTypes, inReductionByref,
+                               inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printInReductionPrivateReductionRegion(
-    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
+    OpAsmPrinter &p, Operation *op, Region &region,
+    ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
-    ArrayAttr privateSyms, ValueRange reductionVars, TypeRange reductionTypes,
+    ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
+    ValueRange reductionVars, TypeRange reductionTypes,
     DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
-  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
-                               inReductionByref, inReductionSyms);
+  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
+                               inReductionTypes, inReductionByref,
+                               inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
-  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+                             reductionByref, reductionSyms);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -991,26 +1025,29 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 static void printPrivateReductionRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
-    TypeRange privateTypes, ArrayAttr privateSyms, ValueRange reductionVars,
+    TypeRange privateTypes, ArrayAttr privateSyms,
+    ReductionModifierAttr reductionMod, ValueRange reductionVars,
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
-  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms);
+  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
+                             reductionByref, reductionSyms);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
                                      Region &region,
+                                     ReductionModifierAttr taskReductionMod,
                                      ValueRange taskReductionVars,
                                      TypeRange taskReductionTypes,
                                      DenseBoolArrayAttr taskReductionByref,
                                      ArrayAttr taskReductionSyms) {
   AllRegionPrintArgs args;
-  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
-                                 taskReductionByref, taskReductionSyms);
+  args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
+                                 taskReductionTypes, taskReductionByref,
+                                 taskReductionSyms);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1727,6 +1764,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
                   clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
                   clauses.hostEvalVars, clauses.ifExpr,
+                  /*in_reduction_mod=*/nullptr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1942,7 +1980,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*proc_bind_kind=*/nullptr,
-                    /*reduction_vars=*/ValueRange(),
+                    /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
                     /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
   state.addAttributes(attributes);
 }
@@ -1953,7 +1991,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.ifExpr, clauses.numThreads, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.procBindKind, clauses.reductionVars,
+                    clauses.procBindKind, clauses.reductionMod,
+                    clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2052,12 +2091,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
                     const TeamsOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms.
-  TeamsOp::build(
-      builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
-      /*private_vars=*/{}, /*private_syms=*/nullptr, clauses.reductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-      makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
+  TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+                 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+                 /*private_vars=*/{}, /*private_syms=*/nullptr,
+                 clauses.reductionMod, clauses.reductionVars,
+                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                 makeArrayAttr(ctx, clauses.reductionSyms),
+                 clauses.threadLimit);
 }
 
 LogicalResult TeamsOp::verify() {
@@ -2114,7 +2154,8 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                     clauses.nowait, /*private_vars=*/{},
-                    /*private_syms=*/nullptr, clauses.reductionVars,
+                    /*private_syms=*/nullptr, clauses.reductionMod,
+                    clauses.reductionVars,
                     makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                     makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2221,7 +2262,7 @@ void LoopOp::build(OpBuilder &builder, OperationState &state,
 
   LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
                 makeArrayAttr(ctx, clauses.privateSyms), clauses.order,
-                clauses.orderMod, clauses.reductionVars,
+                clauses.orderMod, clauses.reductionMod, clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms));
 }
@@ -2249,7 +2290,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
         /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
         /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
-        /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
+        nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
         /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
         /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
         /*schedule_simd=*/false);
@@ -2261,15 +2302,16 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO: Store clauses in op: allocateVars, allocatorVars, privateVars,
   // privateSyms.
-  WsloopOp::build(
-      builder, state,
-      /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
-      clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
-      clauses.ordered, clauses.privateVars,
-      makeArrayAttr(ctx, clauses.privateSyms), clauses.reductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-      makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
-      clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
+  WsloopOp::build(builder, state,
+                  /*allocate_vars=*/{}, /*allocator_vars=*/{},
+                  clauses.linearVars, clauses.linearStepVars, clauses.nowait,
+                  clauses.order, clauses.orderMod, clauses.ordered,
+                  clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
+                  clauses.reductionMod, clauses.reductionVars,
+                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                  makeArrayAttr(ctx, clauses.reductionSyms),
+                  clauses.scheduleKind, clauses.scheduleChunk,
+                  clauses.scheduleMod, clauses.scheduleSimd);
 }
 
 LogicalResult WsloopOp::verify() {
@@ -2316,7 +2358,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
                 /*linear_vars=*/{}, /*linear_step_vars=*/{},
                 clauses.nontemporalVars, clauses.order, clauses.orderMod,
                 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
-                clauses.reductionVars,
+                clauses.reductionMod, clauses.reductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
                 clauses.simdlen);
@@ -2497,7 +2539,8 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                clauses.final, clauses.ifExpr, clauses.inReductionVars,
+                clauses.final, clauses.ifExpr, clauses.inReductionMod,
+                clauses.inReductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/clauses.privateVars,
@@ -2523,7 +2566,8 @@ void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
                         const TaskgroupOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   TaskgroupOp::build(builder, state, clauses.allocateVars,
-                     clauses.allocatorVars, clauses.taskReductionVars,
+                     clauses.allocatorVars, clauses.taskReductionMod,
+                     clauses.taskReductionVars,
                      makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
                      makeArrayAttr(ctx, clauses.taskReductionSyms));
 }
@@ -2544,11 +2588,12 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   TaskloopOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
+      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionMod,
+      clauses.inReductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
       makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
       clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
-      /*private_syms=*/nullptr, clauses.reductionVars,
+      /*private_syms=*/nullptr, clauses.reductionMod, clauses.reductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
       makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
 }
@@ -3125,6 +3170,54 @@ void MaskedOp::build(OpBuilder &builder, OperationState &state,
   MaskedOp::build(builder, state, clauses.filteredThreadId);
 }
 
+//===----------------------------------------------------------------------===//
+// Spec 5.2: Scan construct (5.6)
+//===----------------------------------------------------------------------===//
+
+void ScanOp::build(OpBuilder &builder, OperationState &state,
+                   const ScanOperands &clauses) {
+  ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
+}
+
+LogicalResult ScanOp::verify() {
+  if (hasExclusiveVars() && hasInclusiveVars()) {
+    return emitError(
+        "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
+  }
+  const OperandRange &scanVars =
+      hasExclusiveVars() ? getExclusiveVars() : getInclusiveVars();
+  auto verifyScanVarsInReduction = [&scanVars](OperandRange reductionVars) {
+    for (const auto &it : scanVars)
+      if (!llvm::is_contained(reductionVars, it))
+        return false;
+    return true;
+  };
+  if (mlir::omp::WsloopOp parentOp =
+          (*this)->getParentOfType<mlir::omp::WsloopOp>()) {
+    if (parentOp.getReductionModAttr() &&
+        parentOp.getReductionModAttr().getValue() ==
+            mlir::omp::ReductionModifier::InScan) {
+      if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+        return emitError(
+            "List item should appear in REDUCTION clause of the parent");
+      }
+      return success();
+    }
+  } else if (mlir::omp::SimdOp parentOp =
+                 (*this)->getParentOfType<mlir::omp::SimdOp>()) {
+    if (parentOp.getReductionModAttr().getValue() ==
+        mlir::omp::ReductionModifier::InScan) {
+      if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+        return emitError(
+            "List item should appear in REDUCTION clause of the parent");
+      }
+      return success();
+    }
+  }
+  return emitError("Scan Operation should be enclosed within a parent "
+                   "WORSKSHARING LOOP or SIMD with INSCAN reduction modifier");
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index c611614265592c..5215e363d43e66 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1825,6 +1825,85 @@ func.func @omp_cancellationpoint2() {
 
 // -----
 
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+  ^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
+  %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+  %test2f32 = "test.f32"() : () -> (!llvm.ptr)
+  omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+  // expected-error @below {{List item should appear in REDUCTION clause of the parent}}
+       omp.scan inclusive(%test2f32 : !llvm.ptr)
+        omp.yield
+    }
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+  ^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
+  %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+  omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+  // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
+       omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr)
+        omp.yield
+    }
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+  ^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
+func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
+  %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+  omp.taskloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+  // expected-error @below {{Scan Operation should be enclosed within a parent WORSKSHARING LOOP or SIMD with INSCAN reduction modifier}}
+       omp.scan inclusive(%test1f32 : !llvm.ptr)
+        omp.yield
+    }
+  }
+  return
+}
+
+// -----
+
 func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
   %testmemref = "test.memref"() : () -> (memref<i32>)
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index b1901c333ade8d..25610d030d3f1c 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -900,6 +900,29 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
   return
 }
 
+// CHECK-LABEL: func @wsloop_inscan_reduction
+func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
+  %c1 = arith.constant 1 : i32
+  %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
+  // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+       // CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
+       omp.scan inclusive(%0 : !llvm.ptr)
+       omp.yield
+    }
+  }
+  // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+    omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
+       // CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
+       omp.scan exclusive(%0 : !llvm.ptr)
+       omp.yield
+    }
+  }
+  return
+}
+
 // CHECK-LABEL: func @wsloop_reduction_byref
 func.func @wsloop_reduction_byref(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32

>From 1be9aa4ef98a26d67cf97a2c0c5c849bbec388b4 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Mon, 18 Nov 2024 16:43:45 -0600
Subject: [PATCH 2/6] R2: Addressing a few review comments

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td | 10 +++++-----
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td   |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp    |  8 +++++---
 mlir/test/Dialect/OpenMP/invalid.mlir           |  2 +-
 4 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 25e08aa726af40..129bd03f809c49 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -183,16 +183,16 @@ def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
 // reduction_modifier enum.
 //===----------------------------------------------------------------------===//
 
-def ReductionModifierInScan : I32EnumAttrCase<"InScan", 0>;
-def ReductionModifierTask : I32EnumAttrCase<"Task", 1>;
-def ReductionModifierDefault : I32EnumAttrCase<"Default", 2>;
+def ReductionModifierDefault : I32EnumAttrCase<"Default", 0>;
+def ReductionModifierInScan : I32EnumAttrCase<"InScan", 1>;
+def ReductionModifierTask : I32EnumAttrCase<"Task", 2>;
 
 def ReductionModifier : OpenMP_I32EnumAttr<
     "ReductionModifier",
     "reduction modifier", [
+      ReductionModifierDefault,
       ReductionModifierInScan,
-      ReductionModifierTask,
-      ReductionModifierDefault
+      ReductionModifierTask
     ]>;
 
 def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 6c62c83398b1a7..cc375a0c021e81 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1713,7 +1713,7 @@ def ScanOp : OpenMP_Op<"scan", [
   let summary = "scan directive";
   let description = [{
     The scan directive allows to specify scan reduction. Scan directive
-    should be enclosed with in a parent directive along with which , a
+    should be enclosed with in a parent directive along with which, a
     reduction clause with `InScan` modifier must be specified. Scan directive
     allows to separate code blocks to input phase and scan phase in the region
     enclosed by the parent.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2bc14f5abfc695..08df20fbb349cf 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2290,7 +2290,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
         /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
         /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
         /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
-        nullptr, /*reduction_vars=*/ValueRange(), /*reduction_byref=*/nullptr,
+        /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
+        /*reduction_byref=*/nullptr,
         /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
         /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
         /*schedule_simd=*/false);
@@ -3214,8 +3215,9 @@ LogicalResult ScanOp::verify() {
       return success();
     }
   }
-  return emitError("Scan Operation should be enclosed within a parent "
-                   "WORSKSHARING LOOP or SIMD with INSCAN reduction modifier");
+  return emitError("SCAN directive needs to be enclosed within a parent "
+                   "worksharing loop construct or SIMD construct with INSCAN "
+                   "reduction modifier");
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 5215e363d43e66..d21c3e513277f3 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1894,7 +1894,7 @@ func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
   omp.taskloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
-  // expected-error @below {{Scan Operation should be enclosed within a parent WORSKSHARING LOOP or SIMD with INSCAN reduction modifier}}
+  // expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}}
        omp.scan inclusive(%test1f32 : !llvm.ptr)
         omp.yield
     }

>From eb274aae91fbaf06d0b66cc1eb8a09209a62ca57 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Fri, 6 Dec 2024 17:36:20 -0600
Subject: [PATCH 3/6] qR3: Addressing a few review comments

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  12 +-
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |   2 +-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  20 +--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 136 ++++++++----------
 mlir/test/Dialect/OpenMP/invalid.mlir         |   6 +-
 mlir/test/Dialect/OpenMP/ops.mlir             |   8 +-
 6 files changed, 85 insertions(+), 99 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 56ecc15dfc8799..d71dc70ed2d963 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -353,7 +353,7 @@ class OpenMP_ExclusiveClauseSkip<
 
   let extraClassDeclaration = [{
     bool hasExclusiveVars() {
-      return getExclusiveVars().size()>0;
+      return !getExclusiveVars().empty();
     }
   }];
 
@@ -363,6 +363,9 @@ class OpenMP_ExclusiveClauseSkip<
     is specified, the input phase excludes the preceding structured block 
     sequence and instead includes the following structured block sequence, 
     while the scan phase includes the preceding structured block sequence.
+
+    The `exclusive_vars` is a variadic list of operands that specifies the
+    scan-reduction accumulator symbols.
   }];
 }
 
@@ -497,7 +500,7 @@ class OpenMP_InclusiveClauseSkip<
 
   let extraClassDeclaration = [{
     bool hasInclusiveVars() {
-      return getInclusiveVars().size()>0;
+      return !getInclusiveVars().empty();
     }
   }];
 
@@ -506,6 +509,9 @@ class OpenMP_InclusiveClauseSkip<
     structured block into two structured block sequences. If it is specified,
     the input phase includes the preceding structured block sequence and the
     scan phase includes the following structured block sequence.
+
+    The `inclusive_vars` is a variadic list of operands that specifies the
+    scan-reduction accumulator symbols.
   }];
 }
 
@@ -612,7 +618,6 @@ class OpenMP_InReductionClauseSkip<
   ];
 
   let arguments = (ins
-    OptionalAttr<ReductionModifierAttr>:$in_reduction_mod,
     Variadic<OpenMP_PointerLikeType>:$in_reduction_vars,
     OptionalAttr<DenseBoolArrayAttr>:$in_reduction_byref,
     OptionalAttr<SymbolRefArrayAttr>:$in_reduction_syms
@@ -1300,7 +1305,6 @@ class OpenMP_TaskReductionClauseSkip<
   ];
 
   let arguments = (ins
-    OptionalAttr<ReductionModifierAttr>:$task_reduction_mod,
     Variadic<OpenMP_PointerLikeType>:$task_reduction_vars,
     OptionalAttr<DenseBoolArrayAttr>:$task_reduction_byref,
     OptionalAttr<SymbolRefArrayAttr>:$task_reduction_syms
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 129bd03f809c49..22777b4f0964b3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -196,7 +196,7 @@ def ReductionModifier : OpenMP_I32EnumAttr<
     ]>;
 
 def ReductionModifierAttr : OpenMP_EnumAttr<ReductionModifier,
-                                                  "reduction_modifier"> {
+                                            "reduction_modifier"> {
   let assemblyFormat = "`(` $value `)`";
 }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index cc375a0c021e81..e653d0cf7f2f99 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -702,7 +702,7 @@ def TaskOp
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<InReductionPrivateRegion>(
-        $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
+        $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
         type($private_vars), $private_syms) attr-dict
   }];
@@ -780,7 +780,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<InReductionPrivateReductionRegion>(
-        $region, $in_reduction_mod, $in_reduction_vars, type($in_reduction_vars),
+        $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $private_vars,
         type($private_vars), $private_syms, $reduction_mod, $reduction_vars,
         type($reduction_vars), $reduction_byref, $reduction_syms) attr-dict
@@ -827,7 +827,7 @@ def TaskgroupOp : OpenMP_Op<"taskgroup", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<TaskReductionRegion>(
-        $region, $task_reduction_mod, $task_reduction_vars, type($task_reduction_vars),
+        $region, $task_reduction_vars, type($task_reduction_vars),
         $task_reduction_byref, $task_reduction_syms) attr-dict
   }];
 
@@ -1289,7 +1289,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
 
   let assemblyFormat = clausesAssemblyFormat # [{
     custom<HostEvalInReductionMapPrivateRegion>(
-        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_mod, $in_reduction_vars,
+        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
         type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
         $map_vars, type($map_vars), $private_vars, type($private_vars),
         $private_syms, $private_maps) attr-dict
@@ -1707,15 +1707,15 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
 }
 
 def ScanOp : OpenMP_Op<"scan", [
-    AttrSizedOperandSegments, RecipeInterface, IsolatedFromAbove
+    AttrSizedOperandSegments
   ], clauses = [
-  OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
+    OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
   let summary = "scan directive";
   let description = [{
-    The scan directive allows to specify scan reduction. Scan directive
-    should be enclosed with in a parent directive along with which, a
-    reduction clause with `InScan` modifier must be specified. Scan directive
-    allows to separate code blocks to input phase and scan phase in the region
+    The scan directive allows to specify scan reductions. It should be
+    enclosed within a parent directive along with which a reduction clause
+    with `inscan` modifier must be specified. The scan directive allows to
+    split code blocks into input phase and scan phase in the region
     enclosed by the parent.
   }] # clausesDescription;
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 08df20fbb349cf..65b6496ea5b31d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -496,20 +496,17 @@ struct PrivateParseArgs {
 };
 
 struct ReductionParseArgs {
-  ReductionModifierAttr &reductionMod;
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   SmallVectorImpl<Type> &types;
   DenseBoolArrayAttr &byref;
   ArrayAttr &syms;
-  ReductionParseArgs(ReductionModifierAttr &redMod,
-                     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
+  ReductionModifierAttr *modifier;
+  ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
                      SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
-                     ArrayAttr &syms)
-      : reductionMod(redMod), vars(vars), types(types), byref(byref),
-        syms(syms) {}
+                     ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
+      : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
 };
 
-// specifies the arguments needs for `reduction` clause
 struct AllRegionParseArgs {
   std::optional<MapParseArgs> hostEvalArgs;
   std::optional<ReductionParseArgs> inReductionArgs;
@@ -529,7 +526,7 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
     ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
     DenseBoolArrayAttr *byref = nullptr,
-    ReductionModifierAttr *reductionMod = nullptr) {
+    ReductionModifierAttr *modifier = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
   SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
@@ -539,13 +536,17 @@ static ParseResult parseClauseWithRegionArgs(
     return failure();
 
   StringRef enumStr;
-  if (succeeded(parser.parseOptionalKeyword("Id"))) {
+  if (succeeded(parser.parseOptionalKeyword("mod"))) {
     if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
         parser.parseComma())
       return failure();
     std::optional<ReductionModifier> enumValue =
         symbolizeReductionModifier(enumStr);
-    *reductionMod = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+    if (!enumValue.has_value())
+      return failure();
+    *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
+    if (!*modifier)
+      return failure();
   }
 
   if (parser.parseCommaSeparatedList([&]() {
@@ -655,7 +656,7 @@ static ParseResult parseBlockArgClause(
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
             &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
-            &(reductionArgs->reductionMod))))
+            reductionArgs->modifier)))
       return failure();
   }
   return success();
@@ -710,7 +711,7 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
 static ParseResult parseHostEvalInReductionMapPrivateRegion(
     OpAsmParser &parser, Region &region,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
-    SmallVectorImpl<Type> &hostEvalTypes, ReductionModifierAttr &inReductionMod,
+    SmallVectorImpl<Type> &hostEvalTypes,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -721,9 +722,8 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
     DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
-  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
-                               inReductionTypes, inReductionByref,
-                               inReductionSyms);
+  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+                               inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            &privateMaps);
@@ -731,22 +731,21 @@ static ParseResult parseHostEvalInReductionMapPrivateRegion(
 }
 
 static ParseResult parseInReductionPrivateRegion(
-    OpAsmParser &parser, Region &region, ReductionModifierAttr &inReductionMod,
+    OpAsmParser &parser, Region &region,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   AllRegionParseArgs args;
-  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
-                               inReductionTypes, inReductionByref,
-                               inReductionSyms);
+  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+                               inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
   return parseBlockArgRegion(parser, region, args);
 }
 
 static ParseResult parseInReductionPrivateReductionRegion(
-    OpAsmParser &parser, Region &region, ReductionModifierAttr &inReductionMod,
+    OpAsmParser &parser, Region &region,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -757,12 +756,11 @@ static ParseResult parseInReductionPrivateReductionRegion(
     SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
-  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
-                               inReductionTypes, inReductionByref,
-                               inReductionSyms);
+  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+                               inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
-  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
-                             reductionByref, reductionSyms);
+  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+                             reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -785,21 +783,19 @@ static ParseResult parsePrivateReductionRegion(
     ArrayAttr &reductionSyms) {
   AllRegionParseArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
-  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
-                             reductionByref, reductionSyms);
+  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+                             reductionSyms, &reductionMod);
   return parseBlockArgRegion(parser, region, args);
 }
 
 static ParseResult parseTaskReductionRegion(
     OpAsmParser &parser, Region &region,
-    ReductionModifierAttr &taskReductionMod,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &taskReductionVars,
     SmallVectorImpl<Type> &taskReductionTypes,
     DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
   AllRegionParseArgs args;
-  args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
-                                 taskReductionTypes, taskReductionByref,
-                                 taskReductionSyms);
+  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
+                                 taskReductionByref, taskReductionSyms);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -835,15 +831,14 @@ struct PrivatePrintArgs {
       : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
-  ReductionModifierAttr reductionMod;
   ValueRange vars;
   TypeRange types;
   DenseBoolArrayAttr byref;
   ArrayAttr syms;
-  ReductionPrintArgs(ReductionModifierAttr reductionMod, ValueRange vars,
-                     TypeRange types, DenseBoolArrayAttr byref, ArrayAttr syms)
-      : reductionMod(reductionMod), vars(vars), types(types), byref(byref),
-        syms(syms) {}
+  ReductionModifierAttr *modifier;
+  ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
+                     ArrayAttr syms, ReductionModifierAttr *mod = nullptr)
+      : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
 };
 struct AllRegionPrintArgs {
   std::optional<MapPrintArgs> hostEvalArgs;
@@ -862,15 +857,14 @@ static void printClauseWithRegionArgs(
     ValueRange argsSubrange, ValueRange operands, TypeRange types,
     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
     DenseBoolArrayAttr byref = nullptr,
-    ReductionModifierAttr reductionMod = nullptr) {
+    ReductionModifierAttr *modifier = nullptr) {
   if (argsSubrange.empty())
     return;
 
   p << clauseName << "(";
 
-  if (reductionMod) {
-    p << "Id: " << stringifyReductionModifier(reductionMod.getValue()) << ", ";
-  }
+  if (modifier && *modifier)
+    p << "mod: " << stringifyReductionModifier(modifier->getValue()) << ", ";
 
   if (!symbols) {
     llvm::SmallVector<Attribute> values(operands.size(), nullptr);
@@ -929,10 +923,10 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
                     ValueRange argsSubrange,
                     std::optional<ReductionPrintArgs> reductionArgs) {
   if (reductionArgs)
-    printClauseWithRegionArgs(
-        p, ctx, clauseName, argsSubrange, reductionArgs->vars,
-        reductionArgs->types, reductionArgs->syms, /*mapIndices=*/nullptr,
-        reductionArgs->byref, reductionArgs->reductionMod);
+    printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
+                              reductionArgs->vars, reductionArgs->types,
+                              reductionArgs->syms, /*mapIndices=*/nullptr,
+                              reductionArgs->byref, reductionArgs->modifier);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -964,53 +958,47 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
 
 static void printHostEvalInReductionMapPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
-    TypeRange hostEvalTypes, ReductionModifierAttr inReductionMod,
-    ValueRange inReductionVars, TypeRange inReductionTypes,
-    DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
-    ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
-    TypeRange privateTypes, ArrayAttr privateSyms,
+    TypeRange hostEvalTypes, ValueRange inReductionVars,
+    TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
+    ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
+    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
     DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
-  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
-                               inReductionTypes, inReductionByref,
-                               inReductionSyms);
+  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+                               inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printInReductionPrivateRegion(
-    OpAsmPrinter &p, Operation *op, Region &region,
-    ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
+    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
     ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
-  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
-                               inReductionTypes, inReductionByref,
-                               inReductionSyms);
+  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+                               inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printInReductionPrivateReductionRegion(
-    OpAsmPrinter &p, Operation *op, Region &region,
-    ReductionModifierAttr inReductionMod, ValueRange inReductionVars,
+    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
     ArrayAttr privateSyms, ReductionModifierAttr reductionMod,
     ValueRange reductionVars, TypeRange reductionTypes,
     DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
-  args.inReductionArgs.emplace(inReductionMod, inReductionVars,
-                               inReductionTypes, inReductionByref,
-                               inReductionSyms);
+  args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
+                               inReductionByref, inReductionSyms);
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
-  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
-                             reductionByref, reductionSyms);
+  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+                             reductionSyms, &reductionMod);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1032,22 +1020,20 @@ static void printPrivateReductionRegion(
   AllRegionPrintArgs args;
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
-  args.reductionArgs.emplace(reductionMod, reductionVars, reductionTypes,
-                             reductionByref, reductionSyms);
+  args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
+                             reductionSyms, &reductionMod);
   printBlockArgRegion(p, op, region, args);
 }
 
 static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op,
                                      Region &region,
-                                     ReductionModifierAttr taskReductionMod,
                                      ValueRange taskReductionVars,
                                      TypeRange taskReductionTypes,
                                      DenseBoolArrayAttr taskReductionByref,
                                      ArrayAttr taskReductionSyms) {
   AllRegionPrintArgs args;
-  args.taskReductionArgs.emplace(taskReductionMod, taskReductionVars,
-                                 taskReductionTypes, taskReductionByref,
-                                 taskReductionSyms);
+  args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
+                                 taskReductionByref, taskReductionSyms);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1764,7 +1750,6 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
                   clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
                   clauses.hostEvalVars, clauses.ifExpr,
-                  /*in_reduction_mod=*/nullptr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -2540,8 +2525,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                clauses.final, clauses.ifExpr, clauses.inReductionMod,
-                clauses.inReductionVars,
+                clauses.final, clauses.ifExpr, clauses.inReductionVars,
                 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
                 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
                 clauses.priority, /*private_vars=*/clauses.privateVars,
@@ -2567,8 +2551,7 @@ void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
                         const TaskgroupOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   TaskgroupOp::build(builder, state, clauses.allocateVars,
-                     clauses.allocatorVars, clauses.taskReductionMod,
-                     clauses.taskReductionVars,
+                     clauses.allocatorVars, clauses.taskReductionVars,
                      makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
                      makeArrayAttr(ctx, clauses.taskReductionSyms));
 }
@@ -2589,8 +2572,7 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms.
   TaskloopOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionMod,
-      clauses.inReductionVars,
+      clauses.final, clauses.grainsize, clauses.ifExpr, clauses.inReductionVars,
       makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
       makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
       clauses.nogroup, clauses.numTasks, clauses.priority, /*private_vars=*/{},
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d21c3e513277f3..8d095287db064e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1840,7 +1840,7 @@ combiner {
 func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
   %test2f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+  omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
   // expected-error @below {{List item should appear in REDUCTION clause of the parent}}
        omp.scan inclusive(%test2f32 : !llvm.ptr)
@@ -1866,7 +1866,7 @@ combiner {
 
 func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.wsloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+  omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
   // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
        omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr)
@@ -1892,7 +1892,7 @@ combiner {
 
 func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.taskloop reduction(Id:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+  omp.taskloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
   // expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}}
        omp.scan inclusive(%test1f32 : !llvm.ptr)
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 25610d030d3f1c..2c8e9ad3fe0700 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -904,16 +904,16 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
 func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
-  // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
-  omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+  // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
        // CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
        omp.scan inclusive(%0 : !llvm.ptr)
        omp.yield
     }
   }
-  // CHECK: reduction(Id: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
-  omp.wsloop reduction(Id:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+  // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
        // CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
        omp.scan exclusive(%0 : !llvm.ptr)

>From 5c984a048eefe479f664d287c32ea55e0d12128f Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Fri, 10 Jan 2025 14:15:49 -0600
Subject: [PATCH 4/6] R4: Adding memory write side effects to ScanOp and
 addressing other review comments

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  9 +++---
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |  6 ++--
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 29 ++++++++++-------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  4 +++
 mlir/test/Dialect/OpenMP/invalid.mlir         |  6 ++--
 mlir/test/Dialect/OpenMP/ops.mlir             | 12 +++----
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 32 +++++++++++++++++++
 8 files changed, 71 insertions(+), 29 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d71dc70ed2d963..a8d97a36df79ee 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1188,10 +1188,11 @@ class OpenMP_ReductionClauseSkip<
 
   // Description varies depending on the operation.
   let description = [{
-    Reductions can be performed by specifying reduction accumulator variables in
-    `reduction_vars`, symbols referring to reduction declarations in the
-    `reduction_syms` attribute, and whether the reduction variable should be
-    passed into the reduction region by value or by reference in
+    Reductions can be performed by specifying the reduction modifer
+    (`default`, `inscan` or `task`) in `reduction_mod`, reduction accumulator
+    variables in `reduction_vars`, symbols referring to reduction declarations
+    in the `reduction_syms` attribute, and whether the reduction variable
+    should be passed into the reduction region by value or by reference in
     `reduction_byref`. Each reduction is identified by the accumulator it uses
     and accumulators must not be repeated in the same reduction. A private
     variable corresponding to the accumulator is used in place of the
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 22777b4f0964b3..4fa7630166f5f8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -183,9 +183,9 @@ def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
 // reduction_modifier enum.
 //===----------------------------------------------------------------------===//
 
-def ReductionModifierDefault : I32EnumAttrCase<"Default", 0>;
-def ReductionModifierInScan : I32EnumAttrCase<"InScan", 1>;
-def ReductionModifierTask : I32EnumAttrCase<"Task", 2>;
+def ReductionModifierDefault : I32EnumAttrCase<"defaultmod", 0>;
+def ReductionModifierInScan : I32EnumAttrCase<"inscan", 1>;
+def ReductionModifierTask : I32EnumAttrCase<"task", 2>;
 
 def ReductionModifier : OpenMP_I32EnumAttr<
     "ReductionModifier",
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index e653d0cf7f2f99..580c9c6ef6fde8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1707,7 +1707,7 @@ def CancellationPointOp : OpenMP_Op<"cancellation_point", clauses = [
 }
 
 def ScanOp : OpenMP_Op<"scan", [
-    AttrSizedOperandSegments
+    AttrSizedOperandSegments, MemoryEffects<[MemWrite]>
   ], clauses = [
     OpenMP_InclusiveClause, OpenMP_ExclusiveClause]> {
   let summary = "scan directive";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 65b6496ea5b31d..a8bf501ef28e6e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -535,8 +535,8 @@ static ParseResult parseClauseWithRegionArgs(
   if (parser.parseLParen())
     return failure();
 
-  StringRef enumStr;
-  if (succeeded(parser.parseOptionalKeyword("mod"))) {
+  if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
+    StringRef enumStr;
     if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
         parser.parseComma())
       return failure();
@@ -3169,28 +3169,33 @@ LogicalResult ScanOp::verify() {
   }
   const OperandRange &scanVars =
       hasExclusiveVars() ? getExclusiveVars() : getInclusiveVars();
-  auto verifyScanVarsInReduction = [&scanVars](OperandRange reductionVars) {
+  auto verifyScanVarsInReduction = [&scanVars](ValueRange reductionVars) {
     for (const auto &it : scanVars)
       if (!llvm::is_contained(reductionVars, it))
         return false;
     return true;
   };
-  if (mlir::omp::WsloopOp parentOp =
+  if (mlir::omp::WsloopOp parentWsLoopOp =
           (*this)->getParentOfType<mlir::omp::WsloopOp>()) {
-    if (parentOp.getReductionModAttr() &&
-        parentOp.getReductionModAttr().getValue() ==
-            mlir::omp::ReductionModifier::InScan) {
-      if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+    if (parentWsLoopOp.getReductionModAttr() &&
+        parentWsLoopOp.getReductionModAttr().getValue() ==
+            mlir::omp::ReductionModifier::inscan) {
+      auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
+          parentWsLoopOp.getOperation());
+      if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
         return emitError(
             "List item should appear in REDUCTION clause of the parent");
       }
       return success();
     }
-  } else if (mlir::omp::SimdOp parentOp =
+  } else if (mlir::omp::SimdOp parentSimdOp =
                  (*this)->getParentOfType<mlir::omp::SimdOp>()) {
-    if (parentOp.getReductionModAttr().getValue() ==
-        mlir::omp::ReductionModifier::InScan) {
-      if (!verifyScanVarsInReduction(parentOp.getReductionVars())) {
+    if (parentSimdOp.getReductionModAttr() &&
+        parentSimdOp.getReductionModAttr().getValue() ==
+            mlir::omp::ReductionModifier::inscan) {
+      auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
+          parentSimdOp.getOperation());
+      if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
         return emitError(
             "List item should appear in REDUCTION clause of the parent");
       }
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0be515e63b470c..3c456516eeea51 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4591,6 +4591,10 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
       .Case([&](omp::AtomicCaptureOp op) {
         return convertOmpAtomicCapture(op, builder, moduleTranslation);
       })
+      .Case([&](omp::ScanOp) {
+        return op->emitError()
+               << "not yet implemented: " << op->getName() << " operation";
+      })
       .Case([&](omp::SectionsOp) {
         return convertOmpSections(*op, builder, moduleTranslation);
       })
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8d095287db064e..e3e6397d84d77c 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1840,7 +1840,7 @@ combiner {
 func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
   %test2f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+  omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
   // expected-error @below {{List item should appear in REDUCTION clause of the parent}}
        omp.scan inclusive(%test2f32 : !llvm.ptr)
@@ -1866,7 +1866,7 @@ combiner {
 
 func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.wsloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+  omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
   // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
        omp.scan inclusive(%test1f32 : !llvm.ptr) exclusive(%test1f32: !llvm.ptr)
@@ -1892,7 +1892,7 @@ combiner {
 
 func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.taskloop reduction(mod:InScan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+  omp.taskloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
     omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
   // expected-error @below {{SCAN directive needs to be enclosed within a parent worksharing loop construct or SIMD construct with INSCAN reduction modifier}}
        omp.scan inclusive(%test1f32 : !llvm.ptr)
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 2c8e9ad3fe0700..c1259fabe82fba 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -904,19 +904,19 @@ func.func @wsloop_reduction(%lb : index, %ub : index, %step : index) {
 func.func @wsloop_inscan_reduction(%lb : index, %ub : index, %step : index) {
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
-  // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
-  omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+  // CHECK: reduction(mod: inscan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(mod:inscan, @add_f32 %0 -> %prv : !llvm.ptr) {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
        // CHECK: omp.scan inclusive(%{{.*}} : !llvm.ptr)
-       omp.scan inclusive(%0 : !llvm.ptr)
+       omp.scan inclusive(%prv : !llvm.ptr)
        omp.yield
     }
   }
-  // CHECK: reduction(mod: InScan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
-  omp.wsloop reduction(mod:InScan, @add_f32 %0 -> %prv : !llvm.ptr) {
+  // CHECK: reduction(mod: inscan, @add_f32 %{{.+}} -> %[[PRV:.+]] : !llvm.ptr)
+  omp.wsloop reduction(mod:inscan, @add_f32 %0 -> %prv : !llvm.ptr) {
     omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
        // CHECK: omp.scan exclusive(%{{.*}} : !llvm.ptr)
-       omp.scan exclusive(%0 : !llvm.ptr)
+       omp.scan exclusive(%prv : !llvm.ptr)
        omp.yield
     }
   }
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 392a6558dcfa69..d5c43322c03469 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -186,6 +186,38 @@ llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // -----
 
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+  %2 = llvm.load %arg3 : !llvm.ptr -> f32
+  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+  omp.yield
+}
+llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
+  omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      // expected-error at below {{not yet implemented: omp.scan operation}}
+      // expected-error at below {{LLVM Translation failed for operation: omp.scan}}
+      omp.scan inclusive(%prv : !llvm.ptr)
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @single_allocate(%x : !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause allocate in omp.single operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.single}}

>From 2adadc13e18269e08a4a2c93b77d8f2c62cda1ce Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 15 Jan 2025 13:17:44 -0600
Subject: [PATCH 5/6] R5: Removing the check for matching scan variables and
 reduction variables

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 25 +++---------------
 mlir/test/Dialect/OpenMP/invalid.mlir        | 27 --------------------
 2 files changed, 3 insertions(+), 49 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a8bf501ef28e6e..aaa0eb8455d2a5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3167,38 +3167,19 @@ LogicalResult ScanOp::verify() {
     return emitError(
         "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
   }
-  const OperandRange &scanVars =
-      hasExclusiveVars() ? getExclusiveVars() : getInclusiveVars();
-  auto verifyScanVarsInReduction = [&scanVars](ValueRange reductionVars) {
-    for (const auto &it : scanVars)
-      if (!llvm::is_contained(reductionVars, it))
-        return false;
-    return true;
-  };
   if (mlir::omp::WsloopOp parentWsLoopOp =
           (*this)->getParentOfType<mlir::omp::WsloopOp>()) {
     if (parentWsLoopOp.getReductionModAttr() &&
         parentWsLoopOp.getReductionModAttr().getValue() ==
             mlir::omp::ReductionModifier::inscan) {
-      auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
-          parentWsLoopOp.getOperation());
-      if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
-        return emitError(
-            "List item should appear in REDUCTION clause of the parent");
-      }
       return success();
     }
-  } else if (mlir::omp::SimdOp parentSimdOp =
-                 (*this)->getParentOfType<mlir::omp::SimdOp>()) {
+  }
+  if (mlir::omp::SimdOp parentSimdOp =
+          (*this)->getParentOfType<mlir::omp::SimdOp>()) {
     if (parentSimdOp.getReductionModAttr() &&
         parentSimdOp.getReductionModAttr().getValue() ==
             mlir::omp::ReductionModifier::inscan) {
-      auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(
-          parentSimdOp.getOperation());
-      if (!verifyScanVarsInReduction(iface.getReductionBlockArgs())) {
-        return emitError(
-            "List item should appear in REDUCTION clause of the parent");
-      }
       return success();
     }
   }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e3e6397d84d77c..514792e425f89e 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1837,33 +1837,6 @@ combiner {
   omp.yield (%1 : f32)
 }
 
-func.func @scan_test_1(%lb: i32, %ub: i32, %step: i32) {
-  %test1f32 = "test.f32"() : () -> (!llvm.ptr)
-  %test2f32 = "test.f32"() : () -> (!llvm.ptr)
-  omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
-    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
-  // expected-error @below {{List item should appear in REDUCTION clause of the parent}}
-       omp.scan inclusive(%test2f32 : !llvm.ptr)
-        omp.yield
-    }
-  }
-  return
-}
-
-// -----
-
-omp.declare_reduction @add_f32 : f32
-init {
- ^bb0(%arg: f32):
-  %0 = arith.constant 0.0 : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-  ^bb1(%arg0: f32, %arg1: f32):
-  %1 = arith.addf %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-
 func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
   omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {

>From c7a03e52842585d121013a4652a821e75966f917 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 21 Jan 2025 14:59:25 -0600
Subject: [PATCH 6/6] R6: Addressing review comments

---
 .../mlir/Dialect/OpenMP/OpenMPEnums.td        |  4 +--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 31 +++++++------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 15 ++++-----
 mlir/test/Dialect/OpenMP/invalid.mlir         | 26 ++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  3 +-
 5 files changed, 49 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 4fa7630166f5f8..690e3df1f685e3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -184,14 +184,14 @@ def OrderModifierAttr : EnumAttr<OpenMP_Dialect, OrderModifier,
 //===----------------------------------------------------------------------===//
 
 def ReductionModifierDefault : I32EnumAttrCase<"defaultmod", 0>;
-def ReductionModifierInScan : I32EnumAttrCase<"inscan", 1>;
+def ReductionModifierInscan : I32EnumAttrCase<"inscan", 1>;
 def ReductionModifierTask : I32EnumAttrCase<"task", 2>;
 
 def ReductionModifier : OpenMP_I32EnumAttr<
     "ReductionModifier",
     "reduction modifier", [
       ReductionModifierDefault,
-      ReductionModifierInScan,
+      ReductionModifierInscan,
       ReductionModifierTask
     ]>;
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index aaa0eb8455d2a5..3e0ec356bfc368 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -835,9 +835,9 @@ struct ReductionPrintArgs {
   TypeRange types;
   DenseBoolArrayAttr byref;
   ArrayAttr syms;
-  ReductionModifierAttr *modifier;
+  ReductionModifierAttr modifier;
   ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
-                     ArrayAttr syms, ReductionModifierAttr *mod = nullptr)
+                     ArrayAttr syms, ReductionModifierAttr mod = nullptr)
       : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
 };
 struct AllRegionPrintArgs {
@@ -857,14 +857,14 @@ static void printClauseWithRegionArgs(
     ValueRange argsSubrange, ValueRange operands, TypeRange types,
     ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
     DenseBoolArrayAttr byref = nullptr,
-    ReductionModifierAttr *modifier = nullptr) {
+    ReductionModifierAttr modifier = nullptr) {
   if (argsSubrange.empty())
     return;
 
   p << clauseName << "(";
 
-  if (modifier && *modifier)
-    p << "mod: " << stringifyReductionModifier(modifier->getValue()) << ", ";
+  if (modifier)
+    p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
 
   if (!symbols) {
     llvm::SmallVector<Attribute> values(operands.size(), nullptr);
@@ -998,7 +998,7 @@ static void printInReductionPrivateReductionRegion(
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms, &reductionMod);
+                             reductionSyms, reductionMod);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -1021,7 +1021,7 @@ static void printPrivateReductionRegion(
   args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
                            /*mapIndices=*/nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
-                             reductionSyms, &reductionMod);
+                             reductionSyms, reductionMod);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -3163,26 +3163,19 @@ void ScanOp::build(OpBuilder &builder, OperationState &state,
 }
 
 LogicalResult ScanOp::verify() {
-  if (hasExclusiveVars() && hasInclusiveVars()) {
+  if (hasExclusiveVars() == hasInclusiveVars())
     return emitError(
         "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
-  }
-  if (mlir::omp::WsloopOp parentWsLoopOp =
-          (*this)->getParentOfType<mlir::omp::WsloopOp>()) {
+  if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>())
     if (parentWsLoopOp.getReductionModAttr() &&
         parentWsLoopOp.getReductionModAttr().getValue() ==
-            mlir::omp::ReductionModifier::inscan) {
+            mlir::omp::ReductionModifier::inscan)
       return success();
-    }
-  }
-  if (mlir::omp::SimdOp parentSimdOp =
-          (*this)->getParentOfType<mlir::omp::SimdOp>()) {
+  if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>())
     if (parentSimdOp.getReductionModAttr() &&
         parentSimdOp.getReductionModAttr().getValue() ==
-            mlir::omp::ReductionModifier::inscan) {
+            mlir::omp::ReductionModifier::inscan)
       return success();
-    }
-  }
   return emitError("SCAN directive needs to be enclosed within a parent "
                    "worksharing loop construct or SIMD construct with INSCAN "
                    "reduction modifier");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3c456516eeea51..7efead79ba9ee7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -241,9 +241,13 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     }
   };
   auto checkReduction = [&todo](auto op, LogicalResult &result) {
-    if (!op.getReductionVars().empty() || op.getReductionByref() ||
-        op.getReductionSyms())
-      result = todo("reduction");
+    if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
+      if (!op.getReductionVars().empty() || op.getReductionByref() ||
+          op.getReductionSyms())
+        result = todo("reduction");
+    if (op.getReductionMod() &&
+        op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
+      result = todo("reduction with modifier");
   };
   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -286,6 +290,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkNowait(op, result);
       })
       .Case([&](omp::WsloopOp op) {
+        checkReduction(op, result);
         checkAllocate(op, result);
         checkLinear(op, result);
         checkOrder(op, result);
@@ -4591,10 +4596,6 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
       .Case([&](omp::AtomicCaptureOp op) {
         return convertOmpAtomicCapture(op, builder, moduleTranslation);
       })
-      .Case([&](omp::ScanOp) {
-        return op->emitError()
-               << "not yet implemented: " << op->getName() << " operation";
-      })
       .Case([&](omp::SectionsOp) {
         return convertOmpSections(*op, builder, moduleTranslation);
       })
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 514792e425f89e..8ee40208b14e1a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1837,6 +1837,32 @@ combiner {
   omp.yield (%1 : f32)
 }
 
+func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
+  %test1f32 = "test.f32"() : () -> (!llvm.ptr)
+  omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
+    omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
+  // expected-error @below {{Exactly one of EXCLUSIVE or INCLUSIVE clause is expected}}
+       omp.scan
+        omp.yield
+    }
+  }
+  return
+}
+
+// -----
+
+omp.declare_reduction @add_f32 : f32
+init {
+ ^bb0(%arg: f32):
+  %0 = arith.constant 0.0 : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+  ^bb1(%arg0: f32, %arg1: f32):
+  %1 = arith.addf %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+
 func.func @scan_test_2(%lb: i32, %ub: i32, %step: i32) {
   %test1f32 = "test.f32"() : () -> (!llvm.ptr)
   omp.wsloop reduction(mod:inscan, @add_f32 %test1f32 -> %arg1 : !llvm.ptr) {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d5c43322c03469..30833474256a4e 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -204,11 +204,10 @@ atomic {
   omp.yield
 }
 llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
   omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
     omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      // expected-error at below {{not yet implemented: omp.scan operation}}
-      // expected-error at below {{LLVM Translation failed for operation: omp.scan}}
       omp.scan inclusive(%prv : !llvm.ptr)
       omp.yield
     }



More information about the Mlir-commits mailing list