[Mlir-commits] [llvm] [mlir] scan lowering changes (PR #133149)

Anchu Rajendran S llvmlistbot at llvm.org
Wed Mar 26 13:10:32 PDT 2025


https://github.com/anchuraj created https://github.com/llvm/llvm-project/pull/133149

None

>From af71598bf25890f111dc47676edc381f5389421e Mon Sep 17 00:00:00 2001
From: Anchu Rajendran <asudhaku at amd.com>
Date: Tue, 25 Mar 2025 11:21:37 -0500
Subject: [PATCH] scan lowering changes

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  58 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 338 +++++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 241 ++++++++++---
 3 files changed, 575 insertions(+), 62 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 28909cef4748d..3e5916e18a444 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -30,6 +30,7 @@
 namespace llvm {
 class CanonicalLoopInfo;
 struct TargetRegionEntryInfo;
+class ScanInfo;
 class OffloadEntriesInfoManager;
 class OpenMPIRBuilder;
 
@@ -728,6 +729,11 @@ class OpenMPIRBuilder {
                       LoopBodyGenCallbackTy BodyGenCB, Value *TripCount,
                       const Twine &Name = "loop");
 
+  Expected<SmallVector<llvm::CanonicalLoopInfo *>> createCanonicalScanLoops(
+      const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+      Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+      InsertPointTy ComputeIP, const Twine &Name, bool InScan);
+
   /// Calculate the trip count of a canonical loop.
   ///
   /// This allows specifying user-defined loop counter values using increment,
@@ -800,10 +806,12 @@ class OpenMPIRBuilder {
   ///
   /// \returns An object representing the created control flow structure which
   ///          can be used for loop-associated directives.
-  Expected<CanonicalLoopInfo *> createCanonicalLoop(
-      const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
-      Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-      InsertPointTy ComputeIP = {}, const Twine &Name = "loop");
+  Expected<CanonicalLoopInfo *>
+  createCanonicalLoop(const LocationDescription &Loc,
+                      LoopBodyGenCallbackTy BodyGenCB, Value *Start,
+                      Value *Stop, Value *Step, bool IsSigned,
+                      bool InclusiveStop, InsertPointTy ComputeIP = {},
+                      const Twine &Name = "loop", bool InScan = false);
 
   /// Collapse a loop nest into a single loop.
   ///
@@ -1530,6 +1538,16 @@ class OpenMPIRBuilder {
       ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
       Function *ReduceFn, AttributeList FuncAttrs);
 
+  llvm::CallInst *emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+                                          ArrayRef<llvm::Value *> args,
+                                          const llvm::Twine &name);
+
+  void createScanBBs();
+  void emitScanBasedDirectiveDeclsIR(llvm::Value *span,
+                                     ArrayRef<llvm::Value *> ScanVars);
+  void emitScanBasedDirectiveFinalsIR(
+      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
   /// This function emits a helper that gathers Reduce lists from the first
   /// lane of every active warp to lanes in the first warp.
   ///
@@ -1981,7 +1999,7 @@ class OpenMPIRBuilder {
   ///                           in reductions.
   /// \param ReductionInfos     A list of info on each reduction variable.
   /// \param IsNoWait           A flag set if the reduction is marked as nowait.
-  /// \param IsByRef            A flag set if the reduction is using reference
+  /// \param IsByRef            At flag set if the reduction is using reference
   /// or direct value.
   InsertPointOrErrorTy createReductions(const LocationDescription &Loc,
                                         InsertPointTy AllocaIP,
@@ -2177,7 +2195,6 @@ class OpenMPIRBuilder {
   // block, if possible, or else at the end of the function. Also add a branch
   // from current block to BB if current block does not have a terminator.
   void emitBlock(BasicBlock *BB, Function *CurFn, bool IsFinished = false);
-
   /// Emits code for OpenMP 'if' clause using specified \a BodyGenCallbackTy
   /// Here is the logic:
   /// if (Cond) {
@@ -2602,7 +2619,18 @@ class OpenMPIRBuilder {
   InsertPointOrErrorTy createMasked(const LocationDescription &Loc,
                                     BodyGenCallbackTy BodyGenCB,
                                     FinalizeCallbackTy FiniCB, Value *Filter);
-
+  InsertPointOrErrorTy emitScanReduction(
+      const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+      SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos);
+
+  Expected<SmallVector<llvm::CanonicalLoopInfo *>> emitScanBasedDirectiveIR(
+      llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+      llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+          SecondGen);
+  InsertPointOrErrorTy createScan(const LocationDescription &Loc,
+                                  InsertPointTy AllocaIP,
+                                  ArrayRef<llvm::Value *> ScanVars,
+                                  const Twine &Name, bool IsInclusive);
   /// Generator for '#omp critical'
   ///
   /// \param Loc The insert and source location description.
@@ -3711,6 +3739,22 @@ class CanonicalLoopInfo {
   void invalidate();
 };
 
+class ScanInfo {
+public:
+  llvm::BasicBlock *OMPBeforeScanBlock = nullptr;
+  llvm::BasicBlock *OMPAfterScanBlock = nullptr;
+  llvm::BasicBlock *OMPScanExitBlock = nullptr;
+  llvm::BasicBlock *OMPScanDispatch = nullptr;
+  llvm::BasicBlock *OMPScanLoopExit = nullptr;
+  bool OMPFirstScanLoop = false;
+  llvm::SmallDenseMap<llvm::Value *, llvm::Value *> ReductionVarToScanBuffs;
+  SmallVector<llvm::Value *> privateReductionVariables;
+  SmallVector<llvm::Value *> originalReductionVariables;
+  llvm::Value *iv;
+  llvm::Value *span;
+  SmallVector<llvm::BasicBlock *> continueBlocks;
+};
+
 } // end namespace llvm
 
 #endif // LLVM_FRONTEND_OPENMP_OMPIRBUILDER_H
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2e5ce5308eea5..c74727585d4d9 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -67,6 +67,8 @@
 using namespace llvm;
 using namespace omp;
 
+class ScanInfo scanInfo;
+
 static cl::opt<bool>
     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
                          cl::desc("Use optimistic attributes describing "
@@ -80,6 +82,7 @@ static cl::opt<double> UnrollThresholdFactor(
     cl::init(1.5));
 
 #ifndef NDEBUG
+
 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
 /// an InsertPoint stores the instruction before something is inserted. For
@@ -3918,6 +3921,272 @@ OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
                               /*Conditional*/ true, /*hasFinalize*/ true);
 }
 
+llvm::CallInst *
+OpenMPIRBuilder::emitNoUnwindRuntimeCall(llvm::FunctionCallee callee,
+                                         ArrayRef<llvm::Value *> args,
+                                         const llvm::Twine &name) {
+  llvm::CallInst *call = Builder.CreateCall(
+      callee, args, SmallVector<llvm::OperandBundleDef, 1>(), name);
+  call->setDoesNotThrow();
+  return call;
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createScan(
+    const LocationDescription &Loc, InsertPointTy AllocaIP,
+    ArrayRef<llvm::Value *> ScanVars, const Twine &Name, bool IsInclusive) {
+  if (scanInfo.OMPFirstScanLoop) {
+    Builder.restoreIP(AllocaIP);
+    emitScanBasedDirectiveDeclsIR(scanInfo.span, ScanVars);
+  }
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+  llvm::Value *iv = scanInfo.iv;
+
+  if (scanInfo.OMPFirstScanLoop) {
+    // Emit buffer[i] = red; at the end of the input phase.
+    for (int i = 0; i < ScanVars.size(); i++) {
+      auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+
+      auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+      auto val = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+      auto src = Builder.CreateLoad(destTy, ScanVars[i]);
+      auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          val, destTy->getPointerTo(defaultAS));
+
+      Builder.CreateStore(src, dest);
+    }
+  }
+  Builder.CreateBr(scanInfo.OMPScanLoopExit);
+  llvm::LLVMContext &llvmContext = Builder.getContext();
+  Builder.SetInsertPoint(scanInfo.OMPScanDispatch);
+
+  ConstantInt *Zero = ConstantInt::get(Builder.getInt32Ty(), 0);
+  for (int i = 0; i < ScanVars.size(); i++) {
+    auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+    auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        ScanVars[i], destTy->getPointerTo(defaultAS));
+    Builder.CreateStore(Zero, dest);
+  }
+
+  if (!scanInfo.OMPFirstScanLoop) {
+    iv = scanInfo.iv;
+    // Emit red = buffer[i]; at the entrance to the scan phase.
+    for (int i = 0; i < ScanVars.size(); i++) {
+      // x = buffer[i]
+      auto buff = scanInfo.ReductionVarToScanBuffs[ScanVars[i]];
+      auto destTy = Builder.getInt32Ty(); // ScanVars[i]->getType();
+      auto newVPtr = Builder.CreateInBoundsGEP(destTy, buff, iv, "arrayOffset");
+      auto newV = Builder.CreateLoad(destTy, newVPtr);
+      auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+          ScanVars[i], destTy->getPointerTo(defaultAS));
+
+      Builder.CreateStore(newV, dest);
+    }
+  }
+  llvm::Value *testCondVal1 =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/100);
+  llvm::Value *testCondVal2 =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(llvmContext), /*V=*/0);
+  llvm::Value *CmpI = Builder.CreateICmpUGE(testCondVal1, testCondVal2);
+  if (scanInfo.OMPFirstScanLoop == IsInclusive) {
+    Builder.CreateCondBr(CmpI, scanInfo.OMPBeforeScanBlock,
+                         scanInfo.OMPAfterScanBlock);
+  } else {
+    Builder.CreateCondBr(CmpI, scanInfo.OMPAfterScanBlock,
+                         scanInfo.OMPBeforeScanBlock);
+  }
+  emitBlock(scanInfo.OMPAfterScanBlock, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(scanInfo.OMPAfterScanBlock);
+  return Builder.saveIP();
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveDeclsIR(
+    llvm::Value *span, ArrayRef<llvm::Value *> ScanVars) {
+
+  ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+  llvm::Value *allocSpan = Builder.CreateAdd(span, One);
+  for (auto &scanVar : ScanVars) {
+    // TODO: remove after all users of by-ref are updated to use the alloc
+    // region: Allocate reduction variable (which is a pointer to the real
+    // reduciton variable allocated in the inlined region)
+    llvm::Value *buff =
+        Builder.CreateAlloca(Builder.getInt32Ty(), allocSpan, "vla");
+    scanInfo.ReductionVarToScanBuffs[scanVar] = buff;
+  }
+}
+
+void OpenMPIRBuilder::emitScanBasedDirectiveFinalsIR(
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+  llvm::Value *span = scanInfo.span;
+  // llvm::Value *OMPLast = span;
+  llvm::Value *OMPLast = Builder.CreateNSWAdd(
+      span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+  // llvm::Value *OMPLast = Builder.CreateNSWSub(
+  //     span, llvm::ConstantInt::get(span->getType(), 1, /*isSigned=*/false));
+  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+  for (int i = 0; i < reductionInfos.size(); i++) {
+    auto privateVar = reductionInfos[i].PrivateVariable;
+    auto origVar = reductionInfos[i].Variable;
+    auto buff = scanInfo.ReductionVarToScanBuffs[privateVar];
+    // newV = Builder.CreateLoad(builder.getPtrTy(), newV);
+
+    // if (!offsetIdx.empty())
+    auto srcTy = reductionInfos[i].ElementType;
+    auto val = Builder.CreateInBoundsGEP(srcTy, buff, OMPLast, "arrayOffset");
+    auto src = Builder.CreateLoad(srcTy, val);
+    auto dest = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        origVar, srcTy->getPointerTo(defaultAS));
+
+    Builder.CreateStore(src, dest);
+  }
+}
+
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::emitScanReduction(
+    const LocationDescription &Loc, InsertPointTy &FinalizeIP,
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos) {
+
+  llvm::Value *spanDiff = scanInfo.span;
+
+  if (!updateToLocation(Loc))
+    return Loc.IP;
+  auto curFn = Builder.GetInsertBlock()->getParent();
+  // for (int k = 0; k <= ceil(log2(n)); ++k)
+  llvm::BasicBlock *LoopBB =
+      BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.body");
+  llvm::BasicBlock *ExitBB =
+      BasicBlock::Create(curFn->getContext(), "omp.outer.log.scan.exit");
+  llvm::Function *F = llvm::Intrinsic::getOrInsertDeclaration(
+      Builder.GetInsertBlock()->getModule(),
+      (llvm::Intrinsic::ID)llvm::Intrinsic::log2, Builder.getDoubleTy());
+  llvm::BasicBlock *InputBB = Builder.GetInsertBlock();
+  ConstantInt *One = ConstantInt::get(Builder.getInt32Ty(), 1);
+  llvm::Value *span = Builder.CreateAdd(spanDiff, One);
+  llvm::Value *Arg = Builder.CreateUIToFP(span, Builder.getDoubleTy());
+  llvm::Value *LogVal = emitNoUnwindRuntimeCall(F, Arg, "");
+  F = llvm::Intrinsic::getOrInsertDeclaration(
+      Builder.GetInsertBlock()->getModule(),
+      (llvm::Intrinsic::ID)llvm::Intrinsic::ceil, Builder.getDoubleTy());
+  LogVal = emitNoUnwindRuntimeCall(F, LogVal, "");
+  LogVal = Builder.CreateFPToUI(LogVal, Builder.getInt32Ty());
+  llvm::Value *NMin1 =
+      Builder.CreateNUWSub(span, llvm::ConstantInt::get(span->getType(), 1));
+  Builder.SetInsertPoint(InputBB);
+  Builder.CreateBr(LoopBB);
+  emitBlock(LoopBB, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(LoopBB);
+
+  auto *Counter = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  //// size pow2k = 1;
+  auto *Pow2K = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  Counter->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 0),
+                       InputBB);
+  Pow2K->addIncoming(llvm::ConstantInt::get(Builder.getInt32Ty(), 1), InputBB);
+  //// for (size i = n - 1; i >= 2 ^ k; --i)
+  ////   tmp[i] op= tmp[i-pow2k];
+  llvm::BasicBlock *InnerLoopBB =
+      BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.body");
+  llvm::BasicBlock *InnerExitBB =
+      BasicBlock::Create(curFn->getContext(), "omp.inner.log.scan.exit");
+  llvm::Value *CmpI = Builder.CreateICmpUGE(NMin1, Pow2K);
+  Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+  emitBlock(InnerLoopBB, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(InnerLoopBB);
+  auto *IVal = Builder.CreatePHI(Builder.getInt32Ty(), 2);
+  IVal->addIncoming(NMin1, LoopBB);
+  unsigned int defaultAS = M.getDataLayout().getProgramAddressSpace();
+  for (int i = 0; i < reductionInfos.size(); i++) {
+    auto &reductionVal = reductionInfos[i].PrivateVariable;
+    auto buff = scanInfo.ReductionVarToScanBuffs[reductionVal];
+    auto destTy = reductionInfos[i].ElementType;
+    Value *IV = Builder.CreateAdd(IVal, Builder.getInt32(1));
+    auto lhsPtr = Builder.CreateInBoundsGEP(destTy, buff, IV, "arrayOffset");
+    auto offsetIval = Builder.CreateNUWSub(IV, Pow2K);
+    auto rhsPtr =
+        Builder.CreateInBoundsGEP(destTy, buff, offsetIval, "arrayOffset");
+    auto lhs = Builder.CreateLoad(destTy, lhsPtr);
+    auto rhs = Builder.CreateLoad(destTy, rhsPtr);
+    auto lhsAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
+        lhsPtr, rhs->getType()->getPointerTo(defaultAS));
+    llvm::Value *result;
+    InsertPointOrErrorTy AfterIP =
+        reductionInfos[i].ReductionGen(Builder.saveIP(), lhs, rhs, result);
+    if (!AfterIP)
+      return AfterIP.takeError();
+    Builder.CreateStore(result, lhsAddr);
+  }
+  llvm::Value *NextIVal = Builder.CreateNUWSub(
+      IVal, llvm::ConstantInt::get(Builder.getInt32Ty(), 1));
+  IVal->addIncoming(NextIVal, Builder.GetInsertBlock());
+  CmpI = Builder.CreateICmpUGE(NextIVal, Pow2K);
+  Builder.CreateCondBr(CmpI, InnerLoopBB, InnerExitBB);
+  emitBlock(InnerExitBB, Builder.GetInsertBlock()->getParent());
+  llvm::Value *Next = Builder.CreateNUWAdd(
+      Counter, llvm::ConstantInt::get(Counter->getType(), 1));
+  Counter->addIncoming(Next, Builder.GetInsertBlock());
+  // pow2k <<= 1;
+  llvm::Value *NextPow2K = Builder.CreateShl(Pow2K, 1, "", /*HasNUW=*/true);
+  Pow2K->addIncoming(NextPow2K, Builder.GetInsertBlock());
+  llvm::Value *Cmp = Builder.CreateICmpNE(Next, LogVal);
+  Builder.CreateCondBr(Cmp, LoopBB, ExitBB);
+  emitBlock(ExitBB, Builder.GetInsertBlock()->getParent());
+  Builder.SetInsertPoint(ExitBB);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+      createBarrier(Builder.saveIP(), llvm::omp::OMPD_barrier);
+
+  Builder.restoreIP(FinalizeIP);
+  emitScanBasedDirectiveFinalsIR(reductionInfos);
+  FinalizeIP = Builder.saveIP();
+
+  return afterIP;
+}
+
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::emitScanBasedDirectiveIR(
+    llvm::function_ref<Expected<CanonicalLoopInfo *>()> FirstGen,
+    llvm::function_ref<Expected<CanonicalLoopInfo *>(LocationDescription loc)>
+        SecondGen) {
+
+  SmallVector<llvm::CanonicalLoopInfo *> ret;
+  {
+    // Emit loop with input phase:
+    // #pragma omp ...
+    // for (i: 0..<num_iters>) {
+    //   <input phase>;
+    //   buffer[i] = red;
+    // }
+    scanInfo.OMPFirstScanLoop = true;
+    auto result = FirstGen();
+    if (result.takeError())
+      return result.takeError();
+    Builder.restoreIP((*result)->getAfterIP());
+    ret.push_back(*result);
+  }
+  {
+    scanInfo.OMPFirstScanLoop = false;
+    auto result = SecondGen(Builder.saveIP());
+    if (result.takeError())
+      return result.takeError();
+    Builder.restoreIP((*result)->getAfterIP());
+    ret.push_back(*result);
+  }
+  return ret;
+}
+
+void OpenMPIRBuilder::createScanBBs() {
+  auto fun = Builder.GetInsertBlock()->getParent();
+  scanInfo.OMPScanExitBlock =
+      BasicBlock::Create(fun->getContext(), "omp.exit.inscan.bb");
+  scanInfo.OMPScanDispatch =
+      BasicBlock::Create(fun->getContext(), "omp.inscan.dispatch");
+  scanInfo.OMPAfterScanBlock =
+      BasicBlock::Create(fun->getContext(), "omp.after.scan.bb");
+  scanInfo.OMPBeforeScanBlock =
+      BasicBlock::Create(fun->getContext(), "omp.before.scan.bb");
+  scanInfo.OMPScanLoopExit =
+      BasicBlock::Create(fun->getContext(), "omp.scan.loop.exit");
+}
+
 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
     BasicBlock *PostInsertBefore, const Twine &Name) {
@@ -4015,10 +4284,72 @@ OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
   return CL;
 }
 
+Expected<SmallVector<llvm::CanonicalLoopInfo *>>
+OpenMPIRBuilder::createCanonicalScanLoops(
+    const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
+    Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
+    InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
+  auto *IndVarTy = cast<IntegerType>(Start->getType());
+  assert(IndVarTy == Stop->getType() && "Stop type mismatch");
+  assert(IndVarTy == Step->getType() && "Step type mismatch");
+  LocationDescription ComputeLoc =
+      ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
+  updateToLocation(ComputeLoc);
+
+  ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
+
+  // Distance between Start and Stop; always positive.
+  Value *Span;
+
+  // Condition whether there are no iterations are executed at all, e.g. because
+  // UB < LB.
+
+  if (IsSigned) {
+    // Ensure that increment is positive. If not, negate and invert LB and UB.
+    Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
+    Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
+    Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
+    Span = Builder.CreateSub(UB, LB, "", false, true);
+  } else {
+    Span = Builder.CreateSub(Stop, Start, "", true);
+  }
+  auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
+    if (InScan) {
+      scanInfo.iv = IV;
+      createScanBBs();
+      auto terminator = Builder.GetInsertBlock()->getTerminator();
+      scanInfo.continueBlocks.push_back(terminator->getSuccessor(0));
+      terminator->setSuccessor(0, scanInfo.OMPScanDispatch);
+      emitBlock(scanInfo.OMPBeforeScanBlock,
+                Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(scanInfo.OMPScanLoopExit);
+
+      emitBlock(scanInfo.OMPScanLoopExit,
+                Builder.GetInsertBlock()->getParent());
+      Builder.CreateBr(scanInfo.continueBlocks.back());
+      emitBlock(scanInfo.OMPScanDispatch,
+                Builder.GetInsertBlock()->getParent());
+      Builder.SetInsertPoint(
+          scanInfo.OMPBeforeScanBlock->getFirstInsertionPt());
+    }
+    return BodyGenCB(Builder.saveIP(), IV);
+  };
+  const auto &&FirstGen = [&]() {
+    return createCanonicalLoop(Loc, BodyGen, Start, Stop, Step, IsSigned,
+                               InclusiveStop, ComputeIP, Name, true);
+  };
+  const auto &&SecondGen = [&](LocationDescription loc) {
+    return createCanonicalLoop(loc, BodyGen, Start, Stop, Step, IsSigned,
+                               InclusiveStop, ComputeIP, Name, true);
+  };
+  scanInfo.span = Span;
+  auto result = emitScanBasedDirectiveIR(FirstGen, SecondGen);
+  return result;
+}
+
 Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
     const LocationDescription &Loc, Value *Start, Value *Stop, Value *Step,
     bool IsSigned, bool InclusiveStop, const Twine &Name) {
-
   // Consider the following difficulties (assuming 8-bit signed integers):
   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
   //      DO I = 1, 100, 50
@@ -4078,7 +4409,7 @@ Value *OpenMPIRBuilder::calculateCanonicalLoopTripCount(
 Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
-    InsertPointTy ComputeIP, const Twine &Name) {
+    InsertPointTy ComputeIP, const Twine &Name, bool InScan) {
   LocationDescription ComputeLoc =
       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
 
@@ -4089,6 +4420,9 @@ Expected<CanonicalLoopInfo *> OpenMPIRBuilder::createCanonicalLoop(
     Builder.restoreIP(CodeGenIP);
     Value *Span = Builder.CreateMul(IV, Step);
     Value *IndVar = Builder.CreateAdd(Span, Start);
+    if (InScan) {
+      scanInfo.iv = IndVar;
+    }
     return BodyGenCB(Builder.saveIP(), IndVar);
   };
   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..d5aabf8b75ae1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -86,7 +86,9 @@ class OpenMPLoopInfoStackFrame
     : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
 public:
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+  // For constructs like scan, one Loop info frame can contain multiple
+  // Canonical Loops
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
 };
 
 /// Custom error class to signal translation errors that don't need reporting,
@@ -232,8 +234,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
           op.getReductionSyms())
         result = todo("reduction");
     if (op.getReductionMod() &&
-        op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
-      result = todo("reduction with modifier");
+        op.getReductionMod().value() == omp::ReductionModifier::task)
+      result = todo("reduction with task modifier");
   };
   auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -383,15 +385,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
 /// Find the loop information structure for the loop nest being translated. It
 /// will return a `null` value unless called from the translation function for
 /// a loop wrapper operation after successfully translating its body.
