[flang] [llvm] [mlir] [llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop (PR #139386)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 10 08:28:31 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (NimishMishra)

<details>
<summary>Changes</summary>

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).

---

Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff


10 Files Affected:

- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1) 
- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2) 
- (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2) 
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88) 
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/139386


More information about the llvm-commits mailing list