[llvm] [mlir] scan lowering changes (PR #133149)

Anchu Rajendran S via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 4 14:58:48 PDT 2025


https://github.com/anchuraj updated https://github.com/llvm/llvm-project/pull/133149

>From 19b17c9d42bd0f56e51bc779cc8937a19970656c Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 25 Mar 2025 11:21:37 -0500
Subject: [PATCH 1/3] scan lowering changes

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  70 +++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 336 +++++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 247 ++++++++++---
 .../Target/LLVMIR/openmp-reduction-scan.mlir  | 124 +++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  31 --
 5 files changed, 717 insertions(+), 91 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..c07525efe7b69 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -30,6 +30,7 @@
 namespace llvm {
 class CanonicalLoopInfo;
 struct TargetRegionEntryInfo;
+class ScanInfo;
 class OffloadEntriesInfoManager;
 class OpenMPIRBuilder;
 
@@ -728,6 +729,11 @@ class OpenMPIRBuilder {
                       LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
                       const Twine &Name = "loop");
 
+  Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
+      const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+      Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+      InsertPointTy ComputeIP, const Twine &Name, bool InScan);
+
   /// Calculate the trip count of a canonical loop.
   ///
   /// This allows specifying user-defined loop counter values using increment,
@@ -800,10 +806,12 @@ class OpenMPIRBuilder {
   ///
   /// \returns An object representing the created control flow structure which
   ///          can be used for loop-associated directives.
-  Expected<CanonicalLoopInfo *> createCanonicalLoop(
-      const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
-      Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-      InsertPointTy ComputeIP = {}, const Twine &Name = "loop");
+  Expected<CanonicalLoopInfo *>
+  createCanonicalLoop(const LocationDescription &Loc,
+                      LoopBodyGenCallbackTy BodyGenCB, Value *Start,
+                      Value *Stop, Value *Step, bool IsSigned,
+                      bool InclusiveStop, InsertPointTy ComputeIP = {},
+                      const Twine &Name = "loop", bool InScan = false);
 
   /// Collapse a loop nest into a single loop.
   ///
@@ -1530,6 +1538,20 @@ class OpenMPIRBuilder {
       ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
       Function *ReduceFn, AttributeList FuncAttrs);
 
+  Expected<SmallVector<llvm::CanonicalLoopInfo *>> emitScanBasedDirectiveIR(
+      llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+      llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+          SecondGen);
+  llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+                                          ArrayRef<llvm::Value *> args,
+                                          const llvm::Twine &name);
+
+  void createScanBBs();
+  void emitScanBasedDirectiveDeclsIR(llvm::Value *span,
+                                     ArrayRef<llvm::Value *> ScanVars);
+  void emitScanBasedDirectiveFinalsIR(
+      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
   /// This function emits a helper that gathers Reduce lists from the first
   /// lane of every active warp to lanes in the first warp.
   ///
@@ -2177,7 +2199,6 @@ class OpenMPIRBuilder {
   // block, if possible, or else at the end of the function. Also add a branch
   // from current block to BB if current block does not have a terminator.
   void emitBlock(BasicBlock *BB, Function *CurFn, bool IsFinished = false);
-
   /// Emits code for OpenMP 'if' clause using specified \a BodyGenCallbackTy
   /// Here is the logic:
   /// if (Cond) {
@@ -2603,6 +2624,31 @@ class OpenMPIRBuilder {
                                     BodyGenCallbackTy BodyGenCB,
                                     FinalizeCallbackTy FiniCB, Value *Filter);
 
+  /// Generator for the scan reduction
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param FinalizeIP The IP where the reduction result needs
+  //                   to be copied back to original variable.
+  /// \param ReductionInfos Array type containing the ReductionOps.
+  ///
+  /// \returns The insertion position *after* the masked.
+  InsertPointOrErrorTy emitScanReduction(
+      const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
+  /// Generator for the scan reduction
+  ///
+  /// \param Loc The insert and source location description.
+  /// \param AllocaIP The IP where the temporary buffer for scan reduction
+  //                  needs to be allocated.
+  /// \param ScanVars Scan Variables.
+  /// \param IsInclusive Indicates if it is an inclusive or exclusive scan.
+  ///
+  /// \returns The insertion position *after* the masked.
+  InsertPointOrErrorTy createScan(const LocationDescription &Loc,
+                                  InsertPointTy AllocaIP,
+                                  ArrayRef<llvm::Value *> ScanVars,
+                                  bool IsInclusive);
   /// Generator for '#omp critical'
   ///
   /// \param Loc The insert and source location description.
@@ -3711,6 +3757,20 @@ class CanonicalLoopInfo {
   void invalidate();
 };
 
+class ScanInfo {
+public:
+  llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
+  llvm::BasicBlock *OMPAfterScanBlock = nullptr;
+  llvm::BasicBlock *OMPScanExitBlock = nullptr;
+  llvm::BasicBlock *OMPScanDispatch = nullptr;
+  llvm::BasicBlock *OMPScanLoopExit = nullptr;
+  bool OMPFirstScanLoop = false;
+  llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
+  llvm::Value *iv;
+  llvm::Value *span;
+  SmallVector<llvm::BasicBlock *> continueBlocks;
+};
+
 } // end namespace llvm
 
 #endif // LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..99e5c4e1282db 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -67,6 +67,8 @@
 using namespace llvm;
 using namespace omp;
 
+class ScanInfo scanInfo;
+
 static cl::opt<bool>
     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
                          cl::desc("Use optimistic attributes describing "
@@ -80,6 +82,7 @@ static cl::opt<double> UnrollThresholdFactor(
     cl::init(1.5));
 
 #ifndef NDEBUG
+
 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
 /// an InsertPoint stores the instruction before something is inserted. For
@@ -3918,6 +3921,270 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
                               /*Conditional*/ true, /*hasFinalize*/ true);
 }
 
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+                                         ArrayRef<llvm::Value *> args,
+                                         const llvm::Twine &name) {
+  llvm::CallInst *call = Builder.CreateCall(
+      callee, args, SmallVector<llvm::OperandBundleDef, 1>(), name);
+  call->setDoesNotThrow();
+  return call;
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    ArrayRef<llvm::Value *> ScanVars, bool IsInclusive) {
+  if (scanInfo.OMPFirstScanLoop) {
+    Builder.restoreIP(AllocaIP);
+    emitScanBasedDirectiveDeclsIR(scanInfo.span, ScanVars);
+  }
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+  llvm::Value *iv = scanInfo.iv;
+
+  if (scanInfo.OMPFirstScanLoop) {
+    // Emit buffer[i] = red; at the end of the input phase.
+    for (int i = 0; i < ScanVars.size(); i++) {
+      auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+
+      auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+      auto val = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+      auto src = Builder.CreateLoad(destTy, ScanVars[i]);
+      auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          val, destTy->getPointerTo(defaultAS));
+
+      Builder.CreateStore(src, dest);
+    }
+  }
+  Builder.CreateBr(scanInfo.OMPScanLoopExit);
+  llvm::LLVMContext &llvmContext = Builder.getContext();
+  Builder.SetInsertPoint(scanInfo.OMPScanDispatch);
+
+  ConstantInt *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
+  for (int i = 0; i < ScanVars.size(); i++) {
+    auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+    auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ScanVars[i], destTy->getPointerTo(defaultAS));
+    Builder.CreateStore(Zero, dest);
+  }
+
+  if (!scanInfo.OMPFirstScanLoop) {
+    iv = scanInfo.iv;
+    // Emit red = buffer[i]; at the entrance to the scan phase.
+    for (int i = 0; i < ScanVars.size(); i++) {
+      // x = buffer[i]
+      auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+      auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+      auto newVPtr = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+      auto newV = Builder.CreateLoad(destTy, newVPtr);
+      auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ScanVars[i], destTy->getPointerTo(defaultAS));
+
+      Builder.CreateStore(newV, dest);
+    }
+  }
+  llvm::Value *testCondVal1 =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/100);
+  llvm::Value *testCondVal2 =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
+  llvm::Value *CmpI = Builder.CreateICmpUGE(testCondVal1, testCondVal2);
+  if (scanInfo.OMPFirstScanLoop == IsInclusive) {
+    Builder.CreateCondBr(CmpI, scanInfo.OMPBeforeScanBlock,
+                         scanInfo.OMPAfterScanBlock);
+  } else {
+    Builder.CreateCondBr(CmpI, scanInfo.OMPAfterScanBlock,
+                         scanInfo.OMPBeforeScanBlock);
+  }
+  emitBlock(scanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(scanInfo.OMPAfterScanBlock);
+  return Builder.saveIP();
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+    llvm::Value *span, ArrayRef<llvm::Value *> ScanVars) {
+
+  ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+  llvm::Value *allocSpan = Builder.CreateAdd(span, One);
+  for (auto &scanVar : ScanVars) {
+    // TODO: remove after all users of by-ref are updated to use the alloc
+    // region: Allocate reduction variable (which is a pointer to the real
+    // reduciton variable allocated in the inlined region)
+    llvm::Value *buff =
+        Builder.CreateAlloca(Builder.getInt32Ty(), allocSpan, "vla");
+    scanInfo.ReductionVarToScanBuffs[scanVar] = buff;
+  }
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+  llvm::Value *span = scanInfo.span;
+  // llvm::Value *OMPLast = span;
+  llvm::Value *OMPLast = Builder.CreateNSWAdd(
+      span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+  // llvm::Value *OMPLast = Builder.CreateNSWSub(
+  //     span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+  for (int i = 0; i < reductionInfos.size(); i++) {
+    auto privateVar = reductionInfos[i].PrivateVariable;
+    auto origVar = reductionInfos[i].Variable;
+    auto buff = scanInfo.ReductionVarToScanBuffs[privateVar];
+
+    auto srcTy = reductionInfos[i].ElementType;
+    auto val = Builder.CreateInBoundsGEP(srcTy, buff, OMPLast, "arrayOffset");
+    auto src = Builder.CreateLoad(srcTy, val);
+    auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        origVar, srcTy->getPointerTo(defaultAS));
+
+    Builder.CreateStore(src, dest);
+  }
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+    const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+
+  llvm::Value *spanDiff = scanInfo.span;
+
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+  auto curFn = Builder.GetInsertBlock()->getParent();
+  // for (int k = 0; k <= ceil(log2(n)); ++k)
+  llvm::BasicBlock *LoopBB =
+      BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.body");
+  llvm::BasicBlock *ExitBB =
+      BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.exit");
+  llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+      Builder.GetInsertBlock()->getModule(),
+      (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+  llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+  ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+  llvm::Value *span = Builder.CreateAdd(spanDiff, One);
+  llvm::Value *Arg = Builder.CreateUIToFP(span, Builder.getDoubleTy());
+  llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+  F = llvm::Intrinsic::getOrInsertDeclaration(
+      Builder.GetInsertBlock()->getModule(),
+      (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+  LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+  LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+  llvm::Value *NMin1 =
+      Builder.CreateNUWSub(span, llvm::ConstantInt::get(span->getType(), 1));
+  Builder.SetInsertPoint(InputBB);
+  Builder.CreateBr(LoopBB);
+  emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(LoopBB);
+
+  auto *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  //// size pow2k = 1;
+  auto *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+                       InputBB);
+  Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), InputBB);
+  //// for (size i = n - 1; i >= 2 ^ k; --i)
+  ////   tmp[i] op= tmp[i-pow2k];
+  llvm::BasicBlock *InnerLoopBB =
+      BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.body");
+  llvm::BasicBlock *InnerExitBB =
+      BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.exit");
+  llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+  Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+  emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(InnerLoopBB);
+  auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  IVal->addIncoming(NMin1, LoopBB);
+  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+  for (int i = 0; i < reductionInfos.size(); i++) {
+    auto &reductionVal = reductionInfos[i].PrivateVariable;
+    auto buff = scanInfo.ReductionVarToScanBuffs[reductionVal];
+    auto destTy = reductionInfos[i].ElementType;
+    Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+    auto lhsPtr = Builder.CreateInBoundsGEP(destTy, buff, IV, "arrayOffset");
+    auto offsetIval = Builder.CreateNUWSub(IV, Pow2K);
+    auto rhsPtr =
+        Builder.CreateInBoundsGEP(destTy, buff, offsetIval, "arrayOffset");
+    auto lhs = Builder.CreateLoad(destTy, lhsPtr);
+    auto rhs = Builder.CreateLoad(destTy, rhsPtr);
+    auto lhsAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        lhsPtr, rhs->getType()->getPointerTo(defaultAS));
+    llvm::Value *result;
+    InsertPointOrErrorTy AfterIP =
+        reductionInfos[i].ReductionGen(Builder.saveIP(), lhs, rhs, result);
+    if (!AfterIP)
+      return AfterIP.takeError();
+    Builder.CreateStore(result, lhsAddr);
+  }
+  llvm::Value *NextIVal = Builder.CreateNUWSub(
+      IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+  IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+  CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+  Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+  emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent());
+  llvm::Value *Next = Builder.CreateNUWAdd(
+      Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+  Counter->addIncoming(Next, Builder.GetInsertBlock());
+  // pow2k <<= 1;
+  llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+  Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+  llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+  Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+  emitBlock(ExitBB, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(ExitBB);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+      createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+  Builder.restoreIP(FinalizeIP);
+  emitScanBasedDirectiveFinalsIR(reductionInfos);
+  FinalizeIP = Builder.saveIP();
+
+  return afterIP;
+}
+
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::emitScanBasedDirectiveIR(
+    llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+    llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+        SecondGen) {
+
+  SmallVector<llvm::CanonicalLoopInfo *> ret;
+  {
+    // Emit loop with input phase:
+    // #pragma omp ...
+    // for (i: 0..<num_iters>) {
+    //   <input phase>;
+    //   buffer[i] = red;
+    // }
+    scanInfo.OMPFirstScanLoop = true;
+    auto result = FirstGen();
+    if (result.takeError())
+      return result.takeError();
+    Builder.restoreIP((*result)->getAfterIP());
+    ret.push_back(*result);
+  }
+  {
+    scanInfo.OMPFirstScanLoop = false;
+    auto result = SecondGen(Builder.saveIP());
+    if (result.takeError())
+      return result.takeError();
+    Builder.restoreIP((*result)->getAfterIP());
+    ret.push_back(*result);
+  }
+  return ret;
+}
+
+void OpenMPIRBuilder::createScanBBs() {
+  auto fun = Builder.GetInsertBlock()->getParent();
+  scanInfo.OMPScanExitBlock =
+      BasicBlock::Create(fun->getContext(), "omp.exit.inscan.bb");
+  scanInfo.OMPScanDispatch =
+      BasicBlock::Create(fun->getContext(), "omp.inscan.dispatch");
+  scanInfo.OMPAfterScanBlock =
+      BasicBlock::Create(fun->getContext(), "omp.after.scan.bb");
+  scanInfo.OMPBeforeScanBlock =
+      BasicBlock::Create(fun->getContext(), "omp.before.scan.bb");
+  scanInfo.OMPScanLoopExit =
+      BasicBlock::Create(fun->getContext(), "omp.scan.loop.exit");
+}
+
 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
     BasicBlock *PostInsertBefore, const Twine &Name) {
@@ -4015,10 +4282,72 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
   return CL;
 }
 
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::createCanonicalScanLoops(
+    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+    InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
+  auto *IndVarTy = cast<IntegerType>(Start->getType());
+  assert(IndVarTy == Stop->getType() && "Stop type mismatch");
+  assert(IndVarTy == Step->getType() && "Step type mismatch");
+  LocationDescription ComputeLoc =
+      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+  updateToLocation(ComputeLoc);
+
+  ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
+
+  // Distance between Start and Stop; always positive.
+  Value *Span;
+
+  // Condition whether there are no iterations are executed at all, e.g. because
+  // UB < LB.
+
+  if (IsSigned) {
+    // Ensure that increment is positive. If not, negate and invert LB and UB.
+    Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
+    Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
+    Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
+    Span = Builder.CreateSub(UB, LB, "", false, true);
+  } else {
+    Span = Builder.CreateSub(Stop, Start, "", true);
+  }
+  auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
+    if (InScan) {
+      scanInfo.iv = IV;
+      createScanBBs();
+      auto terminator = Builder.GetInsertBlock()->getTerminator();
+      scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
+      terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
+      emitBlock(scanInfo.OMPBeforeScanBlock,
+                Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(scanInfo.OMPScanLoopExit);
+
+      emitBlock(scanInfo.OMPScanLoopExit,
+                Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(scanInfo.continueBlocks.back());
+      emitBlock(scanInfo.OMPScanDispatch,
+                Builder.GetInsertBlock()->getParent());
+      Builder.SetInsertPoint(
+          scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
+    }
+    return BodyGenCB(Builder.saveIP(), IV);
+  };
+  const auto &&FirstGen = [&]() {
+    return createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+                               InclusiveStop, ComputeIP, Name, true);
+  };
+  const auto &&SecondGen = [&](LocationDescription loc) {
+    return createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
+                               InclusiveStop, ComputeIP, Name, true);
+  };
+  scanInfo.span = Span;
+  auto result = emitScanBasedDirectiveIR(FirstGen, SecondGen);
+  return result;
+}
+
 Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
     const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
     bool IsSigned, bool InclusiveStop, const Twine &Name) {
-
   // Consider the following difficulties (assuming 8-bit signed integers):
   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
   //      DO I = 1, 100, 50
@@ -4078,7 +4407,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
 Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-    InsertPointTy ComputeIP, const Twine &Name) {
+    InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
   LocationDescription ComputeLoc =
       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
 
@@ -4089,6 +4418,9 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     Builder.restoreIP(CodeGenIP);
     Value *Span = Builder.CreateMul(IV, Step);
     Value *IndVar = Builder.CreateAdd(Span, Start);
+    if (InScan) {
+      scanInfo.iv = IndVar;
+    }
     return BodyGenCB(Builder.saveIP(), IndVar);
   };
   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..6ada71977a6a4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -86,7 +86,9 @@ class OpenMPLoopInfoStackFrame
     : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+  // For constructs like scan, one Loop info frame can contain multiple
