[Mlir-commits] [flang] [mlir] [MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side (PR #165788)

Anchu Rajendran S llvmlistbot at llvm.org
Thu Oct 30 15:09:01 PDT 2025


https://github.com/anchuraj created https://github.com/llvm/llvm-project/pull/165788

Scan reductions are supported in OpenMP with the the help of scan directive. Reduction clause of the for workshare loop/simd directive takes an `inscan` modifier if scan reduction is specified. With an `inscan` modifier, the body of the directive should specify a `scan` directive. This PR implements the lowering logic for scan reductions in workshare loops of OpenMP. OpenMPIRBuilder support can be found in https://github.com/llvm/llvm-project/pull/136035. Support for nested loops/ exclusive clause is not done in this PR 



>From cb38aae937f89b0c77aed5f147998f02db7c2d44 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 30 Oct 2025 17:01:59 -0500
Subject: [PATCH] [MLIR][LLVMIR] Adding scan lowering to llvm on the mlir side

---
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  39 +-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 380 ++++++++++++++----
 .../Target/LLVMIR/openmp-reduction-scan.mlir  | 130 ++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  38 +-
 4 files changed, 503 insertions(+), 84 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f86ee01355104..5d82466889b1e 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -2326,12 +2326,41 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
 static mlir::omp::ScanOp
 genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
-          semantics::SemanticsContext &semaCtx, mlir::Location loc,
-          const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+          semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+          mlir::Location loc, const ConstructQueue &queue,
+          ConstructQueue::const_iterator item) {
   mlir::omp::ScanOperands clauseOps;
   genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
-  return mlir::omp::ScanOp::create(converter.getFirOpBuilder(),
-                                   converter.getCurrentLocation(), clauseOps);
+  mlir::omp::ScanOp scanOp = mlir::omp::ScanOp::create(
+      converter.getFirOpBuilder(), converter.getCurrentLocation(), clauseOps);
+  // If there are nested loops all indices should be loaded after
+  // the scan construct as otherwise, it would result in using the index
+  // variable across scan directive.
+  // (`Intra-iteration dependences from a statement in the structured
+  // block sequence that precede a scan directive to a statement in the
+  // structured block sequence that follows a scan directive must not exist,
+  // except for dependences for the list items specified in an inclusive or
+  // exclusive clause.`).
+  // TODO: If there are nested loops, it is not handled.
+  mlir::omp::LoopNestOp loopNestOp =
+      scanOp->getParentOfType<mlir::omp::LoopNestOp>();
+  assert(loopNestOp.getNumLoops() == 1 &&
+         "Scan directive inside nested do loops is not handled yet.");
+  mlir::Region &region = loopNestOp->getRegion(0);
+  mlir::Value indexVal = fir::getBase(region.getArgument(0));
+  lower::pft::Evaluation *doConstructEval = eval.parentConstruct;
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  lower::pft::Evaluation *doLoop = &doConstructEval->getFirstNestedEvaluation();
+  auto *doStmt = doLoop->getIf<parser::NonLabelDoStmt>();
+  assert(doStmt && "Expected do loop to be in the nested evaluation");
+  const auto &loopControl =
+      std::get<std::optional<parser::LoopControl>>(doStmt->t);
+  const parser::LoopControl::Bounds *bounds =
+      std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+  mlir::Operation *storeOp =
+      setLoopVar(converter, loc, indexVal, bounds->name.thing.symbol);
+  firOpBuilder.setInsertionPointAfter(storeOp);
+  return scanOp;
 }
 
 static mlir::omp::SectionsOp
@@ -3416,7 +3445,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
                                   loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_scan:
-    newOp = genScanOp(converter, symTable, semaCtx, loc, queue, item);
+    newOp = genScanOp(converter, symTable, semaCtx, eval, loc, queue, item);
     break;
   case llvm::omp::Directive::OMPD_section:
     llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1e2099d6cc1b2..eb4378dd6f719 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -37,6 +37,7 @@
 #include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
+#include <cassert>
 #include <cstdint>
 #include <iterator>
 #include <numeric>
@@ -77,6 +78,22 @@ class OpenMPAllocaStackFrame
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
 };
 
