[Mlir-commits] [mlir] c83bdc7 - [MLIR][OpenMP] Normalize lowering of omp.loop_nest (#127217)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 24 05:38:31 PST 2025


Author: Sergio Afonso
Date: 2025-02-24T13:38:27Z
New Revision: c83bdc7c111f8100d8fdf9e7a87d05b5a1a3ae94

URL: https://github.com/llvm/llvm-project/commit/c83bdc7c111f8100d8fdf9e7a87d05b5a1a3ae94
DIFF: https://github.com/llvm/llvm-project/commit/c83bdc7c111f8100d8fdf9e7a87d05b5a1a3ae94.diff

LOG: [MLIR][OpenMP] Normalize lowering of omp.loop_nest (#127217)

This patch refactors the translation of `omp.loop_nest` operations into
LLVM IR so that it is handled similarly to other operations. Before this
change, the responsibility of translating the loop nest fell into each
loop wrapper, causing code duplication. This patch centralizes that
handling of the loop. One consequence of this was fixing an issue
lowering non-inclusive `omp.simd` loops.

As a result, it is now expected that the handling of composite
constructs is performed collaboratively among translating functions for
each operation involved. At the moment, only `do/for simd` is supported
by ignoring SIMD information, and this behavior is preserved.

The translation of loop wrapper operations needs access to the
`llvm::CanonicalLoopInfo` loop information structure in order to apply
transformations to it. This is now created in the nested call to
`convertOmpLoopNest`, so it needs to be passed up to all associated loop
wrapper translation functions. This is done via the creation of an
`OpenMPLoopInfoStackFrame` within `convertHostOrTargetOperation`,
associated to the outermost loop wrapper. This structure is updated by
`convertOmpLoopNest`, making the result available to all loop wrappers
after their body has been translated.

Added: 
    

Modified: 
    mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Target/LLVMIR/openmp-llvm.mlir
    mlir/test/Target/LLVMIR/openmp-simd-private.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index e86d576bdb241..eb59ef8c62266 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -290,13 +290,12 @@ class ModuleTranslation {
   /// Calls `callback` for every ModuleTranslation stack frame of type `T`
   /// starting from the top of the stack.
   template <typename T>
-  WalkResult
-  stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
+  WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
     static_assert(std::is_base_of<StackFrame, T>::value,
                   "expected T derived from StackFrame");
     if (!callback)
       return WalkResult::skip();
-    for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
+    for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
       if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
         WalkResult result = callback(*ptr);
         if (result.wasInterrupted())

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 758cdfece6f80..6883d78cd317d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -75,6 +75,16 @@ class OpenMPAllocaStackFrame
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
 };
 
+/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
+/// collapsed canonical loop information corresponding to an \c omp.loop_nest
+/// operation.
+class OpenMPLoopInfoStackFrame
+    : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
+public:
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
+  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+};
+
 /// Custom error class to signal translation errors that don't need reporting,
 /// since encountering them will have already triggered relevant error messages.
 ///
@@ -335,13 +345,13 @@ static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
 /// normal operations in the builder.
 static llvm::OpenMPIRBuilder::InsertPointTy
 findAllocaInsertPoint(llvm::IRBuilderBase &builder,
-                      const LLVM::ModuleTranslation &moduleTranslation) {
+                      LLVM::ModuleTranslation &moduleTranslation) {
   // If there is an alloca insertion point on stack, i.e. we are in a nested
   // operation and a specific point was provided by some surrounding operation,
   // use it.
   llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
   WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
-      [&](const OpenMPAllocaStackFrame &frame) {
+      [&](OpenMPAllocaStackFrame &frame) {
         allocaInsertPoint = frame.allocaInsertPoint;
         return WalkResult::interrupt();
       });
@@ -372,6 +382,20 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
       &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
 }
 
+/// Find the loop information structure for the loop nest being translated. It
+/// will return a `null` value unless called from the translation function for
+/// a loop wrapper operation after successfully translating its body.
+static llvm::CanonicalLoopInfo *
+findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::CanonicalLoopInfo *loopInfo = nullptr;
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        loopInfo = frame.loopInfo;
+        return WalkResult::interrupt();
+      });
+  return loopInfo;
+}
+
 /// Converts the given region that appears within an OpenMP dialect operation to
 /// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
 /// region, and a branch from any block with an successor-less OpenMP terminator
@@ -522,7 +546,7 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
 /// This must be called after block arguments of parent wrappers have already
 /// been mapped to LLVM IR values.
 static LogicalResult
-convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
+convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
                       LLVM::ModuleTranslation &moduleTranslation) {
   // Map block arguments directly to the LLVM value associated to the
   // corresponding operand. This is semantically equivalent to this wrapper not
@@ -544,34 +568,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
         return success();
       })
       .Default([&](Operation *op) {
-        return op->emitError() << "cannot ignore nested wrapper";
+        return op->emitError() << "cannot ignore wrapper";
       });
 }
 
-/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
-/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
-/// entry block arguments defined by these operations to outside values.
-///
-/// It must be called after block arguments of \c parentOp have already been
-/// mapped themselves.
-static LogicalResult
-convertIgnoredWrappers(omp::LoopNestOp loopOp,
-                       omp::LoopWrapperInterface parentOp,
-                       LLVM::ModuleTranslation &moduleTranslation) {
-  SmallVector<omp::LoopWrapperInterface> wrappers;
-  loopOp.gatherWrappers(wrappers);
-
-  // Process wrappers nested inside of `parentOp` from outermost to innermost.
-  for (auto it =
-           std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
-       it != wrappers.rend(); ++it) {
-    if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
-      return failure();
-  }
-
-  return success();
-}
-
 /// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
 static LogicalResult
 convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -1889,6 +1889,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
 static LogicalResult
 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto wsloopOp = cast<omp::WsloopOp>(opInst);
   if (failed(checkImplementationStatus(opInst)))
     return failure();
@@ -1969,90 +1970,25 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                                reductionVariableMap, isByRef, deferredStores)))
     return failure();
 
-  // TODO: Replace this with proper composite translation support.
-  // Currently, all nested wrappers are ignored, so 'do/for simd' will be
-  // treated the same as a standalone 'do/for'. This is allowed by the spec,
-  // since it's equivalent to always using a SIMD length of 1.
-  if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
-    return failure();
-
-  // Set up the source location value for OpenMP runtime.
-  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-
-  // Generator of the canonical loop body.
-  SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
-  SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
-  auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
-                     llvm::Value *iv) -> llvm::Error {
-    // Make sure further conversions know about the induction variable.
-    moduleTranslation.mapValue(
-        loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
-
-    // Capture the body insertion point for use in nested loops. BodyIP of the
-    // CanonicalLoopInfo always points to the beginning of the entry block of
-    // the body.
-    bodyInsertPoints.push_back(ip);
-
-    if (loopInfos.size() != loopOp.getNumLoops() - 1)
-      return llvm::Error::success();
-
-    // Convert the body of the loop.
-    builder.restoreIP(ip);
-    return convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
-                               moduleTranslation)
-        .takeError();
-  };
-
-  // Delegate actual loop construction to the OpenMP IRBuilder.
-  // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
-  // loop, i.e. it has a positive step, uses signed integer semantics.
-  // Reconsider this code when the nested loop operation clearly supports more
-  // cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
-  for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
-    llvm::Value *lowerBound =
-        moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
-    llvm::Value *upperBound =
-        moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
-    llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
-
-    // Make sure loop trip count are emitted in the preheader of the outermost
-    // loop at the latest so that they are all available for the new collapsed
-    // loop will be created below.
-    llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
-    llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
-    if (i != 0) {
-      loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
-      computeIP = loopInfos.front()->getPreheaderIP();
-    }
-
-    llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
-        ompBuilder->createCanonicalLoop(
-            loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
-
-    if (failed(handleError(loopResult, *loopOp)))
-      return failure();
-
-    loopInfos.push_back(*loopResult);
-  }
-
-  // Collapse loops. Store the insertion point because LoopInfos may get
-  // invalidated.
-  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
-  llvm::CanonicalLoopInfo *loopInfo =
-      ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
-  allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
-
   // TODO: Handle doacross loops when the ordered clause has a parameter.
   bool isOrdered = wsloopOp.getOrdered().has_value();
   std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
   bool isSimd = wsloopOp.getScheduleSimd();