+  // Canonical Loops
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
 };
 
 /// Custom error class to signal translation errors that don't need reporting,
@@ -169,6 +171,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDistScheduleChunkSize())
       result = todo("dist_schedule with chunk_size");
   };
+  auto checkExclusive = [&todo](auto op, LogicalResult &result) {
+    if (!op.getExclusiveVars().empty())
+      result = todo("exclusive");
+  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -232,8 +238,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           op.getReductionSyms())
         result = todo("reduction");
     if (op.getReductionMod() &&
-        op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
-      result = todo("reduction with modifier");
+        op.getReductionMod().value() == omp::ReductionModifier::task)
+      result = todo("reduction with task modifier");
   };
   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -253,6 +259,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
+      .Case([&](omp::ScanOp op) { checkExclusive(op, result); })
       .Case([&](omp::SectionsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
@@ -383,15 +390,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
 /// Find the loop information structure for the loop nest being translated. It
 /// will return a `null` value unless called from the translation function for
 /// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        loopInfo = frame.loopInfo;
+        loopInfos = frame.loopInfos;
         return WalkResult::interrupt();
       });
-  return loopInfo;
+  return loopInfos;
 }
 
 /// Converts the given region that appears within an OpenMP dialect operation to
@@ -2258,26 +2265,61 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-
-  llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-      ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-          convertToScheduleKind(schedule), chunk, isSimd,
-          scheduleMod == omp::ScheduleModifier::monotonic,
-          scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType);
-
-  if (failed(handleError(wsloopIP, opInst)))
-    return failure();
-
-  // Process the reductions if required.
-  if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
-                                        allocaIP, reductionDecls,
-                                        privateReductionVariables, isByRef)))
-    return failure();
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+  auto beforeLoopFinishIp = loopInfos.front()->getAfterIP();
+  auto afterLoopFinishIp = loopInfos.back()->getAfterIP();
+  bool isInScanRegion =
+      wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                     mlir::omp::ReductionModifier::inscan);
+  if (isInScanRegion) {
+    builder.restoreIP(beforeLoopFinishIp);
+    SmallVector<OwningReductionGen> owningReductionGens;
+    SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+    collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+                         owningReductionGens, owningAtomicReductionGens,
+                         privateReductionVariables, reductionInfos);
+    llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+        ompBuilder->emitScanReduction(builder.saveIP(), afterLoopFinishIp,
+                                      reductionInfos);
+    if (failed(handleError(redIP, opInst)))
+      return failure();
 
+    builder.restoreIP(*redIP);
+    builder.CreateBr(cont);
+  }
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+        ompBuilder->applyWorkshareLoop(
+            ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+            convertToScheduleKind(schedule), chunk, isSimd,
+            scheduleMod == omp::ScheduleModifier::monotonic,
+            scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+            workshareLoopType);
+
+    if (failed(handleError(wsloopIP, opInst)))
+      return failure();
+  }
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  if (isInScanRegion) {
+    SmallVector<Region *> reductionRegions;
+    llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+                    [](omp::DeclareReductionOp reductionDecl) {
+                      return &reductionDecl.getCleanupRegion();
+                    });
+    if (failed(inlineOmpRegionCleanup(
+            reductionRegions, privateReductionVariables, moduleTranslation,
+            builder, "omp.reduction.cleanup")))
+      return failure();
+  } else {
+    // Process the reductions if required.
+    if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
+                                          allocaIP, reductionDecls,
+                                          privateReductionVariables, isByRef)))
+      return failure();
+  }
   return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
                             privateVarsInfo.llvmVars,
                             privateVarsInfo.privatizers);
@@ -2467,6 +2509,60 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
   llvm_unreachable("Unknown ClauseOrderKind kind");
 }
 
