[flang] [llvm] [mlir] [llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop (PR #139386)

via llvm-commits llvm-commits at lists.llvm.org
Sat May 10 08:28:02 PDT 2025


https://github.com/NimishMishra created https://github.com/llvm/llvm-project/pull/139386

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).

>From e4f3cb2553f8ef03a3ad347cf14a187e31064153 Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Sat, 10 May 2025 19:34:16 +0530
Subject: [PATCH 1/2] [flang][OpenMP] Support MLIR lowering of linear clause
 for omp.wsloop

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    | 34 +++++++++++
 flang/lib/Lower/OpenMP/ClauseProcessor.h      |  1 +
 .../lib/Lower/OpenMP/DataSharingProcessor.cpp |  5 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  4 +-
 flang/test/Lower/OpenMP/wsloop-linear.f90     | 57 +++++++++++++++++++
 5 files changed, 97 insertions(+), 4 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/wsloop-linear.f90

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine

>From 616d6377bbb94bdc742023d71c63ba89df293e3c Mon Sep 17 00:00:00 2001
From: Nimish Mishra <neelam.nimish at gmail.com>
Date: Sat, 10 May 2025 20:51:39 +0530
Subject: [PATCH 2/2] [mlir][llvm][OpenMP] Support translation for linear
 clause in omp.wsloop

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  15 ++
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     |   3 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 185 +++++++++++++++++-
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |  88 +++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  13 --
 5 files changed, 289 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr");
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
+                                                loopInfo->getLastIter());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
+      linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
+                                           index);
+    builder.restoreIP(oldIP);
+  }
+
   // Set the correct branch target for task cancellation
   popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
 
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 32f0ba5b105ff..9ad9e93301239 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -358,6 +358,94 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
 
 // -----
 
+// CHECK-LABEL: wsloop_linear
+
+// CHECK: {{.*}} = alloca i32, i64 1, align 4
+// CHECK: %[[Y:.*]] = alloca i32, i64 1, align 4
+// CHECK: %[[X:.*]] = alloca i32, i64 1, align 4
+
+// CHECK: entry:
+// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4
+// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4
+// CHECK: br label %omp_loop.preheader
+
+// CHECK: omp_loop.preheader:
+// CHECK: %[[LOAD:.*]] = load i32, ptr %[[X]], align 4
+// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4
+// CHECK: %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr @2)
+// CHECK: call void @__kmpc_barrier(ptr @1, i32 %omp_global_thread_num)
+
+// CHECK: omp_loop.body:
+// CHECK: %[[LOOP_IV:.*]] = add i32 %omp_loop.iv, {{.*}}
+// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4
+// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV]], 1
+// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_LOAD]], %[[MUL]]
+// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4
+// CHECK: br label %omp.loop_nest.region
+
+// CHECK: omp.loop_nest.region:
+// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4
+// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_LOAD]], 2
+// CHECK: store i32 %[[ADD]], ptr %[[Y]], align 4
+
+// CHECK: omp_loop.exit:
+// CHECK: call void @__kmpc_for_static_fini(ptr @2, i32 %omp_global_thread_num4)
+// CHECK: %omp_global_thread_num5 = call i32 @__kmpc_global_thread_num(ptr @2)
+// CHECK: call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num5)
+// CHECK: br label %omp_loop.linear_finalization
+
+// CHECK: omp_loop.linear_finalization:
+// CHECK: %[[LAST_ITER:.*]] = load i32, ptr %p.lastiter, align 4
+// CHECK: %[[CMP:.*]] = icmp ne i32 %[[LAST_ITER]], 0
+// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit
+
+// CHECK: omp_loop.linear_lastiter_exit:
+// CHECK: %[[LINEAR_RESULT_LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4
+// CHECK: store i32 %[[LINEAR_RESULT_LOAD]], ptr %[[X]], align 4
+// CHECK: br label %omp_loop.linear_exit
+
+// CHECK: omp_loop.linear_exit:
+// CHECK: %omp_global_thread_num6 = call i32 @__kmpc_global_thread_num(ptr @2)
+// CHECK: call void @__kmpc_barrier(ptr @1, i32 %omp_global_thread_num6)
+// CHECK: br label %omp_loop.after
+
+llvm.func @wsloop_linear() {
+  %0 = llvm.mlir.constant(1 : i64) : i64
+  %1 = llvm.alloca %0 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
+  %2 = llvm.mlir.constant(1 : i64) : i64
+  %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
+  %4 = llvm.mlir.constant(1 : i64) : i64
+  %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
+  %6 = llvm.mlir.constant(1 : i64) : i64
+  %7 = llvm.alloca %6 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
+  %8 = llvm.mlir.constant(2 : i32) : i32
+  %9 = llvm.mlir.constant(10 : i32) : i32
+  %10 = llvm.mlir.constant(1 : i32) : i32
+  %11 = llvm.mlir.constant(1 : i64) : i64
+  %12 = llvm.mlir.constant(1 : i64) : i64
+  %13 = llvm.mlir.constant(1 : i64) : i64
+  %14 = llvm.mlir.constant(1 : i64) : i64
+  omp.wsloop linear(%5 = %10 : !llvm.ptr) {
+    omp.loop_nest (%arg0) : i32 = (%10) to (%9) inclusive step (%10) {
+      llvm.store %arg0, %1 : i32, !llvm.ptr
+      %15 = llvm.load %5 : !llvm.ptr -> i32
+      %16 = llvm.add %15, %8 : i32
+      llvm.store %16, %3 : i32, !llvm.ptr
+      %17 = llvm.add %arg0, %10 : i32
+      %18 = llvm.icmp "sgt" %17, %9 : i32
+      llvm.cond_br %18, ^bb1, ^bb2
+    ^bb1:  // pred: ^bb0
+      llvm.store %17, %1 : i32, !llvm.ptr
+      llvm.br ^bb2
+    ^bb2:  // 2 preds: ^bb0, ^bb1
+      omp.yield
+     }
+   }
+  llvm.return
+}
+
+// -----
+
 // CHECK-LABEL: @wsloop_inclusive_1
 llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
   %0 = llvm.mlir.constant(42 : index) : i64
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 9a83b46efddca..98fccb1a80f67 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -511,19 +511,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // -----
 
-llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
-  // expected-error at below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
-  omp.wsloop linear(%x = %step : !llvm.ptr) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause order in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}



More information about the llvm-commits mailing list