+/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
+/// insertion points for allocas of parent of the current parallel region. The
+/// insertion point is used to allocate variables to be share by the threads
+/// executing the parallel region. Lowering of scan reduction requires declaring
+/// shared pointers to the temporary buffer to perform scan reduction.
+class OpenMPParallelAllocaStackFrame
+    : public StateStackFrameBase<OpenMPParallelAllocaStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPParallelAllocaStackFrame)
+
+  explicit OpenMPParallelAllocaStackFrame(
+      llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
+      : allocaInsertPoint(allocaIP) {}
+  llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
+};
+
 /// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
 /// collapsed canonical loop information corresponding to an \c omp.loop_nest
 /// operation.
@@ -84,7 +101,13 @@ class OpenMPLoopInfoStackFrame
     : public StateStackFrameBase<OpenMPLoopInfoStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+  /// For constructs like scan, one LoopInfo frame can contain multiple
+  /// Canonical Loops as a single openmpLoopNestOp will be split into input
+  /// loop and scan loop.
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
+  llvm::ScanInfo *scanInfo;
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      new llvm::DenseMap<llvm::Value *, llvm::Type *>();
 };
 
 /// Custom error class to signal translation errors that don't need reporting,
@@ -323,6 +346,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDistScheduleChunkSize())
       result = todo("dist_schedule with chunk_size");
   };
+  auto checkExclusive = [&todo](auto op, LogicalResult &result) {
+    if (!op.getExclusiveVars().empty())
+      result = todo("exclusive");
+  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -371,9 +398,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       if (!op.getReductionVars().empty() || op.getReductionByref() ||
           op.getReductionSyms())
         result = todo("reduction");
-    if (op.getReductionMod() &&
-        op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
-      result = todo("reduction with modifier");
+    if (op.getReductionMod()) {
+      if (isa<omp::WsloopOp>(op)) {
+        if (op.getReductionMod().value() == omp::ReductionModifier::task)
+          result = todo("reduction with task modifier");
+      } else {
+        result = todo("reduction with modifier");
+      }
+    }
   };
   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -397,6 +429,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
+      .Case([&](omp::ScanOp op) { checkExclusive(op, result); })
       .Case([&](omp::SectionsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
@@ -531,15 +564,59 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
 /// Find the loop information structure for the loop nest being translated. It
 /// will return a `null` value unless called from the translation function for
 /// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        loopInfo = frame.loopInfo;
+        loopInfos = frame.loopInfos;
         return WalkResult::interrupt();
       });
-  return loopInfo;
+  return loopInfos;
+}
+
+// LoopFrame stores the scaninfo which is used for scan reduction.
+// Upon encountering an `inscan` reduction modifier, `scanInfoInitialize`
+// initializes the ScanInfo and is used when scan directive is encountered
+// in the body of the loop nest.
+static llvm::ScanInfo *
+findScanInfo(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::ScanInfo *scanInfo;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        scanInfo = frame.scanInfo;
+        return WalkResult::interrupt();
+      });
+  return scanInfo;
+}
+
+// The types of reduction vars are used for lowering scan directive which
+// appears in the body of the loop. The types are stored in loop frame when
+// reduction clause is encountered and is used when scan directive is
+// encountered.
+static llvm::DenseMap<llvm::Value *, llvm::Type *> *
+findReductionVarTypes(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType = nullptr;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        reductionVarToType = frame.reductionVarToType;
+        return WalkResult::interrupt();
+      });
+  return reductionVarToType;
+}
+
+// Scan reduction requires a shared buffer to be allocated to perform reduction.
+// ParallelAllocaStackFrame holds the allocaIP where shared allocation can be
+// done.
+static llvm::OpenMPIRBuilder::InsertPointTy
+findParallelAllocaIP(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder::InsertPointTy parallelAllocaIP;
+  moduleTranslation.stackWalk<OpenMPParallelAllocaStackFrame>(
+      [&](OpenMPParallelAllocaStackFrame &frame) {
+        parallelAllocaIP = frame.allocaInsertPoint;
+        return WalkResult::interrupt();
+      });
+  return parallelAllocaIP;
 }
 
 /// Converts the given region that appears within an OpenMP dialect operation to