+static LogicalResult
+convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
+               LLVM::ModuleTranslation &moduleTranslation) {
+  if (failed(checkImplementationStatus(opInst)))
+    return failure();
+  auto scanOp = cast<omp::ScanOp>(opInst);
+  bool isInclusive = scanOp.hasInclusiveVars();
+  SmallVector<llvm::Value *> llvmScanVars;
+  mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
+  if (!isInclusive)
+    mlirScanVars = scanOp.getExclusiveVars();
+  for (auto val : mlirScanVars) {
+    llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
+
+    llvmScanVars.push_back(llvmVal);
+  }
+  llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+      findAllocaInsertPoint(builder, moduleTranslation);
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+      moduleTranslation.getOpenMPBuilder()->createScan(
+          ompLoc, allocaIP, llvmScanVars, isInclusive);
+  if (failed(handleError(afterIP, opInst)))
+    return failure();
+
+  builder.restoreIP(*afterIP);
+
+  // TODO: The argument of LoopnestOp is stored into the index variable and this
+  // variable is used
+  //  across scan operation. However that makes the mlir
+  //  invalid.(`Intra-iteration dependences from a statement in the structured
+  //  block sequence that precede a scan directive to a statement in the
+  //  structured block sequence that follows a scan directive must not exist,
+  //  except for dependences for the list items specified in an inclusive or
+  //  exclusive clause.`). The argument of LoopNestOp need to be loaded again
+  //  after ScanOp again so mlir generated is valid.
+  auto parentOp = scanOp->getParentOp();
+  auto loopOp = cast<omp::LoopNestOp>(parentOp);
+  if (loopOp) {
+    auto &firstBlock = *(scanOp->getParentRegion()->getBlocks()).begin();
+    auto &ins = *(firstBlock.begin());
+    if (isa<LLVM::StoreOp>(ins)) {
+      LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
+      auto src = moduleTranslation.lookupValue(storeOp->getOperand(0));
+      if (src == moduleTranslation.lookupValue(
+                     (loopOp.getRegion().getArguments())[0])) {
+        auto dest = moduleTranslation.lookupValue(storeOp->getOperand(1));
+        builder.CreateStore(src, dest);
+      }
+    }
+  }
+  return success();
+}
+
 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2540,12 +2636,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-  ompBuilder->applySimd(loopInfo, alignedVars,
-                        simdOp.getIfExpr()
-                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
-                            : nullptr,
-                        order, simdlen, safelen);
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    ompBuilder->applySimd(
+        loopInfo, alignedVars,
+        simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                           : nullptr,
+        order, simdlen, safelen);
+  }
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
                             privateVarsInfo.llvmVars,
@@ -2612,16 +2711,52 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
                                                        ompLoc.DL);
       computeIP = loopInfos.front()->getPreheaderIP();
     }
+    if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>()) {
+      bool isInScanRegion =
+          wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                         mlir::omp::ReductionModifier::inscan);
+      // TODOSCAN: Take care of loop and add asserts if required
+      if (isInScanRegion) {
+        llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
+            ompBuilder->createCanonicalScanLoops(
+                loc, bodyGen, lowerBound, upperBound, step,
+                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
+                isInScanRegion);
+
+        if (failed(handleError(loopResults, *loopOp)))
+          return failure();
+        auto beforeLoop = loopResults->front();
+        auto afterLoop = loopResults->back();
+        moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+            [&](OpenMPLoopInfoStackFrame &frame) {
+              frame.loopInfos.push_back(beforeLoop);
+              frame.loopInfos.push_back(afterLoop);
+              return WalkResult::interrupt();
+            });
+        builder.restoreIP(afterLoop->getAfterIP());
+        return success();
+      } else {
+        llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+            ompBuilder->createCanonicalLoop(
+                loc, bodyGen, lowerBound, upperBound, step,
+                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
 
-    llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
-        ompBuilder->createCanonicalLoop(
-            loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+        if (failed(handleError(loopResult, *loopOp)))
+          return failure();
 
-    if (failed(handleError(loopResult, *loopOp)))
-      return failure();
+        loopInfos.push_back(*loopResult);
+      }
+    } else {
+      llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+          ompBuilder->createCanonicalLoop(
+              loc, bodyGen, lowerBound, upperBound, step,
+              /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+
+      if (failed(handleError(loopResult, *loopOp)))
+        return failure();
 
-    loopInfos.push_back(*loopResult);
+      loopInfos.push_back(*loopResult);
+    }
   }
 
   // Collapse loops. Store the insertion point because LoopInfos may get
@@ -2633,7 +2768,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
   // after applying transformations.
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+        frame.loopInfos.push_back(
+            ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}));
         return WalkResult::interrupt();
       });
 
@@ -4212,18 +4348,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
       bool loopNeedsBarrier = false;
       llvm::Value *chunk = nullptr;
 
-      llvm::CanonicalLoopInfo *loopInfo =
-          findCurrentLoopInfo(moduleTranslation);
-      llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-          ompBuilder->applyWorkshareLoop(
-              ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-              convertToScheduleKind(schedule), chunk, isSimd,
-              scheduleMod == omp::ScheduleModifier::monotonic,
-              scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType);
-
-      if (!wsloopIP)
-        return wsloopIP.takeError();
+      SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+          findCurrentLoopInfos(moduleTranslation);
+      for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+        llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+            ompBuilder->applyWorkshareLoop(
+                ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+                convertToScheduleKind(schedule), chunk, isSimd,
+                scheduleMod == omp::ScheduleModifier::monotonic,
+                scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+                workshareLoopType);
+
+        if (!wsloopIP)
+          return wsloopIP.takeError();
+      }
     }
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation,
@@ -5202,6 +5340,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
           .Case([&](omp::WsloopOp) {
             return convertOmpWsloop(*op, builder, moduleTranslation);
           })
+          .Case([&](omp::ScanOp) {
+            return convertOmpScan(*op, builder, moduleTranslation);
+          })
           .Case([&](omp::SimdOp) {
             return convertOmpSimd(*op, builder, moduleTranslation);
           })
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
new file mode 100644
index 0000000000000..1d8bbbef3e9b3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
@@ -0,0 +1,124 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// nonesense minimised code simulating the control flow graph generated by flang
+// for array reductions. The important thing here is that we are testing a byref
+// reduction with a cleanup region, and the various regions contain multiple
+// blocks
+omp.declare_reduction @add_reduction_i32 : i32 init {
+^bb0(%arg0: i32):
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  omp.yield(%0 : i32)
+} combiner {
+^bb0(%arg0: i32, %arg1: i32):
+  %0 = llvm.add %arg0, %arg1 : i32
+  omp.yield(%0 : i32)
+}
+// CHECK-LABEL: @scan_reduction
+llvm.func @scan_reduction() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "z"} : (i64) -> !llvm.ptr
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
+  %4 = llvm.mlir.constant(1 : i64) : i64
+  %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  %6 = llvm.mlir.constant(1 : i64) : i64
+  %7 = llvm.alloca %6 x i32 {bindc_name = "k"} : (i64) -> !llvm.ptr
+  %8 = llvm.mlir.constant(0 : index) : i64
+  %9 = llvm.mlir.constant(1 : index) : i64
+  %10 = llvm.mlir.constant(100 : i32) : i32
+  %11 = llvm.mlir.constant(1 : i32) : i32
+  %12 = llvm.mlir.constant(0 : i32) : i32
+  %13 = llvm.mlir.constant(100 : index) : i64
+  %14 = llvm.mlir.addressof @_QFEa : !llvm.ptr
+  %15 = llvm.mlir.addressof @_QFEb : !llvm.ptr
+  omp.parallel {
+    %37 = llvm.mlir.constant(1 : i64) : i64
+    %38 = llvm.alloca %37 x i32 {bindc_name = "k", pinned} : (i64) -> !llvm.ptr
+    %39 = llvm.mlir.constant(1 : i64) : i64
+    omp.wsloop reduction(mod: inscan, @add_reduction_i32 %5 -> %arg0 : !llvm.ptr) {
+      omp.loop_nest (%arg1) : i32 = (%11) to (%10) inclusive step (%11) {
+        llvm.store %arg1, %38 : i32, !llvm.ptr
+        %40 = llvm.load %arg0 : !llvm.ptr -> i32
+        %41 = llvm.load %38 : !llvm.ptr -> i32
+        %42 = llvm.sext %41 : i32 to i64
+        %50 = llvm.getelementptr %14[%42] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+        %51 = llvm.load %50 : !llvm.ptr -> i32
+        %52 = llvm.add %40, %51 : i32
+        llvm.store %52, %arg0 : i32, !llvm.ptr
+        omp.scan inclusive(%arg0 : !llvm.ptr)
+        %53 = llvm.load %arg0 : !llvm.ptr -> i32
+        %54 = llvm.load %38 : !llvm.ptr -> i32
+        %55 = llvm.sext %54 : i32 to i64
+        %63 = llvm.getelementptr %15[%55] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+        llvm.store %53, %63 : i32, !llvm.ptr
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+  %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+  llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal @_QFEb() {addr_space = 0 : i32} : !llvm.array<100 x i32> {
+  %0 = llvm.mlir.zero : !llvm.array<100 x i32>
+  llvm.return %0 : !llvm.array<100 x i32>
+}
+llvm.mlir.global internal constant @_QFECn() {addr_space = 0 : i32} : i32 {
+  %0 = llvm.mlir.constant(100 : i32) : i32
+  llvm.return %0 : i32
+}
+//CHECK: %[[BUFF:.+]] = alloca i32, i32 100, align 4
+//CHECK: omp_loop.preheader{{.*}}:                              ; preds = %omp.wsloop.region
+//CHECK: omp_loop.after:                                   ; preds = %omp_loop.exit
+//CHECK:   %[[LOG:.+]] = call double @llvm.log2.f64(double 1.000000e+02) #0
+//CHECK:   %[[CEIL:.+]] = call double @llvm.ceil.f64(double %[[LOG]]) #0
+//CHECK:   %[[UB:.+]] = fptoui double %[[CEIL]] to i32
+//CHECK:   br label %omp.outer.log.scan.body
+//CHECK: omp.outer.log.scan.body:                          ; preds = %omp.inner.log.scan.exit, %omp_loop.after
+//CHECK:   %[[K:.+]] = phi i32 [ 0, %omp_loop.after ], [ %[[NEXTK:.+]], %omp.inner.log.scan.exit ]
+//CHECK:   %[[I:.+]] = phi i32 [ 1, %omp_loop.after ], [ %[[NEXTI:.+]], %omp.inner.log.scan.exit ]
+//CHECK:   %[[CMP1:.+]] = icmp uge i32 99, %[[I]]
+//CHECK:   br i1 %[[CMP1]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+//CHECK: omp.inner.log.scan.exit:                          ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+//CHECK:   %[[NEXTK]] = add nuw i32 %[[K]], 1
+//CHECK:   %[[NEXTI]] = shl nuw i32 %[[I]], 1
+//CHECK:   %[[CMP2:.+]] = icmp ne i32 %[[NEXTK]], %[[UB]]
+//CHECK:   br i1 %[[CMP2]], label %omp.outer.log.scan.body, label %omp.outer.log.scan.exit
+//CHECK: omp.outer.log.scan.exit:                          ; preds = %omp.inner.log.scan.exit
+//CHECK:   call void @__kmpc_barrier{{.*}}
+//CHECK:   br label %omp.scan.loop.cont
+//CHECK: omp.scan.loop.cont:                               ; preds = %omp.outer.log.scan.exit
+//CHECK:   br label %omp_loop.preheader{{.*}}
+//CHECK: omp_loop.after{{.*}}:                                 ; preds = %omp_loop.exit{{.*}}
+//CHECK:  %[[ARRLAST:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 100
+//CHECK:  %[[RES:.+]] = load i32, ptr %[[ARRLAST]], align 4
+//CHECK:  store i32 %[[RES]], ptr %loadgep{{.*}}, align 4
+//CHECK: omp.inscan.dispatch{{.*}}:                            ; preds = %omp_loop.body{{.*}}
+//CHECK:   store i32 0, ptr %[[REDPRIV:.+]], align 4
+//CHECK:   %[[arrayOffset1:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %{{.*}}
+//CHECK:   %[[BUFFVAL1:.+]] = load i32, ptr %[[arrayOffset1]], align 4
+//CHECK:   store i32 %[[BUFFVAL1]], ptr %[[REDPRIV]], align 4
+//CHECK: omp.inner.log.scan.body:                          ; preds = %omp.inner.log.scan.body, %omp.outer.log.scan.body
+//CHECK:   %[[CNT:.+]] = phi i32 [ 99, %omp.outer.log.scan.body ], [ %[[CNTNXT:.+]], %omp.inner.log.scan.body ]
+//CHECK:   %[[IND1:.+]] = add i32 %[[CNT]], 1
+//CHECK:   %[[IND1PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND1]]
+//CHECK:   %[[IND2:.+]] = sub nuw i32 %[[IND1]], %[[I]]
+//CHECK:   %[[IND2PTR:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %[[IND2]]
+//CHECK:   %[[IND1VAL:.+]] = load i32, ptr %[[IND1PTR]], align 4
+//CHECK:   %[[IND2VAL:.+]] = load i32, ptr %[[IND2PTR]], align 4
+//CHECK:   %[[REDVAL:.+]] = add i32 %[[IND1VAL]], %[[IND2VAL]]
+//CHECK:   store i32 %[[REDVAL]], ptr %[[IND1PTR]], align 4
+//CHECK:   %[[CNTNXT]] = sub nuw i32 %[[CNT]], 1
+//CHECK:   %[[CMP3:.+]] = icmp uge i32 %[[CNTNXT]], %[[I]]
+//CHECK:   br i1 %[[CMP3]], label %omp.inner.log.scan.body, label %omp.inner.log.scan.exit
+//CHECK: omp.inscan.dispatch:                              ; preds = %omp_loop.body
+//CHECK:   store i32 0, ptr %[[REDPRIV]], align 4
+//CHECK:   br i1 true, label %omp.before.scan.bb, label %omp.after.scan.bb
+//CHECK: omp.loop_nest.region:                             ; preds = %omp.before.scan.bb
+//CHECK:   %[[ARRAYOFFSET2:.+]] = getelementptr inbounds i32, ptr %[[BUFF]], i32 %{{.*}}
+//CHECK:   %[[REDPRIVVAL:.+]] = load i32, ptr %[[REDPRIV]], align 4
+//CHECK:   store i32 %[[REDPRIVVAL]], ptr %[[ARRAYOFFSET2]], align 4
+//CHECK:   br label %omp.scan.loop.exit
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index af31f8bab73ac..2b5d6aedff39e 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -212,37 +212,6 @@ llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // -----
 
-omp.declare_reduction @add_f32 : f32
-init {
-^bb0(%arg: f32):
-  %0 = llvm.mlir.constant(0.0 : f32) : f32
-  omp.yield (%0 : f32)
-}
-combiner {
-^bb1(%arg0: f32, %arg1: f32):
-  %1 = llvm.fadd %arg0, %arg1 : f32
-  omp.yield (%1 : f32)
-}
-atomic {
-^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
-  %2 = llvm.load %arg3 : !llvm.ptr -> f32
-  llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
-  omp.yield
-}
-llvm.func @scan_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause reduction with modifier in omp.wsloop operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
-  omp.wsloop reduction(mod:inscan, @add_f32 %x -> %prv : !llvm.ptr) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.scan inclusive(%prv : !llvm.ptr)
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @single_allocate(%x : !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause allocate in omp.single operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.single}}

>From df10e479633f90806de85875652a98989174c43f Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Wed, 2 Apr 2025 13:32:07 -0500
Subject: [PATCH 2/3] few additions

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  38 +++----
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 102 ++++++++++--------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  50 +++++++++
 llvm/utils/lit/setup.py                       |  46 --------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |   3 +-
 5 files changed, 126 insertions(+), 113 deletions(-)
 delete mode 100644 llvm/utils/lit/setup.py

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index c07525efe7b69..c217e6193e84c 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -503,6 +503,22 @@ class OpenMPIRBuilder {
       return allocaInst;
     }
   };
+  class ScanInfo {
+  public:
+    llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
+    llvm::BasicBlock *OMPAfterScanBlock = nullptr;
+    llvm::BasicBlock *OMPScanExitBlock = nullptr;
+    llvm::BasicBlock *OMPScanDispatch = nullptr;
+    llvm::BasicBlock *OMPScanLoopExit = nullptr;
+    bool OMPFirstScanLoop = false;
+    llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
+    llvm::Value *iv;
+    llvm::Value *span;
+    SmallVector<llvm::BasicBlock *> continueBlocks;
+  };
+  
+  class ScanInfo scanInfo;
+
   /// Initialize the internal state, this will put structures types and
   /// potentially other helpers into the underlying module. Must be called
   /// before any other method and only once! This internal state includes types
@@ -732,7 +748,7 @@ class OpenMPIRBuilder {
   Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
       const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
       Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-      InsertPointTy ComputeIP, const Twine &Name, bool InScan);
+      InsertPointTy ComputeIP, const Twine &Name);
 
   /// Calculate the trip count of a canonical loop.
   ///