-static llvm::CanonicalLoopInfo *
-findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
-  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+static SmallVector<llvm::CanonicalLoopInfo *>
+findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        loopInfo = frame.loopInfo;
+        loopInfos = frame.loopInfos;
         return WalkResult::interrupt();
       });
-  return loopInfo;
+  return loopInfos;
 }
 
 /// Converts the given region that appears within an OpenMP dialect operation to
@@ -2258,26 +2260,61 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-
-  llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-      ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-          convertToScheduleKind(schedule), chunk, isSimd,
-          scheduleMod == omp::ScheduleModifier::monotonic,
-          scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType);
-
-  if (failed(handleError(wsloopIP, opInst)))
-    return failure();
-
-  // Process the reductions if required.
-  if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
-                                        allocaIP, reductionDecls,
-                                        privateReductionVariables, isByRef)))
-    return failure();
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+  auto beforeLoopFinishIp = loopInfos.front()->getAfterIP();
+  auto afterLoopFinishIp = loopInfos.back()->getAfterIP();
+  bool isInScanRegion =
+      wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                     mlir::omp::ReductionModifier::inscan);
+  if (isInScanRegion) {
+    builder.restoreIP(beforeLoopFinishIp);
+    SmallVector<OwningReductionGen> owningReductionGens;
+    SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
+    SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
+    collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
+                         owningReductionGens, owningAtomicReductionGens,
+                         privateReductionVariables, reductionInfos);
+    llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
+        ompBuilder->emitScanReduction(builder.saveIP(), afterLoopFinishIp,
+                                      reductionInfos);
+    if (failed(handleError(redIP, opInst)))
+      return failure();
 