@@ -1254,11 +1331,17 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
   for (auto [data, addr] : deferredStores)
     builder.CreateStore(data, addr);
 
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
   // Before the loop, store the initial values of reductions into reduction
   // variables. Although this could be done after allocas, we don't want to mess
   // up with the alloca insertion point.
   for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
     SmallVector<llvm::Value *, 1> phis;
+    llvm::Type *reductionType =
+        moduleTranslation.convertType(reductionDecls[i].getType());
+    if (reductionVarToType != nullptr)
+      (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
 
     // map block argument to initializer region
     mapInitializationArgs(op, moduleTranslation, reductionDecls,
@@ -1330,15 +1413,20 @@ static void collectReductionInfo(
 
   // Collect the reduction information.
   reductionInfos.reserve(numReductions);
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
   for (unsigned i = 0; i < numReductions; ++i) {
     llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
     if (owningAtomicReductionGens[i])
       atomicGen = owningAtomicReductionGens[i];
     llvm::Value *variable =
         moduleTranslation.lookupValue(loop.getReductionVars()[i]);
+    llvm::Type *reductionType =
+        moduleTranslation.convertType(reductionDecls[i].getType());
+    if (reductionVarToType != nullptr)
+      (*reductionVarToType)[privateReductionVariables[i]] = reductionType;
     reductionInfos.push_back(
-        {moduleTranslation.convertType(reductionDecls[i].getType()), variable,
-         privateReductionVariables[i],
+        {reductionType, variable, privateReductionVariables[i],
          /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
          owningReductionGens[i],
          /*ReductionGenClang=*/nullptr, atomicGen});
@@ -2543,6 +2631,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
   bool isSimd = wsloopOp.getScheduleSimd();
   bool loopNeedsBarrier = !wsloopOp.getNowait();
+  bool isInScanRegion =
+      wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                     mlir::omp::ReductionModifier::inscan);
 
   // The only legal way for the direct parent to be omp.distribute is that this
   // represents 'distribute parallel do'. Otherwise, this is a regular
@@ -2574,20 +2665,81 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+
+  const auto &&wsloopCodeGen = [&](llvm::CanonicalLoopInfo *loopInfo,
+                                   bool noLoopMode, bool inputScanLoop) {
+    bool emitLinearVarInit = !isInScanRegion || inputScanLoop;
+    // Emit Initialization and Update IR for linear variables
+    if (emitLinearVarInit && !wsloopOp.getLinearVars().empty()) {
+      llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+          linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                              loopInfo->getPreheader());
+      if (failed(handleError(afterBarrierIP, *loopOp)))
+        return failure();
+      builder.restoreIP(*afterBarrierIP);
+      linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                            loopInfo->getIndVar());
+      linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                        loopInfo->getExit());
+    }
 
-  // Emit Initialization and Update IR for linear variables
-  if (!wsloopOp.getLinearVars().empty()) {
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
-        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
-                                            loopInfo->getPreheader());
-    if (failed(handleError(afterBarrierIP, *loopOp)))
+    builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+        ompBuilder->applyWorkshareLoop(
+            ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+            convertToScheduleKind(schedule), chunk, isSimd,
+            scheduleMod == omp::ScheduleModifier::monotonic,
+            scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+            workshareLoopType, noLoopMode);
+
+    if (failed(handleError(wsloopIP, opInst)))
+      return failure();
+
+    bool emitLinearVarFinalize = !isInScanRegion || !inputScanLoop;
+    // Emit finalization and in-place rewrites for linear vars.
+    if (emitLinearVarFinalize && !wsloopOp.getLinearVars().empty()) {
+      llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+      if (loopInfo->getLastIter())
+        return failure();
+      // assert(loopInfo->getLastIter() &&
+      //        "`lastiter` in CanonicalLoopInfo is nullptr");
+      llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+          linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
+                                                  loopInfo->getLastIter());
+      if (failed(handleError(afterBarrierIP, *loopOp)))
+        return failure();
+      for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
+        linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+                                             index);
+      builder.restoreIP(oldIP);
+    }
+    if (!inputScanLoop || !isInScanRegion)
+      popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
+
+    return llvm::success();
+  };
+
+  if (isInScanRegion) {
+    auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+    builder.restoreIP(inputLoopFinishIp);
+    SmallVector<OwningReductionGen> owningReductionGens;
+    SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+    collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+                         owningReductionGens, owningAtomicReductionGens,
+                         privateReductionVariables, reductionInfos);
+    llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+    llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+        ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos,
+                                      scanInfo);
+    if (failed(handleError(redIP, opInst)))
       return failure();