@@ -1538,9 +1554,9 @@ class OpenMPIRBuilder {
       ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
       Function *ReduceFn, AttributeList FuncAttrs);
 
-  Expected<SmallVector<llvm::CanonicalLoopInfo *>> emitScanBasedDirectiveIR(
-      llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
-      llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+  Error emitScanBasedDirectiveIR(
+      llvm::function_ref<Error ()> FirstGen,
+      llvm::function_ref<Error (LocationDescription loc)>
           SecondGen);
   llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
                                           ArrayRef<llvm::Value *> args,
@@ -3757,20 +3773,6 @@ class CanonicalLoopInfo {
   void invalidate();
 };
 
-class ScanInfo {
-public:
-  llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
-  llvm::BasicBlock *OMPAfterScanBlock = nullptr;
-  llvm::BasicBlock *OMPScanExitBlock = nullptr;
-  llvm::BasicBlock *OMPScanDispatch = nullptr;
-  llvm::BasicBlock *OMPScanLoopExit = nullptr;
-  bool OMPFirstScanLoop = false;
-  llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
-  llvm::Value *iv;
-  llvm::Value *span;
-  SmallVector<llvm::BasicBlock *> continueBlocks;
-};
-
 } // end namespace llvm
 
 #endif // LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 99e5c4e1282db..52ad448ad653f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -67,8 +67,6 @@
 using namespace llvm;
 using namespace omp;
 
-class ScanInfo scanInfo;
-
 static cl::opt<bool>
     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
                          cl::desc("Use optimistic attributes describing "
@@ -741,7 +739,6 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
                             /* AllowAlloca */ true,
                             /* AllocaBlock*/ OI.OuterAllocaBB,
                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
-
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
                       << " Exit: " << OI.ExitBB->getName() << "\n");
@@ -3931,6 +3928,12 @@ OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
   return call;
 }
 
+//Expects current basic block is dominated by BeforeScanBB.
+//Once Scan directive is encountered, the code after scan block should be
+//dominated by AfterScanBB. Scan directive splits the code sequence to 
+//before and after parts. Based on whether inclusive or exclusive 
+//clause is used in the scan directive, it adds jumps to input and 
+//scan phase in the first and second loops. 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     ArrayRef<llvm::Value *> ScanVars, bool IsInclusive) {
@@ -3973,7 +3976,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
     iv = scanInfo.iv;
     // Emit red = buffer[i]; at the entrance to the scan phase.
     for (int i = 0; i < ScanVars.size(); i++) {
-      // x = buffer[i]
       auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
       auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
       auto newVPtr = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
@@ -3984,17 +3986,14 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
       Builder.CreateStore(newV, dest);
     }
   }
-  llvm::Value *testCondVal1 =
-      llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/100);
-  llvm::Value *testCondVal2 =
-      llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
-  llvm::Value *CmpI = Builder.CreateICmpUGE(testCondVal1, testCondVal2);
+
+  llvm::Value *CmpI = Builder.CreateICmpUGE(Builder.getInt32(100), Builder.getInt32(0));
   if (scanInfo.OMPFirstScanLoop == IsInclusive) {
     Builder.CreateCondBr(CmpI, scanInfo.OMPBeforeScanBlock,
-                         scanInfo.OMPAfterScanBlock);
+                       scanInfo.OMPAfterScanBlock);
   } else {
     Builder.CreateCondBr(CmpI, scanInfo.OMPAfterScanBlock,
-                         scanInfo.OMPBeforeScanBlock);
+                       scanInfo.OMPBeforeScanBlock);
   }
   emitBlock(scanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
   Builder.SetInsertPoint(scanInfo.OMPAfterScanBlock);
@@ -4139,10 +4138,10 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
   return afterIP;
 }
 
-Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+Error
 OpenMPIRBuilder::emitScanBasedDirectiveIR(
-    llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
-    llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+    llvm::function_ref<Error ()> FirstGen,
+    llvm::function_ref<Error (LocationDescription loc)>
         SecondGen) {
 
   SmallVector<llvm::CanonicalLoopInfo *> ret;
@@ -4155,20 +4154,16 @@ OpenMPIRBuilder::emitScanBasedDirectiveIR(
     // }
     scanInfo.OMPFirstScanLoop = true;
     auto result = FirstGen();
-    if (result.takeError())
-      return result.takeError();
-    Builder.restoreIP((*result)->getAfterIP());
-    ret.push_back(*result);
+    if (result)
+      return result;
   }
   {
     scanInfo.OMPFirstScanLoop = false;
     auto result = SecondGen(Builder.saveIP());
-    if (result.takeError())
-      return result.takeError();
-    Builder.restoreIP((*result)->getAfterIP());
-    ret.push_back(*result);
+    if (result)
+      return result;
   }
-  return ret;
+  return Error::success();
 }
 
 void OpenMPIRBuilder::createScanBBs() {
@@ -4286,7 +4281,7 @@ Expected<SmallVector<llvm::CanonicalLoopInfo *>>
 OpenMPIRBuilder::createCanonicalScanLoops(
     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-    InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
+    InsertPointTy ComputeIP, const Twine &Name) {
   auto *IndVarTy = cast<IntegerType>(Start->getType());
   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
   assert(IndVarTy == Step->getType() && "Step type mismatch");
@@ -4312,36 +4307,49 @@ OpenMPIRBuilder::createCanonicalScanLoops(
     Span = Builder.CreateSub(Stop, Start, "", true);
   }
   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
-    if (InScan) {
-      scanInfo.iv = IV;
-      createScanBBs();
-      auto terminator = Builder.GetInsertBlock()->getTerminator();
-      scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
-      terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
-      emitBlock(scanInfo.OMPBeforeScanBlock,
-                Builder.GetInsertBlock()->getParent());
-      Builder.CreateBr(scanInfo.OMPScanLoopExit);
-
-      emitBlock(scanInfo.OMPScanLoopExit,
-                Builder.GetInsertBlock()->getParent());
-      Builder.CreateBr(scanInfo.continueBlocks.back());
-      emitBlock(scanInfo.OMPScanDispatch,
-                Builder.GetInsertBlock()->getParent());
-      Builder.SetInsertPoint(
-          scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
-    }
+    scanInfo.iv = IV;
+    createScanBBs();
+    auto terminator = Builder.GetInsertBlock()->getTerminator();
+    scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
+    terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
+    emitBlock(scanInfo.OMPBeforeScanBlock,
+              Builder.GetInsertBlock()->getParent());
+    Builder.CreateBr(scanInfo.OMPScanLoopExit);
+
+    emitBlock(scanInfo.OMPScanLoopExit,
+              Builder.GetInsertBlock()->getParent());
+    Builder.CreateBr(scanInfo.continueBlocks.back());
+    emitBlock(scanInfo.OMPScanDispatch,
+              Builder.GetInsertBlock()->getParent());
+    Builder.SetInsertPoint(
+        scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
     return BodyGenCB(Builder.saveIP(), IV);
   };
-  const auto &&FirstGen = [&]() {
-    return createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+  
+  SmallVector<llvm::CanonicalLoopInfo *> result;
+  const auto &&FirstGen = [&]()-> Error {
+    auto LoopInfo = createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
                                InclusiveStop, ComputeIP, Name, true);
+    if (!LoopInfo)
+      return LoopInfo.takeError();
+    result.push_back(*LoopInfo);
+    Builder.restoreIP((*LoopInfo)->getAfterIP());
+    return Error::success();
   };
-  const auto &&SecondGen = [&](LocationDescription loc) {
-    return createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
+  const auto &&SecondGen = [&](LocationDescription loc)-> Error {
+    auto LoopInfo = createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
                                InclusiveStop, ComputeIP, Name, true);
+    if (!LoopInfo)
+      return LoopInfo.takeError();
+    result.push_back(*LoopInfo);
+    Builder.restoreIP((*LoopInfo)->getAfterIP());
+    return Error::success();
   };
   scanInfo.span = Span;
-  auto result = emitScanBasedDirectiveIR(FirstGen, SecondGen);
+  auto err = emitScanBasedDirectiveIR(FirstGen, SecondGen);
+  if(err) {
+    return err;
+  }
   return result;
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 27c0e0bf80255..b1a7023c3204b 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1440,6 +1440,56 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
 
   EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
 }
+void createScan(llvm::Value *scanVar, OpenMPIRBuilder &OMPBuilder, IRBuilder<> &Builder, OpenMPIRBuilder::LocationDescription Loc, OpenMPIRBuilder::InsertPointTy &allocaIP) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+    ASSERT_EXPECTED_INIT(
+        InsertPointTy, retIp,
+        OMPBuilder.createScan(Loc, allocaIP, {scanVar}, true));
+    Builder.restoreIP(retIp);
+}
+
+TEST_F(OpenMPIRBuilderTest, CanonicalScanLoops) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Value *TripCount = F->getArg(0);
+  Type *LCTy = TripCount->getType();
+  Value *StartVal = ConstantInt::get(LCTy, 1);
+  Value *StopVal = ConstantInt::get(LCTy, 100);
+  Value *Step = ConstantInt::get(LCTy, 1);
+  auto allocaIP = Builder.saveIP();
+
+
+  llvm::Value * scanVar = Builder.CreateAlloca(Builder.getInt64Ty(), 1);
+  unsigned NumBodiesGenerated = 0;
+  auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
+    NumBodiesGenerated += 1;
+    
+    
+    Builder.restoreIP(CodeGenIP);
+    createScan(scanVar, OMPBuilder, Builder, Loc, allocaIP);
+    return Error::success();
+  };
+  SmallVector<CanonicalLoopInfo *> Loops;
+    ASSERT_EXPECTED_INIT(
+        SmallVector<CanonicalLoopInfo *>, loopsVec,
+        OMPBuilder.createCanonicalScanLoops(Loc, LoopBodyGenCB,
+                                       StartVal, StopVal, Step,
+                                       false, false, Builder.saveIP(), "scan"));
+  Loops = loopsVec;
+  EXPECT_EQ(Loops.size(), 2U);
+  auto inputLoop = Loops.front();
+  auto scanLoop = Loops.back();
+  Builder.restoreIP(scanLoop->getAfterIP());
+  inputLoop->assertOK();
+  scanLoop->assertOK();
+
+  //// Verify control flow structure (in addition to Loop->assertOK()).
+  EXPECT_EQ(inputLoop->getPreheader()->getSinglePredecessor(), &F->getEntryBlock());
+  EXPECT_EQ(scanLoop->getAfter(), Builder.GetInsertBlock());
+}
 
 TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
   OpenMPIRBuilder OMPBuilder(*M);
diff --git a/llvm/utils/lit/setup.py b/llvm/utils/lit/setup.py
deleted file mode 100644
index b11e3eafb2a35..0000000000000
--- a/llvm/utils/lit/setup.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import os
-import sys
-
-from setuptools import setup, find_packages
-
-# setuptools expects to be invoked from within the directory of setup.py, but it
-# is nice to allow:
-#   python path/to/setup.py install
-# to work (for scripts, etc.)
-os.chdir(os.path.dirname(os.path.abspath(__file__)))
-sys.path.insert(0, ".")
-
-import lit
-
-with open("README.rst", "r", encoding="utf-8") as f:
-    long_description = f.read()
-
-setup(
-    name="lit",
-    version=lit.__version__,
-    author=lit.__author__,
-    author_email=lit.__email__,
-    url="http://llvm.org",
-    license="Apache-2.0 with LLVM exception",
-    license_files=["LICENSE.TXT"],
-    description="A Software Testing Tool",
-    keywords="test C++ automatic discovery",
-    long_description=long_description,
-    classifiers=[
-        "Development Status :: 3 - Alpha",
-        "Environment :: Console",
-        "Intended Audience :: Developers",
-        "License :: OSI Approved :: Apache Software License",
-        "Natural Language :: English",
-        "Operating System :: OS Independent",
-        "Programming Language :: Python",
-        "Topic :: Software Development :: Testing",
-    ],
-    zip_safe=False,
-    packages=find_packages(),
-    entry_points={
-        "console_scripts": [
-            "lit = lit.main:main",
-        ],
-    },
-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6ada71977a6a4..2927044a5c1a5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2720,8 +2720,7 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
         llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
             ompBuilder->createCanonicalScanLoops(
                 loc, bodyGen, lowerBound, upperBound, step,
-                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
-                isInScanRegion);
+                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop");
 
         if (failed(handleError(loopResults, *loopOp)))
           return failure();

>From 9809c2268987c55702480de299db74b8e533c215 Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Thu, 3 Apr 2025 12:08:53 -0500
Subject: [PATCH 3/3] Comments

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  91 +++--
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 350 +++++++++---------
 .../Frontend/OpenMPIRBuilderTest.cpp          | 108 +++---
 llvm/utils/lit/setup.py                       |  46 +++
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  24 +-
 .../Target/LLVMIR/openmp-reduction-scan.mlir  |   4 -
 6 files changed, 368 insertions(+), 255 deletions(-)
 create mode 100644 llvm/utils/lit/setup.py

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index c217e6193e84c..9f7571ce635d6 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -30,7 +30,6 @@
 namespace llvm {
 class CanonicalLoopInfo;
 struct TargetRegionEntryInfo;
-class ScanInfo;
 class OffloadEntriesInfoManager;
 class OpenMPIRBuilder;
 
@@ -503,7 +502,7 @@ class OpenMPIRBuilder {
       return allocaInst;
     }
   };
-  class ScanInfo {
+  struct ScanInformation {
   public:
     llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
     llvm::BasicBlock *OMPAfterScanBlock = nullptr;
@@ -512,12 +511,9 @@ class OpenMPIRBuilder {
     llvm::BasicBlock *OMPScanLoopExit = nullptr;
     bool OMPFirstScanLoop = false;
     llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
-    llvm::Value *iv;
-    llvm::Value *span;
-    SmallVector<llvm::BasicBlock *> continueBlocks;
-  };
-  
-  class ScanInfo scanInfo;
+    llvm::Value *IV;
+    llvm::Value *Span;
+  } ScanInfo;
 
   /// Initialize the internal state, this will put structures types and
   /// potentially other helpers into the underlying module. Must be called
@@ -745,6 +741,30 @@ class OpenMPIRBuilder {
                       LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
                       const Twine &Name = "loop");
 
+  /// Generator for the control flow structure of an OpenMP canonical loops if
+  /// the parent directive has an `inscan` modifier specified.
+  /// If the `inscan` modifier is specifier, the region of the parent is
+  /// expected to have a `scan` directive specified. Based on the clauses in
+  /// scan directive, the body of the loop is split into two loops: Input loop
+  /// and Scan Loop. Input loop contains the code generated for input phase of
+  /// scan and Scan loop contains the code generated for scan phase of scan.
+  ///
+  /// \param Loc       The insert and source location description.
+  /// \param BodyGenCB Callback that will generate the loop body code.
+  /// \param Start     Value of the loop counter for the first iterations.
+  /// \param Stop      Loop counter values past this will stop the loop.
+  /// \param Step      Loop counter increment after each iteration; negative
+  ///                  means counting down.
+  /// \param IsSigned  Whether Start, Stop and Step are signed integers.
+  /// \param InclusiveStop Whether \p Stop itself is a valid value for the loop
+  ///                      counter.
+  /// \param ComputeIP Insertion point for instructions computing the trip
+  ///                  count. Can be used to ensure the trip count is available
+  ///                  at the outermost loop of a loop nest. If not set,
+  ///                  defaults to the preheader of the generated loop.
+  /// \param Name      Base name used to derive BB and instruction names.
+  ///
+  /// \returns A vector containing Loop Info of Input Loop and Scan Loop.
   Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
       const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
       Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
@@ -1554,19 +1574,37 @@ class OpenMPIRBuilder {
       ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
       Function *ReduceFn, AttributeList FuncAttrs);
 
+  /// Creates the runtime call specified
+  /// \param Callee Function Declaration Value
+  /// \param Args Arguments passed to the call
+  /// \param Name Optional param to specify the name of the call Instruction.
+  ///
+  /// \return The Runtime call instruction created.
+  llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee,
+                                          ArrayRef<llvm::Value *> Args,
+                                          const llvm::Twine &Name);
+
+  /// Helper function for CreateCanonicalScanLoops to create InputLoop
+  /// in the firstGen and Scan Loop in the SecondGen
+  /// \param InputLoopGen Callback for generating the loop for input phase
+  /// \param ScanLoopGen Callback for generating the loop for scan phase
+  ///
+  /// \return error if any produced, else return success.
   Error emitScanBasedDirectiveIR(
-      llvm::function_ref<Error ()> FirstGen,
-      llvm::function_ref<Error (LocationDescription loc)>
-          SecondGen);
-  llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
-                                          ArrayRef<llvm::Value *> args,
-                                          const llvm::Twine &name);
+      llvm::function_ref<Error()> InputLoopGen,
+      llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen);
 
+  /// Creates the basic blocks required for scan reduction.
   void createScanBBs();
-  void emitScanBasedDirectiveDeclsIR(llvm::Value *span,
-                                     ArrayRef<llvm::Value *> ScanVars);
+
+  /// Creates the buffer needed for scan reduction.
+  /// \param ScanVars Scan Variables.
+  void emitScanBasedDirectiveDeclsIR(ArrayRef<llvm::Value *> ScanVars);
+
+  /// Copies the result back to the reduction variable.
+  /// \param ReductionInfos Array type containing the ReductionOps.
   void emitScanBasedDirectiveFinalsIR(
-      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos);
 
   /// This function emits a helper that gathers Reduce lists from the first
   /// lane of every active warp to lanes in the first warp.
@@ -2640,8 +2678,17 @@ class OpenMPIRBuilder {
                                     BodyGenCallbackTy BodyGenCB,
                                     FinalizeCallbackTy FiniCB, Value *Filter);
 