+  bool loopNeedsBarrier = !wsloopOp.getNowait();
+
+  llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+  llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+      wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
+
+  if (failed(handleError(regionBlock, opInst)))
+    return failure();
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
-          ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
+          ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -2060,12 +1996,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
-  // Continue building IR after the loop. Note that the LoopInfo returned by
-  // `collapseLoops` points inside the outermost loop and is intended for
-  // potential further loop transformations. Use the insertion point stored
-  // before collapsing loops instead.
-  builder.restoreIP(afterIP);
-
   // Process the reductions if required.
   if (failed(createReductionsAndCleanup(wsloopOp, builder, moduleTranslation,
                                         allocaIP, reductionDecls,
@@ -2274,8 +2204,20 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
 static LogicalResult
 convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
                LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto simdOp = cast<omp::SimdOp>(opInst);
-  auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+
+  // TODO: Replace this with proper composite translation support.
+  // Currently, simd information on composite constructs is ignored, so e.g.
+  // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
+  // allowed by the spec, since it's equivalent to using a SIMD length of 1.
+  if (simdOp.isComposite()) {
+    if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
+      return failure();
+
+    return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
+                                   builder, moduleTranslation);
+  }
 
   if (failed(checkImplementationStatus(opInst)))
     return failure();
@@ -2308,6 +2250,61 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
           .failed())
     return failure();
 
+  llvm::ConstantInt *simdlen = nullptr;
+  if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
+    simdlen = builder.getInt64(simdlenVar.value());
+
+  llvm::ConstantInt *safelen = nullptr;
+  if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
+    safelen = builder.getInt64(safelenVar.value());
+
+  llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
+  llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+  llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+  std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+  mlir::OperandRange operands = simdOp.getAlignedVars();
+  for (size_t i = 0; i < operands.size(); ++i) {
+    llvm::Value *alignment = nullptr;
+    llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+    llvm::Type *ty = llvmVal->getType();
+
+    auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
+    alignment = builder.getInt64(intAttr.getInt());
+    assert(ty->isPointerTy() && "Invalid type for aligned variable");
+    assert(alignment && "Invalid alignment value");
+    auto curInsert = builder.saveIP();
+    builder.SetInsertPoint(sourceBlock);
+    llvmVal = builder.CreateLoad(ty, llvmVal);
+    builder.restoreIP(curInsert);
+    alignedVars[llvmVal] = alignment;
+  }
+
+  llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+      simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
+
+  if (failed(handleError(regionBlock, opInst)))
+    return failure();
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+  llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+  ompBuilder->applySimd(loopInfo, alignedVars,
+                        simdOp.getIfExpr()
+                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
+                            : nullptr,
+                        order, simdlen, safelen);
+
+  return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
+                            llvmPrivateVars, privateDecls);
+}
+
+/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
+static LogicalResult
+convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
+                   LLVM::ModuleTranslation &moduleTranslation) {
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+  auto loopOp = cast<omp::LoopNestOp>(opInst);
+
+  // Set up the source location value for OpenMP runtime.
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
 
   // Generator of the canonical loop body.
@@ -2329,9 +2326,13 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
 
     // Convert the body of the loop.
     builder.restoreIP(ip);
-    return convertOmpOpRegions(loopOp.getRegion(), "omp.simd.region", builder,
-                               moduleTranslation)
-        .takeError();
+    llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
+        loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
+    if (!regionBlock)
+      return regionBlock.takeError();
+
+    builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+    return llvm::Error::success();
   };
 
   // Delegate actual loop construction to the OpenMP IRBuilder.
@@ -2339,7 +2340,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
   // loop, i.e. it has a positive step, uses signed integer semantics.
   // Reconsider this code when the nested loop operation clearly supports more
   // cases.
-  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
     llvm::Value *lowerBound =
         moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