-    builder.restoreIP(*afterBarrierIP);
-    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
-                                          loopInfo->getIndVar());
-    linearClauseProcessor.outlineLinearFinalizationBB(builder,
-                                                      loopInfo->getExit());
+
+    builder.restoreIP(*redIP);
+    builder.CreateBr(cont);
   }
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2612,42 +2764,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     }
   }
 
-  llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-      ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-          convertToScheduleKind(schedule), chunk, isSimd,
-          scheduleMod == omp::ScheduleModifier::monotonic,
-          scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType, noLoopMode);
-
-  if (failed(handleError(wsloopIP, opInst)))
-    return failure();
-
-  // Emit finalization and in-place rewrites for linear vars.
-  if (!wsloopOp.getLinearVars().empty()) {
-    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
-    assert(loopInfo->getLastIter() &&
-           "`lastiter` in CanonicalLoopInfo is nullptr");
-    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
-        linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
-                                                loopInfo->getLastIter());
-    if (failed(handleError(afterBarrierIP, *loopOp)))
+  // For Scan loops input loop need not pop cancellation CB and hence, it is set
+  // false for the first loop
+  bool inputScanLoop = isInScanRegion;
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    if (failed(wsloopCodeGen(loopInfo, noLoopMode, inputScanLoop)))
       return failure();
-    for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
-      linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
-                                           index);
-    builder.restoreIP(oldIP);
+    inputScanLoop = false;
   }
 
-  // Set the correct branch target for task cancellation
-  popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
-
-  // Process the reductions if required.
-  if (failed(createReductionsAndCleanup(
-          wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
-          privateReductionVariables, isByRef, wsloopOp.getNowait(),
-          /*isTeamsReduction=*/false)))
-    return failure();
+  // todo: change builder.saveIP to wsLoopIP
+  if (isInScanRegion) {
+    SmallVector<Region *> reductionRegions;
+    llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+                    [](omp::DeclareReductionOp reductionDecl) {
+                      return &reductionDecl.getCleanupRegion();
+                    });
+    if (failed(inlineOmpRegionCleanup(
+            reductionRegions, privateReductionVariables, moduleTranslation,
+            builder, "omp.reduction.cleanup")))
+      return failure();
+  } else {
+    // Process the reductions if required.
+    if (failed(createReductionsAndCleanup(
+            wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
+            privateReductionVariables, isByRef, wsloopOp.getNowait(),
+            /*isTeamsReduction=*/false)))
+      return failure();
+  }
 
   return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
                             privateVarsInfo.llvmVars,