+    builder.restoreIP(*redIP);
+    builder.CreateBr(cont);
+  }
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+        ompBuilder->applyWorkshareLoop(
+            ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+            convertToScheduleKind(schedule), chunk, isSimd,
+            scheduleMod == omp::ScheduleModifier::monotonic,
+            scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+            workshareLoopType);
+
+    if (failed(handleError(wsloopIP, opInst)))
+      return failure();
+  }
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  if (isInScanRegion) {
+    SmallVector<Region *> reductionRegions;
+    llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
+                    [](omp::DeclareReductionOp reductionDecl) {
+                      return &reductionDecl.getCleanupRegion();
+                    });
+    if (failed(inlineOmpRegionCleanup(
+            reductionRegions, privateReductionVariables, moduleTranslation,
+            builder, "omp.reduction.cleanup")))
+      return failure();
+  } else {
+    // Process the reductions if required.
+    if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
+                                          allocaIP, reductionDecls,
+                                          privateReductionVariables, isByRef)))
+      return failure();
+  }
   return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
                             privateVarsInfo.llvmVars,
                             privateVarsInfo.privatizers);
@@ -2467,6 +2504,59 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
   llvm_unreachable("Unknown ClauseOrderKind kind");
 }
 
+static LogicalResult
+convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
+               LLVM::ModuleTranslation &moduleTranslation) {
+  if (failed(checkImplementationStatus(opInst)))
+    return failure();
+  auto scanOp = cast<omp::ScanOp>(opInst);
+  bool isInclusive = scanOp.hasInclusiveVars();
+  SmallVector<llvm::Value *> llvmScanVars;
+  mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
+  if (!isInclusive)
+    mlirScanVars = scanOp.getExclusiveVars();
+  for (auto val : mlirScanVars) {
+    llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
+    llvmScanVars.push_back(llvmVal);
+  }
+  llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
+      findAllocaInsertPoint(builder, moduleTranslation);
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+      moduleTranslation.getOpenMPBuilder()->createScan(
+          ompLoc, allocaIP, llvmScanVars, "scan", isInclusive);
+  if (failed(handleError(afterIP, opInst)))
+    return failure();
+
+  builder.restoreIP(*afterIP);
+
+  // TODO: The argument of LoopnestOp is stored into the index variable and this
+  // variable is used
+  //  across scan operation. However that makes the mlir
+  //  invalid.(`Intra-iteration dependences from a statement in the structured
+  //  block sequence that precede a scan directive to a statement in the
+  //  structured block sequence that follows a scan directive must not exist,
+  //  except for dependences for the list items specified in an inclusive or
+  //  exclusive clause.`). The argument of LoopNestOp need to be loaded again
+  //  after ScanOp again so mlir generated is valid.
+  auto parentOp = scanOp->getParentOp();
+  auto loopOp = cast<omp::LoopNestOp>(parentOp);
+  if (loopOp) {
+    auto &firstBlock = *(scanOp->getParentRegion()->getBlocks()).begin();
+    auto &ins = *(firstBlock.begin());
+    if (isa<LLVM::StoreOp>(ins)) {
+      LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
+      auto src = moduleTranslation.lookupValue(storeOp->getOperand(0));
+      if (src == moduleTranslation.lookupValue(
+                     (loopOp.getRegion().getArguments())[0])) {
+        auto dest = moduleTranslation.lookupValue(storeOp->getOperand(1));
+        builder.CreateStore(src, dest);
+      }
+    }
+  }
+  return success();
+}
+
 /// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2540,12 +2630,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     return failure();
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
-  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
-  ompBuilder->applySimd(loopInfo, alignedVars,
-                        simdOp.getIfExpr()
-                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
-                            : nullptr,
-                        order, simdlen, safelen);
+  SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+      findCurrentLoopInfos(moduleTranslation);
+  for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+    ompBuilder->applySimd(
+        loopInfo, alignedVars,
+        simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                           : nullptr,
+        order, simdlen, safelen);
+  }
 
   return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
                             privateVarsInfo.llvmVars,