@@ -2361,7 +2361,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
         ompBuilder->createCanonicalLoop(
             loc, bodyGen, lowerBound, upperBound, step,
-            /*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
+            /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
 
     if (failed(handleError(loopResult, *loopOp)))
       return failure();
@@ -2369,49 +2369,25 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
     loopInfos.push_back(*loopResult);
   }
 
-  // Collapse loops.
-  llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
-  llvm::CanonicalLoopInfo *loopInfo =
-      ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
-
-  llvm::ConstantInt *simdlen = nullptr;
-  if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
-    simdlen = builder.getInt64(simdlenVar.value());
-
-  llvm::ConstantInt *safelen = nullptr;
-  if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
-    safelen = builder.getInt64(safelenVar.value());
-
-  llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
-  llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
-  llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
-  std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
-  mlir::OperandRange operands = simdOp.getAlignedVars();
-  for (size_t i = 0; i < operands.size(); ++i) {
-    llvm::Value *alignment = nullptr;
-    llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
-    llvm::Type *ty = llvmVal->getType();
-    if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
-      alignment = builder.getInt64(intAttr.getInt());
-      assert(ty->isPointerTy() && "Invalid type for aligned variable");
-      assert(alignment && "Invalid alignment value");
-      auto curInsert = builder.saveIP();
-      builder.SetInsertPoint(sourceBlock->getTerminator());
-      llvmVal = builder.CreateLoad(ty, llvmVal);
-      builder.restoreIP(curInsert);
-      alignedVars[llvmVal] = alignment;
-    }
-  }
-  ompBuilder->applySimd(loopInfo, alignedVars,
-                        simdOp.getIfExpr()
-                            ? moduleTranslation.lookupValue(simdOp.getIfExpr())
-                            : nullptr,
-                        order, simdlen, safelen);
+  // Collapse loops. Store the insertion point because LoopInfos may get
+  // invalidated.
+  llvm::OpenMPIRBuilder::InsertPointTy afterIP =
+      loopInfos.front()->getAfterIP();
+
+  // Update the stack frame created for this loop to point to the resulting loop
+  // after applying transformations.
+  moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
+      [&](OpenMPLoopInfoStackFrame &frame) {
+        frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
+        return WalkResult::interrupt();
+      });
 
+  // Continue building IR after the loop. Note that the LoopInfo returned by
+  // `collapseLoops` points inside the outermost loop and is intended for
+  // potential further loop transformations. Use the insertion point stored
+  // before collapsing loops instead.
   builder.restoreIP(afterIP);
-
-  return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
-                            llvmPrivateVars, privateDecls);
+  return success();
 }
 
 /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
@@ -4705,135 +4681,157 @@ static bool isTargetDeviceOp(Operation *op) {
   return false;
 }
 
-/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
-/// (including OpenMP runtime calls).
+/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
+/// OpenMP runtime calls).
 static LogicalResult
 convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
                              LLVM::ModuleTranslation &moduleTranslation) {
-
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
 
-  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
-      .Case([&](omp::BarrierOp op) -> LogicalResult {
-        if (failed(checkImplementationStatus(*op)))
-          return failure();
+  // For each loop, introduce one stack frame to hold loop information. Ensure
+  // this is only done for the outermost loop wrapper to prevent introducing
+  // multiple stack frames for a single loop. Initially set to null, the loop
+  // information structure is initialized during translation of the nested
+  // omp.loop_nest operation, making it available to translation of all loop
+  // wrappers after their body has been successfully translated.
+  bool isOutermostLoopWrapper =
+      isa_and_present<omp::LoopWrapperInterface>(op) &&
+      !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
 
-        llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
-            ompBuilder->createBarrier(builder.saveIP(),
-                                      llvm::omp::OMPD_barrier);
-        return handleError(afterIP, *op);
-      })
-      .Case([&](omp::TaskyieldOp op) {
-        if (failed(checkImplementationStatus(*op)))
-          return failure();
+  if (isOutermostLoopWrapper)
+    moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
 
-        ompBuilder->createTaskyield(builder.saveIP());
-        return success();
-      })
-      .Case([&](omp::FlushOp op) {
-        if (failed(checkImplementationStatus(*op)))
-          return failure();
-
-        // No support in Openmp runtime function (__kmpc_flush) to accept
-        // the argument list.
-        // OpenMP standard states the following:
-        //  "An implementation may implement a flush with a list by ignoring
-        //   the list, and treating it the same as a flush without a list."
-        //
-        // The argument list is discarded so that, flush with a list is treated
-        // same as a flush without a list.
-        ompBuilder->createFlush(builder.saveIP());
-        return success();
-      })
-      .Case([&](omp::ParallelOp op) {
-        return convertOmpParallel(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::MaskedOp) {
-        return convertOmpMasked(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::MasterOp) {
-        return convertOmpMaster(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::CriticalOp) {
-        return convertOmpCritical(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::OrderedRegionOp) {
-        return convertOmpOrderedRegion(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::OrderedOp) {
-        return convertOmpOrdered(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::WsloopOp) {
-        return convertOmpWsloop(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::SimdOp) {
-        return convertOmpSimd(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::AtomicReadOp) {
-        return convertOmpAtomicRead(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::AtomicWriteOp) {
-        return convertOmpAtomicWrite(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::AtomicUpdateOp op) {
-        return convertOmpAtomicUpdate(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::AtomicCaptureOp op) {
-        return convertOmpAtomicCapture(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::SectionsOp) {
-        return convertOmpSections(*op, builder, moduleTranslation);
-      })
-      .Case([&](omp::SingleOp op) {
-        return convertOmpSingle(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::TeamsOp op) {
-        return convertOmpTeams(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::TaskOp op) {
-        return convertOmpTaskOp(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::TaskgroupOp op) {
-        return convertOmpTaskgroupOp(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::TaskwaitOp op) {
-        return convertOmpTaskwaitOp(op, builder, moduleTranslation);
-      })
-      .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
-            omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
-            omp::CriticalDeclareOp>([](auto op) {
-        // `yield` and `terminator` can be just omitted. The block structure
-        // was created in the region that handles their parent operation.
-        // `declare_reduction` will be used by reductions and is not
-        // converted directly, skip it.
-        // `declare_mapper` and `declare_mapper.info` are handled whenever they
-        // are referred to through a `map` clause.
-        // `critical.declare` is only used to declare names of critical
-        // sections which will be used by `critical` ops and hence can be
-        // ignored for lowering. The OpenMP IRBuilder will create unique
-        // name for critical section names.
-        return success();
-      })
-      .Case([&](omp::ThreadprivateOp) {
-        return convertOmpThreadprivate(*op, builder, moduleTranslation);
-      })
-      .Case<omp::TargetDataOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
-            omp::TargetUpdateOp>([&](auto op) {
-        return convertOmpTargetData(op, builder, moduleTranslation);
-      })
-      .Case([&](omp::TargetOp) {
-        return convertOmpTarget(*op, builder, moduleTranslation);
-      })
-      .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
-          [&](auto op) {
-            // No-op, should be handled by relevant owning operations e.g.
-            // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp etc.
-            // and then discarded
+  auto result =
+      llvm::TypeSwitch<Operation *, LogicalResult>(op)
+          .Case([&](omp::BarrierOp op) -> LogicalResult {
+            if (failed(checkImplementationStatus(*op)))
+              return failure();
+
+            llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
+                ompBuilder->createBarrier(builder.saveIP(),
+                                          llvm::omp::OMPD_barrier);
+            return handleError(afterIP, *op);
+          })
+          .Case([&](omp::TaskyieldOp op) {
+            if (failed(checkImplementationStatus(*op)))
+              return failure();
+
+            ompBuilder->createTaskyield(builder.saveIP());
             return success();
           })
-      .Default([&](Operation *inst) {
-        return inst->emitError() << "not yet implemented: " << inst->getName();
-      });
+          .Case([&](omp::FlushOp op) {
+            if (failed(checkImplementationStatus(*op)))
+              return failure();
+
+            // No support in Openmp runtime function (__kmpc_flush) to accept
+            // the argument list.
+            // OpenMP standard states the following:
+            //  "An implementation may implement a flush with a list by ignoring
+            //   the list, and treating it the same as a flush without a list."
+            //
+            // The argument list is discarded so that, flush with a list is
+            // treated same as a flush without a list.
+            ompBuilder->createFlush(builder.saveIP());
+            return success();
+          })
+          .Case([&](omp::ParallelOp op) {
+            return convertOmpParallel(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::MaskedOp) {
+            return convertOmpMasked(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::MasterOp) {
+            return convertOmpMaster(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::CriticalOp) {
+            return convertOmpCritical(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::OrderedRegionOp) {
+            return convertOmpOrderedRegion(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::OrderedOp) {
+            return convertOmpOrdered(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::WsloopOp) {
+            return convertOmpWsloop(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::SimdOp) {
+            return convertOmpSimd(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::AtomicReadOp) {
+            return convertOmpAtomicRead(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::AtomicWriteOp) {
+            return convertOmpAtomicWrite(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::AtomicUpdateOp op) {
+            return convertOmpAtomicUpdate(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::AtomicCaptureOp op) {
+            return convertOmpAtomicCapture(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::SectionsOp) {
+            return convertOmpSections(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::SingleOp op) {
+            return convertOmpSingle(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::TeamsOp op) {
+            return convertOmpTeams(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::TaskOp op) {
+            return convertOmpTaskOp(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::TaskgroupOp op) {
+            return convertOmpTaskgroupOp(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::TaskwaitOp op) {
+            return convertOmpTaskwaitOp(op, builder, moduleTranslation);
+          })
+          .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
+                omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
+                omp::CriticalDeclareOp>([](auto op) {
+            // `yield` and `terminator` can be just omitted. The block structure
+            // was created in the region that handles their parent operation.
+            // `declare_reduction` will be used by reductions and is not
+            // converted directly, skip it.
+            // `declare_mapper` and `declare_mapper.info` are handled whenever
+            // they are referred to through a `map` clause.
+            // `critical.declare` is only used to declare names of critical
+            // sections which will be used by `critical` ops and hence can be
+            // ignored for lowering. The OpenMP IRBuilder will create unique
+            // name for critical section names.
+            return success();
+          })
+          .Case([&](omp::ThreadprivateOp) {
+            return convertOmpThreadprivate(*op, builder, moduleTranslation);
+          })
+          .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
+                omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
+            return convertOmpTargetData(op, builder, moduleTranslation);
+          })
+          .Case([&](omp::TargetOp) {
+            return convertOmpTarget(*op, builder, moduleTranslation);
+          })
+          .Case([&](omp::LoopNestOp) {
+            return convertOmpLoopNest(*op, builder, moduleTranslation);
+          })
+          .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
+              [&](auto op) {
+                // No-op, should be handled by relevant owning operations e.g.
+                // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
+                // etc. and then discarded
+                return success();
+              })
+          .Default([&](Operation *inst) {
+            return inst->emitError()
+                   << "not yet implemented: " << inst->getName();
+          });
+
+  if (isOutermostLoopWrapper)
+    moduleTranslation.stackPop();
+
+  return result;
 }
 
 static LogicalResult

diff  --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 5f42240cf978e..cf18c07dd605b 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -700,7 +700,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr) {
 // CHECK-LABEL: @simd_simple_multiple
 llvm.func @simd_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
   omp.simd {
-    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+    omp.loop_nest (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) inclusive step (%step1, %step2) {
       %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
       // The form of the emitted IR is controlled by OpenMPIRBuilder and
       // tested there. Just check that the right metadata is added and collapsed

diff  --git a/mlir/test/Target/LLVMIR/openmp-simd-private.mlir b/mlir/test/Target/LLVMIR/openmp-simd-private.mlir
index 40f46103a0ab4..15dfb6dbc0e87 100644
--- a/mlir/test/Target/LLVMIR/openmp-simd-private.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-simd-private.mlir
@@ -16,6 +16,9 @@ omp.private {type = private} @i_privatizer : i32
 // CHECK:         br label %[[ENTRY:.*]]
 
 // CHECK:       [[ENTRY]]:
+// CHECK:         br label %[[OMP_SIMD_REGION:.*]]
+
+// CHECK:       [[OMP_SIMD_REGION]]:
 // CHECK:         br label %[[OMP_LOOP_PREHEADER:.*]]
 
 // CHECK:       [[OMP_LOOP_PREHEADER]]:
@@ -32,9 +35,9 @@ omp.private {type = private} @i_privatizer : i32
 // CHECK:       [[OMP_LOOP_BODY]]:
 // CHECK:         %[[IV_UPDATE:.*]] = mul i32 %[[OMP_LOOP_IV]], 1
 // CHECK:         %[[IV_UPDATE_2:.*]] = add i32 %[[IV_UPDATE]], 1
-// CHECK:         br label %[[OMP_SIMD_REGION:.*]]
+// CHECK:         br label %[[OMP_LOOP_NEST_REGION:.*]]
 
-// CHECK:       [[OMP_SIMD_REGION]]:
+// CHECK:       [[OMP_LOOP_NEST_REGION]]:
 // CHECK:         store i32 %[[IV_UPDATE_2]], ptr %[[PRIV_I]], align 4
 // CHECK:         %[[DUMMY_VAL:.*]] = load float, ptr %[[DUMMY]], align 4
 // CHECK:         %[[PRIV_I_VAL:.*]] = load i32, ptr %[[PRIV_I]], align 4
@@ -83,7 +86,7 @@ omp.private {type = private} @dummy_privatizer : f32
 // CHECK:         %[[PRIV_DUMMY:.*]] = alloca float, align 4
 // CHECK:         %[[PRIV_I:.*]] = alloca i32, align 4
 
-// CHECK:       omp.simd.region:
+// CHECK:       omp.loop_nest.region:
 // CHECK-NOT:     br label
 // CHECK:         store i32 %{{.*}}, ptr %[[PRIV_I]], align 4
 // CHECK:        %{{.*}} = load float, ptr %[[PRIV_DUMMY]], align 4


        


More information about the Mlir-commits mailing list