@@ -2815,6 +2959,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
       findAllocaInsertPoint(builder, moduleTranslation);
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  LLVM::ModuleTranslation::SaveStack<OpenMPParallelAllocaStackFrame> frame(
+      moduleTranslation, allocaIP);
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
@@ -2935,12 +3081,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-  ompBuilder->applySimd(loopInfo, alignedVars,
-                        simdOp.getIfExpr()
-                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
-                            : nullptr,
-                        order, simdlen, safelen);
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    ompBuilder->applySimd(
+        loopInfo, alignedVars,
+        simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                           : nullptr,
+        order, simdlen, safelen);
+  }
 
   // We now need to reduce the per-simd-lane reduction variable into the
   // original variable. This works a bit differently to other reductions (e.g.
@@ -2991,6 +3140,40 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                             privateVarsInfo.privatizers);
 }
 
+static LogicalResult
+convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
+               LLVM::ModuleTranslation &moduleTranslation) {
+  if (failed(checkImplementationStatus(opInst)))
+    return failure();
+  auto scanOp = cast<omp::ScanOp>(opInst);
+  bool isInclusive = scanOp.hasInclusiveVars();
+  SmallVector<llvm::Value *> llvmScanVars;
+  SmallVector<llvm::Type *> llvmScanVarsType;
+  mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
+  if (!isInclusive)
+    mlirScanVars = scanOp.getExclusiveVars();
+
+  llvm::DenseMap<llvm::Value *, llvm::Type *> *reductionVarToType =
+      findReductionVarTypes(moduleTranslation);
+  for (auto val : mlirScanVars) {
+    llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
+    llvmScanVars.push_back(llvmVal);
+    llvmScanVarsType.push_back((*reductionVarToType)[llvmVal]);
+  }
+  llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+      findParallelAllocaIP(moduleTranslation);
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::ScanInfo *scanInfo = findScanInfo(moduleTranslation);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+      moduleTranslation.getOpenMPBuilder()->createScan(
+          ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive,
+          scanInfo);
+  if (failed(handleError(afterIP, opInst)))
+    return failure();
+  builder.restoreIP(*afterIP);
+  return success();
+}
+
 /// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -3052,14 +3235,50 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
       computeIP = loopInfos.front()->getPreheaderIP();
     }
 
+    bool isInScanRegion = false;
+    if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>())
+      isInScanRegion =
+          wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                         mlir::omp::ReductionModifier::inscan);
+    if (isInScanRegion) {
+      // TODO: Handle nesting if Scan loop is nested in a loop
+      assert(loopOp.getNumLoops() == 1 &&
+             "Scan directive inside nested do loops is not handled yet.");
+      llvm::Expected<llvm::ScanInfo *> res = ompBuilder->scanInfoInitialize();
+      if (failed(handleError(res, *loopOp)))
+        return failure();
+      llvm::ScanInfo *scanInfo = res.get();
+      moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+          [&](OpenMPLoopInfoStackFrame &frame) {
+            frame.scanInfo = scanInfo;
+            return WalkResult::interrupt();
+          });
+      llvm::Expected<llvm::SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
+          ompBuilder->createCanonicalScanLoops(
+              loc, bodyGen, lowerBound, upperBound, step,
+              /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
+              scanInfo);
+
+      if (failed(handleError(loopResults, *loopOp)))
+        return failure();
+      llvm::CanonicalLoopInfo *inputLoop = loopResults.get().front();
+      llvm::CanonicalLoopInfo *scanLoop = loopResults.get().back();
+      moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+          [&](OpenMPLoopInfoStackFrame &frame) {
+            frame.loopInfos.push_back(inputLoop);
+            frame.loopInfos.push_back(scanLoop);
+            return WalkResult::interrupt();
+          });
+      builder.restoreIP(scanLoop->getAfterIP());
+      // TODO: tiling and collapse are not yet implemented for scan reduction
+      return success();
+    }
     llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
         ompBuilder->createCanonicalLoop(
             loc, bodyGen, lowerBound, upperBound, step,
             /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
     if (failed(handleError(loopResult, *loopOp)))
       return failure();
-
     loopInfos.push_back(*loopResult);
   }
 