-  /// Generator for the scan reduction
-  ///
+  /// This function performs the scan reduction of the values updated in
+  /// the input phase. This reductions needs to be emitted between input and
+  /// scan loop returned by `CreateCanonicalScanLoops`. The following is the
+  /// code that is generated, `buffer` and `span` are exected to be
+  /// populated before calling the function
+  ///
+  ///  for (int k = 0; k != ceil(log2(span)); ++k) {
+  ///    i=pow(2,k)
+  ///    for (size cnt = last_iter; cnt >= i; --cnt)
+  ///      buffer[cnt] op= buffer[cnt-i];
+  ///  }
   /// \param Loc The insert and source location description.
   /// \param FinalizeIP The IP where the reduction result needs
   //                   to be copied back to original variable.
@@ -2650,9 +2697,11 @@ class OpenMPIRBuilder {
   /// \returns The insertion position *after* the masked.
   InsertPointOrErrorTy emitScanReduction(
       const LocationDescription &Loc, InsertPointTy &FinalizeIP,
-      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos);
 
-  /// Generator for the scan reduction
+  /// This directive split and directs the control flow to input phase
+  ///  blocks or scan phase blocks based on 1. whether input loop or scan loop
+  ///  is executed, 2. whether exclusive or inclusive scan is used.
   ///
   /// \param Loc The insert and source location description.
   /// \param AllocaIP The IP where the temporary buffer for scan reduction
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 52ad448ad653f..590cf4cdabb79 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -59,6 +59,7 @@
 #include "llvm/Transforms/Utils/LoopPeel.h"
 #include "llvm/Transforms/Utils/UnrollLoop.h"
 
+#include <cassert>
 #include <cstdint>
 #include <optional>
 
@@ -80,7 +81,6 @@ static cl::opt<double> UnrollThresholdFactor(
     cl::init(1.5));
 
 #ifndef NDEBUG
-
 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
 /// an InsertPoint stores the instruction before something is inserted. For
@@ -739,6 +739,7 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
                             /* AllowAlloca */ true,
                             /* AllocaBlock*/ OI.OuterAllocaBB,
                             /* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
+
     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
                       << " Exit: " << OI.ExitBB->getName() << "\n");
@@ -3919,131 +3920,132 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
 }
 
 llvm::CallInst *
-OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
-                                         ArrayRef<llvm::Value *> args,
-                                         const llvm::Twine &name) {
-  llvm::CallInst *call = Builder.CreateCall(
-      callee, args, SmallVector<llvm::OperandBundleDef, 1>(), name);
-  call->setDoesNotThrow();
-  return call;
-}
-
-//Expects current basic block is dominated by BeforeScanBB.
-//Once Scan directive is encountered, the code after scan block should be
-//dominated by AfterScanBB. Scan directive splits the code sequence to 
-//before and after parts. Based on whether inclusive or exclusive 
-//clause is used in the scan directive, it adds jumps to input and 
-//scan phase in the first and second loops. 
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee Callee,
+                                         ArrayRef<llvm::Value *> Args,
+                                         const llvm::Twine &Name) {
+  llvm::CallInst *Call = Builder.CreateCall(
+      Callee, Args, SmallVector<llvm::OperandBundleDef, 1>(), Name);
+  Call->setDoesNotThrow();
+  return Call;
+}
+
+// Expects input basic block is dominated by BeforeScanBB.
+// Once Scan directive is encountered, the code after scan directive should be
+// dominated by AfterScanBB. Scan directive splits the code sequence to
+// scan and input phase. Based on whether inclusive or exclusive
+// clause is used in the scan directive and whether input loop or scan loop
+// is lowered, it adds jumps to input and scan phase. First Scan loop is the
+// input loop and second is the scan loop. The code generated handles only
+// inclusive scans now.
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
     const LocationDescription &Loc, InsertPointTy AllocaIP,
     ArrayRef<llvm::Value *> ScanVars, bool IsInclusive) {
-  if (scanInfo.OMPFirstScanLoop) {
+  if (ScanInfo.OMPFirstScanLoop) {
     Builder.restoreIP(AllocaIP);
-    emitScanBasedDirectiveDeclsIR(scanInfo.span, ScanVars);
+    emitScanBasedDirectiveDeclsIR(ScanVars);
   }
   if (!updateToLocation(Loc))
     return Loc.IP;
   unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
-  llvm::Value *iv = scanInfo.iv;
+  llvm::Value *IV = ScanInfo.IV;
 
-  if (scanInfo.OMPFirstScanLoop) {
+  if (ScanInfo.OMPFirstScanLoop) {
     // Emit buffer[i] = red; at the end of the input phase.
-    for (int i = 0; i < ScanVars.size(); i++) {
-      auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
-
-      auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
-      auto val = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
-      auto src = Builder.CreateLoad(destTy, ScanVars[i]);
-      auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
-          val, destTy->getPointerTo(defaultAS));
-
-      Builder.CreateStore(src, dest);
+    for (Value *ScanVar : ScanVars) {
+      Value *Buff = ScanInfo.ReductionVarToScanBuffs[ScanVar];
+      Type *DestTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+      Value *Val = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+      Value *Src = Builder.CreateLoad(DestTy, ScanVar);
+      Value *Dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          Val, DestTy->getPointerTo(defaultAS));
+
+      Builder.CreateStore(Src, Dest);
     }
   }
-  Builder.CreateBr(scanInfo.OMPScanLoopExit);
-  llvm::LLVMContext &llvmContext = Builder.getContext();
-  Builder.SetInsertPoint(scanInfo.OMPScanDispatch);
+  Builder.CreateBr(ScanInfo.OMPScanLoopExit);
+  emitBlock(ScanInfo.OMPScanDispatch, Builder.GetInsertBlock()->getParent());
 
+  // Initialize the private reduction variable to 0 in each iteration.
+  // It is used to copy intial values to scan buffer.
   ConstantInt *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
-  for (int i = 0; i < ScanVars.size(); i++) {
-    auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
-    auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
-        ScanVars[i], destTy->getPointerTo(defaultAS));
-    Builder.CreateStore(Zero, dest);
+  for (Value *ScanVar : ScanVars) {
+    Type *DestTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+    Value *Dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ScanVar, DestTy->getPointerTo(defaultAS));
+    Builder.CreateStore(Zero, Dest);
   }
 
-  if (!scanInfo.OMPFirstScanLoop) {
-    iv = scanInfo.iv;
+  if (!ScanInfo.OMPFirstScanLoop) {
+    IV = ScanInfo.IV;
     // Emit red = buffer[i]; at the entrance to the scan phase.
-    for (int i = 0; i < ScanVars.size(); i++) {
-      auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
-      auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
-      auto newVPtr = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
-      auto newV = Builder.CreateLoad(destTy, newVPtr);
-      auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
-          ScanVars[i], destTy->getPointerTo(defaultAS));
-
-      Builder.CreateStore(newV, dest);
+    // TODO: if exclusive scan, the red = buffer[i-1] needs to be updated.
+    for (Value *ScanVar : ScanVars) {
+      Value *Buff = ScanInfo.ReductionVarToScanBuffs[ScanVar];
+      Type *DestTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+      Value *SrcPtr =
+          Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+      Value *Src = Builder.CreateLoad(DestTy, SrcPtr);
+      Value *Dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ScanVar, DestTy->getPointerTo(defaultAS));
+
+      Builder.CreateStore(Src, Dest);
     }
   }
 
-  llvm::Value *CmpI = Builder.CreateICmpUGE(Builder.getInt32(100), Builder.getInt32(0));
-  if (scanInfo.OMPFirstScanLoop == IsInclusive) {
-    Builder.CreateCondBr(CmpI, scanInfo.OMPBeforeScanBlock,
-                       scanInfo.OMPAfterScanBlock);
+  // TODO: Update it to CreateBr and remove dead blocks
+  llvm::Value *CmpI = Builder.getInt1(true);
+  if (ScanInfo.OMPFirstScanLoop == IsInclusive) {
+    Builder.CreateCondBr(CmpI, ScanInfo.OMPBeforeScanBlock,
+                         ScanInfo.OMPAfterScanBlock);
   } else {
-    Builder.CreateCondBr(CmpI, scanInfo.OMPAfterScanBlock,
-                       scanInfo.OMPBeforeScanBlock);
+    Builder.CreateCondBr(CmpI, ScanInfo.OMPAfterScanBlock,
+                         ScanInfo.OMPBeforeScanBlock);
   }
-  emitBlock(scanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
-  Builder.SetInsertPoint(scanInfo.OMPAfterScanBlock);
+  emitBlock(ScanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(ScanInfo.OMPAfterScanBlock);
   return Builder.saveIP();
 }
 
 void OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
-    llvm::Value *span, ArrayRef<llvm::Value *> ScanVars) {
+    ArrayRef<Value *> ScanVars) {
 
-  ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
-  llvm::Value *allocSpan = Builder.CreateAdd(span, One);
-  for (auto &scanVar : ScanVars) {
+  Value *AllocSpan = Builder.CreateAdd(ScanInfo.Span, Builder.getInt32(1));
+  for (Value *ScanVar : ScanVars) {
     // TODO: remove after all users of by-ref are updated to use the alloc
     // region: Allocate reduction variable (which is a pointer to the real
     // reduciton variable allocated in the inlined region)
-    llvm::Value *buff =
-        Builder.CreateAlloca(Builder.getInt32Ty(), allocSpan, "vla");
-    scanInfo.ReductionVarToScanBuffs[scanVar] = buff;
+    llvm::Value *Buff =
+        Builder.CreateAlloca(Builder.getInt32Ty(), AllocSpan, "vla");
+    ScanInfo.ReductionVarToScanBuffs[ScanVar] = Buff;
   }
 }
 
 void OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
-    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
-  llvm::Value *span = scanInfo.span;
-  // llvm::Value *OMPLast = span;
+    SmallVector<ReductionInfo> ReductionInfos) {
   llvm::Value *OMPLast = Builder.CreateNSWAdd(
-      span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
-  // llvm::Value *OMPLast = Builder.CreateNSWSub(
-  //     span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
-  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
-  for (int i = 0; i < reductionInfos.size(); i++) {
-    auto privateVar = reductionInfos[i].PrivateVariable;
-    auto origVar = reductionInfos[i].Variable;
-    auto buff = scanInfo.ReductionVarToScanBuffs[privateVar];
+      ScanInfo.Span,
+      llvm::ConstantInt::get(ScanInfo.Span->getType(), 1, /*isSigned=*/false));
+  unsigned int DefaultAS = M.getDataLayout().getProgramAddressSpace();
+  for (ReductionInfo RedInfo : ReductionInfos) {
+    Value *PrivateVar = RedInfo.PrivateVariable;
+    Value *OrigVar = RedInfo.Variable;
+    Value *Buff = ScanInfo.ReductionVarToScanBuffs[PrivateVar];
 
-    auto srcTy = reductionInfos[i].ElementType;
-    auto val = Builder.CreateInBoundsGEP(srcTy, buff, OMPLast, "arrayOffset");
-    auto src = Builder.CreateLoad(srcTy, val);
-    auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
-        origVar, srcTy->getPointerTo(defaultAS));
+    Type *SrcTy = RedInfo.ElementType;
+    Value *Val = Builder.CreateInBoundsGEP(SrcTy, Buff, OMPLast, "arrayOffset");
+    Value *Src = Builder.CreateLoad(SrcTy, Val);
+    Value *Dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        OrigVar, SrcTy->getPointerTo(DefaultAS));
 