@@ -2612,16 +2705,52 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
                                                        ompLoc.DL);
       computeIP = loopInfos.front()->getPreheaderIP();
     }
+    if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>()) {
+      bool isInScanRegion =
+          wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
+                                         mlir::omp::ReductionModifier::inscan);
+      // TODOSCAN: Take care of loop and add asserts if required
+      if (isInScanRegion) {
+        llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
+            ompBuilder->createCanonicalScanLoops(
+                loc, bodyGen, lowerBound, upperBound, step,
+                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP, "loop",
+                isInScanRegion);
+
+        if (failed(handleError(loopResults, *loopOp)))
+          return failure();
+        auto beforeLoop = loopResults->front();
+        auto afterLoop = loopResults->back();
+        moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+            [&](OpenMPLoopInfoStackFrame &frame) {
+              frame.loopInfos.push_back(beforeLoop);
+              frame.loopInfos.push_back(afterLoop);
+              return WalkResult::interrupt();
+            });
+        builder.restoreIP(afterLoop->getAfterIP());
+        return success();
+      } else {
+        llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+            ompBuilder->createCanonicalLoop(
+                loc, bodyGen, lowerBound, upperBound, step,
+                /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
 
-    llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
-        ompBuilder->createCanonicalLoop(
-            loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+        if (failed(handleError(loopResult, *loopOp)))
+          return failure();
 
-    if (failed(handleError(loopResult, *loopOp)))
-      return failure();
+        loopInfos.push_back(*loopResult);
+      }
+    } else {
+      llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
+          ompBuilder->createCanonicalLoop(
+              loc, bodyGen, lowerBound, upperBound, step,
+              /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
+
+      if (failed(handleError(loopResult, *loopOp)))
+        return failure();
 
-    loopInfos.push_back(*loopResult);
+      loopInfos.push_back(*loopResult);
+    }
   }
 
   // Collapse loops. Store the insertion point because LoopInfos may get
@@ -2633,7 +2762,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
   // after applying transformations.
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
-        frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+        frame.loopInfos.push_back(
+            ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}));
         return WalkResult::interrupt();
       });
 