@@ -3102,7 +3321,7 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
   assert(newTopLoopInfo && "New top loop information is missing");
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        frame.loopInfo = newTopLoopInfo;
+        frame.loopInfos.push_back(newTopLoopInfo);
         return WalkResult::interrupt();
       });
 
@@ -4965,18 +5184,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
       bool loopNeedsBarrier = false;
       llvm::Value *chunk = nullptr;
 
-      llvm::CanonicalLoopInfo *loopInfo =
-          findCurrentLoopInfo(moduleTranslation);
-      llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-          ompBuilder->applyWorkshareLoop(
-              ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-              convertToScheduleKind(schedule), chunk, isSimd,
-              scheduleMod == omp::ScheduleModifier::monotonic,
-              scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType);
-
-      if (!wsloopIP)
-        return wsloopIP.takeError();
+      SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+          findCurrentLoopInfos(moduleTranslation);
+      for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+        llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+            ompBuilder->applyWorkshareLoop(
+                ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+                convertToScheduleKind(schedule), chunk, isSimd,
+                scheduleMod == omp::ScheduleModifier::monotonic,
+                scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+                workshareLoopType);
+
+        if (!wsloopIP)
+          return wsloopIP.takeError();
+      }
     }
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation,
@@ -6167,6 +6388,11 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
           .Case([&](omp::WsloopOp) {
             return convertOmpWsloop(*op, builder, moduleTranslation);
           })