-    Builder.CreateStore(src, dest);
+    Builder.CreateStore(Src, Dest);
   }
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
     const LocationDescription &Loc, InsertPointTy &FinalizeIP,
-    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> ReductionInfos) {
 
-  llvm::Value *spanDiff = scanInfo.span;
+  llvm::Value *spanDiff = ScanInfo.Span;
 
   if (!updateToLocation(Loc))
     return Loc.IP;
@@ -4073,9 +4075,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
   emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
   Builder.SetInsertPoint(LoopBB);
 
-  auto *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  PHINode *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
   //// size pow2k = 1;
-  auto *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  PHINode *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
   Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
                        InputBB);
   Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), InputBB);
@@ -4092,25 +4094,28 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
   auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
   IVal->addIncoming(NMin1, LoopBB);
   unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
-  for (int i = 0; i < reductionInfos.size(); i++) {
-    auto &reductionVal = reductionInfos[i].PrivateVariable;
-    auto buff = scanInfo.ReductionVarToScanBuffs[reductionVal];
-    auto destTy = reductionInfos[i].ElementType;
+  for (ReductionInfo RedInfo : ReductionInfos) {
+    Value *ReductionVal = RedInfo.PrivateVariable;
+    Value *Buff = ScanInfo.ReductionVarToScanBuffs[ReductionVal];
+    Type *DestTy = RedInfo.ElementType;
     Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
-    auto lhsPtr = Builder.CreateInBoundsGEP(destTy, buff, IV, "arrayOffset");
-    auto offsetIval = Builder.CreateNUWSub(IV, Pow2K);
-    auto rhsPtr =
-        Builder.CreateInBoundsGEP(destTy, buff, offsetIval, "arrayOffset");
-    auto lhs = Builder.CreateLoad(destTy, lhsPtr);
-    auto rhs = Builder.CreateLoad(destTy, rhsPtr);
-    auto lhsAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
-        lhsPtr, rhs->getType()->getPointerTo(defaultAS));
-    llvm::Value *result;
+    Value *LHSPtr = Builder.CreateInBoundsGEP(DestTy, Buff, IV, "arrayOffset");
+    Value *OffsetIval = Builder.CreateNUWSub(IV, Pow2K);
+    Value *RHSPtr =
+        Builder.CreateInBoundsGEP(DestTy, Buff, OffsetIval, "arrayOffset");
+    Value *LHS = Builder.CreateLoad(DestTy, LHSPtr);
+    Value *RHS = Builder.CreateLoad(DestTy, RHSPtr);
+    // Value * lhsAddr =
+    //       Builder.CreatePointerBitCastOrAddrSpaceCast(lhs,
+    //                                                   Builder.getPtrTy(0));
+    Value *LHSAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        LHSPtr, RHS->getType()->getPointerTo(defaultAS));
+    llvm::Value *Result;
     InsertPointOrErrorTy AfterIP =
-        reductionInfos[i].ReductionGen(Builder.saveIP(), lhs, rhs, result);
+        RedInfo.ReductionGen(Builder.saveIP(), LHS, RHS, Result);
     if (!AfterIP)
       return AfterIP.takeError();
-    Builder.CreateStore(result, lhsAddr);
+    Builder.CreateStore(Result, LHSAddr);
   }
   llvm::Value *NextIVal = Builder.CreateNUWSub(
       IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
@@ -4128,23 +4133,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
   Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
   emitBlock(ExitBB, Builder.GetInsertBlock()->getParent());
   Builder.SetInsertPoint(ExitBB);
-  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
       createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
 
   Builder.restoreIP(FinalizeIP);
-  emitScanBasedDirectiveFinalsIR(reductionInfos);
+  emitScanBasedDirectiveFinalsIR(ReductionInfos);
   FinalizeIP = Builder.saveIP();
 
-  return afterIP;
+  return AfterIP;
 }
 
-Error
-OpenMPIRBuilder::emitScanBasedDirectiveIR(
-    llvm::function_ref<Error ()> FirstGen,
-    llvm::function_ref<Error (LocationDescription loc)>
-        SecondGen) {
+Error OpenMPIRBuilder::emitScanBasedDirectiveIR(
+    llvm::function_ref<Error()> InputLoopGen,
+    llvm::function_ref<Error(LocationDescription Loc)> ScanLoopGen) {
 
-  SmallVector<llvm::CanonicalLoopInfo *> ret;
   {
     // Emit loop with input phase:
     // #pragma omp ...
@@ -4152,31 +4154,31 @@ OpenMPIRBuilder::emitScanBasedDirectiveIR(
     //   <input phase>;
     //   buffer[i] = red;
     // }
-    scanInfo.OMPFirstScanLoop = true;
-    auto result = FirstGen();
-    if (result)
-      return result;
+    ScanInfo.OMPFirstScanLoop = true;
+    auto Result = InputLoopGen();
+    if (Result)
+      return Result;
   }
   {
-    scanInfo.OMPFirstScanLoop = false;
-    auto result = SecondGen(Builder.saveIP());
-    if (result)
-      return result;
+    ScanInfo.OMPFirstScanLoop = false;
+    auto Result = ScanLoopGen(Builder.saveIP());
+    if (Result)
+      return Result;
   }
   return Error::success();
 }
 
 void OpenMPIRBuilder::createScanBBs() {
   auto fun = Builder.GetInsertBlock()->getParent();
-  scanInfo.OMPScanExitBlock =
+  ScanInfo.OMPScanExitBlock =
       BasicBlock::Create(fun->getContext(), "omp.exit.inscan.bb");
-  scanInfo.OMPScanDispatch =
+  ScanInfo.OMPScanDispatch =
       BasicBlock::Create(fun->getContext(), "omp.inscan.dispatch");
-  scanInfo.OMPAfterScanBlock =
+  ScanInfo.OMPAfterScanBlock =
       BasicBlock::Create(fun->getContext(), "omp.after.scan.bb");
-  scanInfo.OMPBeforeScanBlock =
+  ScanInfo.OMPBeforeScanBlock =
       BasicBlock::Create(fun->getContext(), "omp.before.scan.bb");
-  scanInfo.OMPScanLoopExit =
+  ScanInfo.OMPScanLoopExit =
       BasicBlock::Create(fun->getContext(), "omp.scan.loop.exit");
 }
 
@@ -4282,75 +4284,81 @@ OpenMPIRBuilder::createCanonicalScanLoops(
     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
     InsertPointTy ComputeIP, const Twine &Name) {
-  auto *IndVarTy = cast<IntegerType>(Start->getType());
-  assert(IndVarTy == Stop->getType() && "Stop type mismatch");
-  assert(IndVarTy == Step->getType() && "Step type mismatch");
   LocationDescription ComputeLoc =
       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
   updateToLocation(ComputeLoc);
 
-  ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
-
-  // Distance between Start and Stop; always positive.
-  Value *Span;
-
-  // Condition whether there are no iterations are executed at all, e.g. because
-  // UB < LB.
+  Value *TripCount = calculateCanonicalLoopTripCount(
+      ComputeLoc, Start, Stop, Step, IsSigned, InclusiveStop, Name);
+  ScanInfo.Span = TripCount;
 
-  if (IsSigned) {
-    // Ensure that increment is positive. If not, negate and invert LB and UB.
-    Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
-    Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
-    Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
-    Span = Builder.CreateSub(UB, LB, "", false, true);
-  } else {
-    Span = Builder.CreateSub(Stop, Start, "", true);
-  }
   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
-    scanInfo.iv = IV;
+    /// The control of the loopbody of following structure:
+    ///
+    ///     InputBlock
+    ///        |
+    ///     ContinueBlock
+    ///
+    ///  is transformed to:
+    ///
+    ///     InputBlock
+    ///        |
+    ///     OMPScanDispatch
+    ///
+    ///     OMPBeforeScanBlock
+    ///        |
+    ///     OMPScanLoopExit
+    ///        |
+    ///     ContinueBlock
+    ///
+    /// OMPBeforeScanBlock dominates the control flow of code generated until
+    /// scan directive is encountered and OMPAfterScanBlock dominates the
+    /// control flow of code generated after scan is encountered. The successor
+    /// of OMPScanDispatch can be OMPBeforeScanBlock or OMPAfterScanBlock based
+    /// on 1.whether it is in Input phase or Scan Phase , 2. whether it is an
+    /// exclusive or inclusive scan.
+    ScanInfo.IV = IV;
     createScanBBs();
-    auto terminator = Builder.GetInsertBlock()->getTerminator();
-    scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
-    terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
-    emitBlock(scanInfo.OMPBeforeScanBlock,
-              Builder.GetInsertBlock()->getParent());
-    Builder.CreateBr(scanInfo.OMPScanLoopExit);
-
-    emitBlock(scanInfo.OMPScanLoopExit,
+    BasicBlock *InputBlock = Builder.GetInsertBlock();
+    Instruction *Terminator = InputBlock->getTerminator();
+    assert(Terminator->getNumSuccessors() == 1);
+    BasicBlock *ContinueBlock = Terminator->getSuccessor(0);
+    Terminator->setSuccessor(0, ScanInfo.OMPScanDispatch);
+    emitBlock(ScanInfo.OMPBeforeScanBlock,
               Builder.GetInsertBlock()->getParent());
-    Builder.CreateBr(scanInfo.continueBlocks.back());
-    emitBlock(scanInfo.OMPScanDispatch,
-              Builder.GetInsertBlock()->getParent());
-    Builder.SetInsertPoint(
-        scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
+    Builder.CreateBr(ScanInfo.OMPScanLoopExit);
+    emitBlock(ScanInfo.OMPScanLoopExit, Builder.GetInsertBlock()->getParent());
+    Builder.CreateBr(ContinueBlock);
+    Builder.SetInsertPoint(ScanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
     return BodyGenCB(Builder.saveIP(), IV);
   };
-  
-  SmallVector<llvm::CanonicalLoopInfo *> result;
-  const auto &&FirstGen = [&]()-> Error {
-    auto LoopInfo = createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
-                               InclusiveStop, ComputeIP, Name, true);
+
+  SmallVector<llvm::CanonicalLoopInfo *> Result;
+  const auto &&InputLoopGen = [&]() -> Error {
+    auto LoopInfo =
+        createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+                            InclusiveStop, ComputeIP, Name, true);
     if (!LoopInfo)
       return LoopInfo.takeError();
-    result.push_back(*LoopInfo);
+    Result.push_back(*LoopInfo);
     Builder.restoreIP((*LoopInfo)->getAfterIP());
     return Error::success();
   };
-  const auto &&SecondGen = [&](LocationDescription loc)-> Error {
-    auto LoopInfo = createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
-                               InclusiveStop, ComputeIP, Name, true);
+  const auto &&ScanLoopGen = [&](LocationDescription Loc) -> Error {
+    auto LoopInfo =
+        createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+                            InclusiveStop, ComputeIP, Name, true);
     if (!LoopInfo)
       return LoopInfo.takeError();
-    result.push_back(*LoopInfo);
+    Result.push_back(*LoopInfo);
     Builder.restoreIP((*LoopInfo)->getAfterIP());
     return Error::success();
   };
-  scanInfo.span = Span;
-  auto err = emitScanBasedDirectiveIR(FirstGen, SecondGen);
-  if(err) {
-    return err;
+  Error Err = emitScanBasedDirectiveIR(InputLoopGen, ScanLoopGen);
+  if (Err) {
+    return Err;
   }
-  return result;
+  return Result;
 }
 
 Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
@@ -4427,7 +4435,7 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     Value *Span = Builder.CreateMul(IV, Step);
     Value *IndVar = Builder.CreateAdd(Span, Start);
     if (InScan) {
-      scanInfo.iv = IndVar;
+      ScanInfo.IV = IndVar;
     }
     return BodyGenCB(Builder.saveIP(), IndVar);
   };
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b1a7023c3204b..ba946b1c1b9ab 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -1440,55 +1440,13 @@ TEST_F(OpenMPIRBuilderTest, CanonicalLoopSimple) {
 
   EXPECT_EQ(&Loop->getAfter()->front(), RetInst);
 }
-void createScan(llvm::Value *scanVar, OpenMPIRBuilder &OMPBuilder, IRBuilder<> &Builder, OpenMPIRBuilder::LocationDescription Loc, OpenMPIRBuilder::InsertPointTy &allocaIP) {
+void createScan(llvm::Value *scanVar, OpenMPIRBuilder &OMPBuilder,
+                IRBuilder<> &Builder, OpenMPIRBuilder::LocationDescription Loc,
+                OpenMPIRBuilder::InsertPointTy &allocaIP) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
-    ASSERT_EXPECTED_INIT(
-        InsertPointTy, retIp,
-        OMPBuilder.createScan(Loc, allocaIP, {scanVar}, true));
-    Builder.restoreIP(retIp);
-}
-
-TEST_F(OpenMPIRBuilderTest, CanonicalScanLoops) {
-  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
-  OpenMPIRBuilder OMPBuilder(*M);
-  OMPBuilder.initialize();
-  IRBuilder<> Builder(BB);
-  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
-  Value *TripCount = F->getArg(0);
-  Type *LCTy = TripCount->getType();
-  Value *StartVal = ConstantInt::get(LCTy, 1);
-  Value *StopVal = ConstantInt::get(LCTy, 100);
-  Value *Step = ConstantInt::get(LCTy, 1);
-  auto allocaIP = Builder.saveIP();
-
-
-  llvm::Value * scanVar = Builder.CreateAlloca(Builder.getInt64Ty(), 1);
-  unsigned NumBodiesGenerated = 0;
-  auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
-    NumBodiesGenerated += 1;
-    
-    
-    Builder.restoreIP(CodeGenIP);
-    createScan(scanVar, OMPBuilder, Builder, Loc, allocaIP);
-    return Error::success();
-  };
-  SmallVector<CanonicalLoopInfo *> Loops;
-    ASSERT_EXPECTED_INIT(
-        SmallVector<CanonicalLoopInfo *>, loopsVec,
-        OMPBuilder.createCanonicalScanLoops(Loc, LoopBodyGenCB,
-                                       StartVal, StopVal, Step,
-                                       false, false, Builder.saveIP(), "scan"));
-  Loops = loopsVec;
-  EXPECT_EQ(Loops.size(), 2U);
-  auto inputLoop = Loops.front();
-  auto scanLoop = Loops.back();
-  Builder.restoreIP(scanLoop->getAfterIP());
-  inputLoop->assertOK();
-  scanLoop->assertOK();
-
-  //// Verify control flow structure (in addition to Loop->assertOK()).
-  EXPECT_EQ(inputLoop->getPreheader()->getSinglePredecessor(), &F->getEntryBlock());
-  EXPECT_EQ(scanLoop->getAfter(), Builder.GetInsertBlock());
+  ASSERT_EXPECTED_INIT(InsertPointTy, retIp,
+                       OMPBuilder.createScan(Loc, allocaIP, {scanVar}, true));
+  Builder.restoreIP(retIp);
 }
 
 TEST_F(OpenMPIRBuilderTest, CanonicalLoopTripCount) {
@@ -5385,6 +5343,60 @@ TEST_F(OpenMPIRBuilderTest, CreateReductions) {
   EXPECT_TRUE(findGEPZeroOne(ReductionFn->getArg(1), FirstRHS, SecondRHS));
 }
 
+TEST_F(OpenMPIRBuilderTest, ScanReduction) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Value *TripCount = F->getArg(0);
+  Type *LCTy = TripCount->getType();
+  Value *StartVal = ConstantInt::get(LCTy, 1);
+  Value *StopVal = ConstantInt::get(LCTy, 100);
+  Value *Step = ConstantInt::get(LCTy, 1);
+  auto allocaIP = Builder.saveIP();
+
+  llvm::Value *scanVar = Builder.CreateAlloca(Builder.getFloatTy());
+  llvm::Value *origVar = Builder.CreateAlloca(Builder.getFloatTy());
+  unsigned NumBodiesGenerated = 0;
+  auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, llvm::Value *LC) {
+    NumBodiesGenerated += 1;
+    Builder.restoreIP(CodeGenIP);
+    createScan(scanVar, OMPBuilder, Builder, Loc, allocaIP);
+    return Error::success();
+  };
+  SmallVector<CanonicalLoopInfo *> Loops;
+  ASSERT_EXPECTED_INIT(SmallVector<CanonicalLoopInfo *>, loopsVec,
+                       OMPBuilder.createCanonicalScanLoops(
+                           Loc, LoopBodyGenCB, StartVal, StopVal, Step, false,
+                           false, Builder.saveIP(), "scan"));
+  Loops = loopsVec;
+  EXPECT_EQ(Loops.size(), 2U);
+  auto inputLoop = Loops.front();
+  auto scanLoop = Loops.back();
+  Builder.restoreIP(scanLoop->getAfterIP());
+  inputLoop->assertOK();
+  scanLoop->assertOK();
+
+  //// Verify control flow structure (in addition to Loop->assertOK()).
+  EXPECT_EQ(inputLoop->getPreheader()->getSinglePredecessor(),
+            &F->getEntryBlock());
+  EXPECT_EQ(scanLoop->getAfter(), Builder.GetInsertBlock());
+  EXPECT_EQ(NumBodiesGenerated, 2U);
+  SmallVector<OpenMPIRBuilder::ReductionInfo> reductionInfos = {
+      {Builder.getFloatTy(), origVar, scanVar,
+       /*EvaluationKind=*/OpenMPIRBuilder::EvalKind::Scalar, sumReduction,
+       /*ReductionGenClang=*/nullptr, sumAtomicReduction}};
+  auto FinalizeIP = scanLoop->getAfterIP();
+  OpenMPIRBuilder::LocationDescription RedLoc({inputLoop->getAfterIP(), DL});
+  llvm::BasicBlock *Cont = splitBB(Builder, false, "omp.scan.loop.cont");
+  ASSERT_EXPECTED_INIT(
+      InsertPointTy, retIp,
+      OMPBuilder.emitScanReduction(RedLoc, FinalizeIP, reductionInfos));
+  Builder.restoreIP(retIp);
+  Builder.CreateBr(Cont);
+}
+
 TEST_F(OpenMPIRBuilderTest, CreateTwoReductions) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