@@ -4212,18 +4342,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
       bool loopNeedsBarrier = false;
       llvm::Value *chunk = nullptr;
 
-      llvm::CanonicalLoopInfo *loopInfo =
-          findCurrentLoopInfo(moduleTranslation);
-      llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
-          ompBuilder->applyWorkshareLoop(
-              ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
-              convertToScheduleKind(schedule), chunk, isSimd,
-              scheduleMod == omp::ScheduleModifier::monotonic,
-              scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType);
-
-      if (!wsloopIP)
-        return wsloopIP.takeError();
+      SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
+          findCurrentLoopInfos(moduleTranslation);
+      for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
+        llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
+            ompBuilder->applyWorkshareLoop(
+                ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
+                convertToScheduleKind(schedule), chunk, isSimd,
+                scheduleMod == omp::ScheduleModifier::monotonic,
+                scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
+                workshareLoopType);
+
+        if (!wsloopIP)
+          return wsloopIP.takeError();
+      }
     }
 
     if (failed(cleanupPrivateVars(builder, moduleTranslation,
@@ -5202,6 +5334,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
           .Case([&](omp::WsloopOp) {
             return convertOmpWsloop(*op, builder, moduleTranslation);
           })
+          .Case([&](omp::ScanOp) {
+            return convertOmpScan(*op, builder, moduleTranslation);
+          })
           .Case([&](omp::SimdOp) {
             return convertOmpSimd(*op, builder, moduleTranslation);
           })



More information about the Mlir-commits mailing list