+          .Case([&](omp::ScanOp) {
+            if (failed(checkImplementationStatus(*op)))
+              return failure();
+            return convertOmpScan(*op, builder, moduleTranslation);
+          })
           .Case([&](omp::SimdOp) {
             return convertOmpSimd(*op, builder, moduleTranslation);
           })
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
new file mode 100644
index 0000000000000..ed04a069b998f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
@@ -0,0 +1,130 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+// CHECK-LABEL: @scan_reduction
+llvm.func @scan_reduction() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
+  %4 = llvm.mlir.constant(1 : i64) : i64
+  %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  %6 = llvm.mlir.constant(1 : i64) : i64
+  %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr
+  %8 = llvm.mlir.constant(0 : index) : i64
+  %9 = llvm.mlir.constant(1 : index) : i64
+  %10 = llvm.mlir.constant(100 : i32) : i32
+  %11 = llvm.mlir.constant(1 : i32) : i32
+  %12 = llvm.mlir.constant(0 : i32) : i32
+  %13 = llvm.mlir.constant(100 : index) : i64
+  %14 = llvm.mlir.addressof @_QFEa : !llvm.ptr
+  %15 = llvm.mlir.addressof @_QFEb : !llvm.ptr
+  omp.parallel {
+    %37 = llvm.mlir.constant(1 : i64) : i64
+    %38 = llvm.alloca %37 x i32 {bindc_name = "k", pinned} : (i64) -> !llvm.ptr
+    %39 = llvm.mlir.constant(1 : i64) : i64
+    omp.wsloop reduction(mod: inscan, @add_reduction_i32 %5 -> %arg0 : !llvm.ptr) {
+      omp.loop_nest (%arg1) : i32 = (%11) to (%10) inclusive step (%11) {
+        llvm.store %arg1, %38 : i32, !llvm.ptr
+        %40 = llvm.load %arg0 : !llvm.ptr -> i32
+        %41 = llvm.load %38 : !llvm.ptr -> i32
+        %42 = llvm.sext %41 : i32 to i64
+        %50 = llvm.getelementptr %14[%42] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+        %51 = llvm.load %50 : !llvm.ptr -> i32
+        %52 = llvm.add %40, %51 : i32
+        llvm.store %52, %arg0 : i32, !llvm.ptr
+        omp.scan inclusive(%arg0 : !llvm.ptr)
+        llvm.store %arg1, %38 : i32, !llvm.ptr
+        %53 = llvm.load %arg0 : !llvm.ptr -> i32
+        %54 = llvm.load %38 : !llvm.ptr -> i32
+        %55 = llvm.sext %54 : i32 to i64
+        %63 = llvm.getelementptr %15[%55] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+        llvm.store %53, %63 : i32, !llvm.ptr
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+  %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+  llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+  %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+  llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 {
+  %0 = llvm.mlir.constant(100 : i32) : i32
+  llvm.return %0 : i32
+}
+//CHECK: %vla = alloca ptr, align 8
+//CHECK: omp_parallel
+//CHECK: store ptr %vla, ptr %gep_vla, align 8
+//CHECK: @__kmpc_fork_call
+//CHECK: void @scan_reduction..omp_par
+//CHECK: %[[BUFF_PTR:.+]] = load ptr, ptr %gep_vla
+//CHECK: @__kmpc_masked
+//CHECK: @__kmpc_barrier
+//CHECK: @__kmpc_masked
+//CHECK: @__kmpc_barrier
+//CHECK: omp.scan.loop.cont:
+//CHECK: @__kmpc_masked
+//CHECK: @__kmpc_barrier
+//CHECK: %[[FREE_VAR:.+]] = load ptr, ptr %[[BUFF_PTR]], align 8
+//CHECK:  %[[ARRLAST:.+]] = getelementptr inbounds i32, ptr %[[FREE_VAR]], i32 100
+//CHECK:  %[[RES:.+]] = load i32, ptr %[[ARRLAST]], align 4
+//CHECK:  store i32 %[[RES]], ptr %loadgep{{.*}}, align 4
+//CHECK: tail call void @free(ptr %[[FREE_VAR]])
+//CHECK: @__kmpc_end_masked
+//CHECK: omp.inscan.dispatch{{.*}}:                            ; preds = %omp_loop.body{{.*}}
+//CHECK:   %[[BUFFVAR:.+]] = load ptr, ptr %[[BUFF_PTR]], align 8
+//CHECK:   %[[arrayOffset1:.+]] = getelementptr inbounds i32, ptr %[[BUFFVAR]], i32 %{{.*}}
+//CHECK:   %[[BUFFVAL1:.+]] = load i32, ptr %[[arrayOffset1]], align 4
+//CHECK:   store i32 %[[BUFFVAL1]], ptr %{{.*}}, align 4
+//CHECK:   %[[LOG:.+]] = call double @llvm.log2.f64(double 1.000000e+02) #0
+//CHECK:   %[[CEIL:.+]] = call double @llvm.ceil.f64(double %[[LOG]]) #0
+//CHECK:   %[[UB:.+]] = fptoui double %[[CEIL]] to i32
+//CHECK:   br label %omp.outer.log.scan.body
+//CHECK: omp.outer.log.scan.body:
+//CHECK:   %[[K:.+]] = phi i32 [ 0, %{{.*}} ], [ %[[NEXTK:.+]], %omp.inner.log.scan.exit ]
+//CHECK:   %[[I:.+]] = phi i32 [ 1, %{{.*}} ], [ %[[NEXTI:.+]], %omp.inner.log.scan.exit ]
+//CHECK:   %[[CMP1:.+]] = icmp uge i32 99, %[[I]]
+//CHECK:   br i1 %[[CMP1]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+//CHECK: omp.inner.log.scan.exit:                          ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+//CHECK:   %[[NEXTK]] = add nuw i32 %[[K]], 1
+//CHECK:   %[[NEXTI]] = shl nuw i32 %[[I]], 1
+//CHECK:   %[[CMP2:.+]] = icmp ne i32 %[[NEXTK]], %[[UB]]
+//CHECK:   br i1 %[[CMP2]], label %omp.outer.log.scan.body, label %omp.outer.log.scan.exit
+//CHECK: omp.outer.log.scan.exit:                          ; preds = %omp.inner.log.scan.exit
+//CHECK: @__kmpc_end_masked
+//CHECK: omp.inner.log.scan.body:                          ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+//CHECK:   %[[CNT:.+]] = phi i32 [ 99, %omp.outer.log.scan.body ], [ %[[CNTNXT:.+]], %omp.inner.log.scan.body ]
+//CHECK:   %[[BUFF:.+]] = load ptr, ptr %[[BUFF_PTR]]
+//CHECK:   %[[IND1:.+]] = add i32 %[[CNT]], 1
+//CHECK:   %[[IND1PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND1]]
+//CHECK:   %[[IND2:.+]] = sub nuw i32 %[[IND1]], %[[I]]
+//CHECK:   %[[IND2PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND2]]
+//CHECK:   %[[IND1VAL:.+]] = load i32, ptr %[[IND1PTR]], align 4
+//CHECK:   %[[IND2VAL:.+]] = load i32, ptr %[[IND2PTR]], align 4
+//CHECK:   %[[REDVAL:.+]] = add i32 %[[IND1VAL]], %[[IND2VAL]]
+//CHECK:   store i32 %[[REDVAL]], ptr %[[IND1PTR]], align 4
+//CHECK:   %[[CNTNXT]] = sub nuw i32 %[[CNT]], 1
+//CHECK:   %[[CMP3:.+]] = icmp uge i32 %[[CNTNXT]], %[[I]]
+//CHECK:   br i1 %[[CMP3]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+//CHECK: omp.inscan.dispatch:                              ; preds = %omp_loop.body
+//CHECK:   br i1 true, label %omp.before.scan.bb, label %omp.after.scan.bb
+//CHECK: omp.loop_nest.region:                             ; preds = %omp.before.scan.bb
+//CHECK:   %[[BUFFER:.+]] = load ptr, ptr %loadgep_vla, align 8
+//CHECK:   %[[ARRAYOFFSET2:.+]] = getelementptr inbounds i32, ptr %[[BUFFER]], i32 %{{.*}}
+//CHECK-NEXT:   %[[REDPRIVVAL:.+]] = load i32, ptr %{{.*}}, align 4
+//CHECK:   store i32 %[[REDPRIVVAL]], ptr %[[ARRAYOFFSET2]], align 4
+//CHECK:   br label %omp.scan.loop.exit
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 2fa4470bb8300..32784b14f5302 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -129,6 +129,37 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // -----
 
+omp.declare_reduction @add_f32 : f32
+init {
+^bb0(%arg: f32):
+  %0 = llvm.mlir.constant(0.0 : f32) : f32
+  omp.yield (%0 : f32)
+}
+combiner {
+^bb1(%arg0: f32, %arg1: f32):
+  %1 = llvm.fadd %arg0, %arg1 : f32
+  omp.yield (%1 : f32)
+}
+atomic {
+^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
+  %2 = llvm.load %arg3 : !llvm.ptr -> f32
+  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
+  omp.yield
+}
+llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
+  // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.simd operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.simd}}
+  omp.simd reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.scan inclusive(%prv : !llvm.ptr)
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// -----
+
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):
@@ -147,17 +178,20 @@ atomic {
   omp.yield
 }
 llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
   omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
+    // expected-error at below {{LLVM Translation failed for operation: omp.loop_nest}}
     omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.scan inclusive(%prv : !llvm.ptr)
+      // expected-error at below {{not yet implemented: Unhandled clause exclusive in omp.scan operation}} 
+      // expected-error at below {{LLVM Translation failed for operation: omp.scan}}
+      omp.scan exclusive(%prv : !llvm.ptr)
       omp.yield
     }
   }
   llvm.return
 }
 
+
 // -----
 
 llvm.func @single_allocate(%x : !llvm.ptr) {



More information about the Mlir-commits mailing list