diff --git a/llvm/utils/lit/setup.py b/llvm/utils/lit/setup.py
new file mode 100644
index 0000000000000..b11e3eafb2a35
--- /dev/null
+++ b/llvm/utils/lit/setup.py
@@ -0,0 +1,46 @@
+import os
+import sys
+
+from setuptools import setup, find_packages
+
+# setuptools expects to be invoked from within the directory of setup.py, but it
+# is nice to allow:
+#   python path/to/setup.py install
+# to work (for scripts, etc.)
+os.chdir(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, ".")
+
+import lit
+
+with open("README.rst", "r", encoding="utf-8") as f:
+    long_description = f.read()
+
+setup(
+    name="lit",
+    version=lit.__version__,
+    author=lit.__author__,
+    author_email=lit.__email__,
+    url="http://llvm.org",
+    license="Apache-2.0 with LLVM exception",
+    license_files=["LICENSE.TXT"],
+    description="A Software Testing Tool",
+    keywords="test C++ automatic discovery",
+    long_description=long_description,
+    classifiers=[
+        "Development Status :: 3 - Alpha",
+        "Environment :: Console",
+        "Intended Audience :: Developers",
+        "License :: OSI Approved :: Apache Software License",
+        "Natural Language :: English",
+        "Operating System :: OS Independent",
+        "Programming Language :: Python",
+        "Topic :: Software Development :: Testing",
+    ],
+    zip_safe=False,
+    packages=find_packages(),
+    entry_points={
+        "console_scripts": [
+            "lit = lit.main:main",
+        ],
+    },
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2927044a5c1a5..bea5ec02d5f05 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2267,13 +2267,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
 
   SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
       findCurrentLoopInfos(moduleTranslation);
-  auto beforeLoopFinishIp = loopInfos.front()->getAfterIP();
-  auto afterLoopFinishIp = loopInfos.back()->getAfterIP();
+  auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
+  auto scanLoopFinishIp = loopInfos.back()->getAfterIP();
   bool isInScanRegion =
       wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
                                      mlir::omp::ReductionModifier::inscan);
   if (isInScanRegion) {
-    builder.restoreIP(beforeLoopFinishIp);
+    builder.restoreIP(inputLoopFinishIp);
     SmallVector<OwningReductionGen> owningReductionGens;
     SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
     SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
@@ -2282,7 +2282,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                          privateReductionVariables, reductionInfos);
     llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
     llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
-        ompBuilder->emitScanReduction(builder.saveIP(), afterLoopFinishIp,
+        ompBuilder->emitScanReduction(builder.saveIP(), scanLoopFinishIp,
                                       reductionInfos);
     if (failed(handleError(redIP, opInst)))
       return failure();
@@ -2715,24 +2715,26 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
       bool isInScanRegion =
           wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
                                          mlir::omp::ReductionModifier::inscan);
-      // TODOSCAN: Take care of loop and add asserts if required
       if (isInScanRegion) {
+        // TODO: Handle nesting if Scan loop is nested in a loop
+        assert(loopOp.getNumLoops() == 1);
         llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
             ompBuilder->createCanonicalScanLoops(
                 loc, bodyGen, lowerBound, upperBound, step,
-                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop");
+                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP,
+                "loop");
 
         if (failed(handleError(loopResults, *loopOp)))
           return failure();
-        auto beforeLoop = loopResults->front();
-        auto afterLoop = loopResults->back();
+        auto inputLoop = loopResults->front();
+        auto scanLoop = loopResults->back();
         moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
             [&](OpenMPLoopInfoStackFrame &frame) {
-              frame.loopInfos.push_back(beforeLoop);
-              frame.loopInfos.push_back(afterLoop);
+              frame.loopInfos.push_back(inputLoop);
+              frame.loopInfos.push_back(scanLoop);
               return WalkResult::interrupt();
             });
-        builder.restoreIP(afterLoop->getAfterIP());
+        builder.restoreIP(scanLoop->getAfterIP());
         return success();
       } else {
         llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
diff --git a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
index 1d8bbbef3e9b3..a88c1993aebe1 100644
--- a/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-reduction-scan.mlir
@@ -1,9 +1,5 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
-// nonesense minimised code simulating the control flow graph generated by flang
-// for array reductions. The important thing here is that we are testing a byref
-// reduction with a cleanup region, and the various regions contain multiple
-// blocks
 omp.declare_reduction @add_reduction_i32 : i32 init {
 ^bb0(%arg0: i32):
   %0 = llvm.mlir.constant(0 : i32) : i32



More information about the llvm-commits mailing list