[Mlir-commits] [llvm] [mlir] [MLIR][OpenMP] Add MLIR Lowering Support for dist_schedule (PR #152736)

Jack Styles llvmlistbot at llvm.org
Tue Aug 26 01:08:43 PDT 2025


https://github.com/Stylie777 updated https://github.com/llvm/llvm-project/pull/152736

>From 08ed236a5b93394b27e4a5ceebd442a3d4be90fc Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Tue, 8 Jul 2025 08:57:18 +0100
Subject: [PATCH 1/2] [MLIR][OpenMP] Add support for `dist_schedule`

`dist_schedule` was previously supported in Flang/Clang but was not
implemented in MLIR, instead a user would get a "not yet implemented"
error. This patch adds support for the `dist_schedule` clause to be
lowered to LLVM IR when used in an `omp.distribute` section. Support
is also added for `dist_schedule` to be used when the loop nest is
embedded within a Workshare Loop.

There has needed to be some rework required to ensure that MLIR/LLVM
emits the correct Schedule Type for the clause, as it uses a different
schedule type to other OpenMP directives/clauses in the runtime library.

Add llvm loop metadata

Update implementation to support processing in workshare loop.
---
 llvm/include/llvm/Frontend/OpenMP/OMP.td      |   4 +-
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  47 ++--
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 200 ++++++++++++------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |   1 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  36 ++--
 .../OpenMPToLLVM/convert-to-llvmir.mlir       |  19 ++
 .../Target/LLVMIR/openmp-dist_schedule.mlir   |  30 +++
 .../openmp-dist_schedule_with_wsloop.mlir     |  99 +++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  13 --
 9 files changed, 347 insertions(+), 102 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
 create mode 100644 mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 79f25bb05f20e..4117e112367c6 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -458,7 +458,8 @@ def OMP_SCHEDULE_Dynamic : EnumVal<"dynamic", 3, 1> {}
 def OMP_SCHEDULE_Guided : EnumVal<"guided", 4, 1> {}
 def OMP_SCHEDULE_Auto : EnumVal<"auto", 5, 1> {}
 def OMP_SCHEDULE_Runtime : EnumVal<"runtime", 6, 1> {}
-def OMP_SCHEDULE_Default : EnumVal<"default", 7, 0> { let isDefault = 1; }
+def OMP_SCHEDULE_Distribute : EnumVal<"distribute", 7, 1> {}
+def OMP_SCHEDULE_Default : EnumVal<"default", 8, 0> { let isDefault = 1; }
 def OMPC_Schedule : Clause<[Spelling<"schedule">]> {
   let clangClass = "OMPScheduleClause";
   let flangClass = "OmpScheduleClause";
@@ -469,6 +470,7 @@ def OMPC_Schedule : Clause<[Spelling<"schedule">]> {
     OMP_SCHEDULE_Guided,
     OMP_SCHEDULE_Auto,
     OMP_SCHEDULE_Runtime,
+    OMP_SCHEDULE_Distribute,
     OMP_SCHEDULE_Default
   ];
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f70659120e1e6..41c2e2156736b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1096,11 +1096,17 @@ class OpenMPIRBuilder {
   /// \param NeedsBarrier Indicates whether a barrier must be inserted after
   ///                     the loop.
   /// \param LoopType Type of workshare loop.
+  /// \param HasDistSchedule Defines if the clause being lowered is
+  /// dist_schedule as this is handled slightly differently
+  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute
+  /// loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyStaticWorkshareLoop(
       DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-      omp::WorksharingLoopType LoopType, bool NeedsBarrier);
+      omp::WorksharingLoopType LoopType, bool NeedsBarrier,
+      bool HasDistSchedule = false,
+      omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
 
   /// Modifies the canonical loop a statically-scheduled workshare loop with a
   /// user-specified chunk size.
@@ -1113,13 +1119,22 @@ class OpenMPIRBuilder {
   /// \param NeedsBarrier Indicates whether a barrier must be inserted after the
   ///                     loop.
   /// \param ChunkSize    The user-specified chunk size.
+  /// \param SchedType    Optional type of scheduling to be passed to the init
+  /// function.
+  /// \param DistScheduleChunkSize    The size of dist_shcedule chunk considered
+  /// as a unit when
+  ///                 scheduling. If \p nullptr, defaults to 1.
+  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute
+  /// loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
-  InsertPointOrErrorTy applyStaticChunkedWorkshareLoop(DebugLoc DL,
-                                                       CanonicalLoopInfo *CLI,
-                                                       InsertPointTy AllocaIP,
-                                                       bool NeedsBarrier,
-                                                       Value *ChunkSize);
+  InsertPointOrErrorTy applyStaticChunkedWorkshareLoop(
+      DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+      bool NeedsBarrier, Value *ChunkSize,
+      omp::OMPScheduleType SchedType =
+          omp::OMPScheduleType::UnorderedStaticChunked,
+      Value *DistScheduleChunkSize = nullptr,
+      omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
 
   /// Modifies the canonical loop to be a dynamically-scheduled workshare loop.
   ///
@@ -1139,14 +1154,15 @@ class OpenMPIRBuilder {
   ///                     the loop.
   /// \param Chunk    The size of loop chunk considered as a unit when
   ///                 scheduling. If \p nullptr, defaults to 1.
+  /// \param DistScheduleChunk    The size of dist_shcedule chunk considered as
+  /// a unit when
+  ///                 scheduling. If \p nullptr, defaults to 1.
   ///
   /// \returns Point where to insert code after the workshare construct.
-  InsertPointOrErrorTy applyDynamicWorkshareLoop(DebugLoc DL,
-                                                 CanonicalLoopInfo *CLI,
-                                                 InsertPointTy AllocaIP,
-                                                 omp::OMPScheduleType SchedType,
-                                                 bool NeedsBarrier,
-                                                 Value *Chunk = nullptr);
+  InsertPointOrErrorTy applyDynamicWorkshareLoop(
+      DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+      omp::OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk = nullptr,
+      Value *DistScheduleChunk = nullptr);
 
   /// Create alternative version of the loop to support if clause
   ///
@@ -1197,6 +1213,10 @@ class OpenMPIRBuilder {
   ///                         present.
   /// \param LoopType Information about type of loop worksharing.
   ///                 It corresponds to type of loop workshare OpenMP pragma.
+  /// \param HasDistSchedule Defines if the clause being lowered is
+  /// dist_schedule as this is handled slightly differently
+  ///
+  /// \param ChunkSize The chunk size for dist_schedule loop
   ///
   /// \returns Point where to insert code after the workshare construct.
   LLVM_ABI InsertPointOrErrorTy applyWorkshareLoop(
@@ -1207,7 +1227,8 @@ class OpenMPIRBuilder {
       bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
       bool HasOrderedClause = false,
       omp::WorksharingLoopType LoopType =
-          omp::WorksharingLoopType::ForStaticLoop);
+          omp::WorksharingLoopType::ForStaticLoop,
+      bool HasDistSchedule = false, Value *DistScheduleChunkSize = nullptr);
 
   /// Tile a loop nest.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ea027e48fa2f1..1860ade264740 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -136,6 +136,8 @@ static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
   case OMPScheduleType::NomergeOrderedRuntime:
   case OMPScheduleType::NomergeOrderedAuto:
   case OMPScheduleType::NomergeOrderedTrapezoidal:
+  case OMPScheduleType::OrderedDistributeChunked:
+  case OMPScheduleType::OrderedDistribute:
     break;
   default:
     return false;
@@ -170,7 +172,7 @@ static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
 /// arguments.
 static OMPScheduleType
 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
-                          bool HasSimdModifier) {
+                          bool HasSimdModifier, bool HasDistScheduleChunks) {
   // Currently, the default schedule it static.
   switch (ClauseKind) {
   case OMP_SCHEDULE_Default:
@@ -187,6 +189,9 @@ getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
   case OMP_SCHEDULE_Runtime:
     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
                            : OMPScheduleType::BaseRuntime;
+  case OMP_SCHEDULE_Distribute:
+    return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
+                                 : OMPScheduleType::BaseDistribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
 }
@@ -255,9 +260,10 @@ getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
 static OMPScheduleType
 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
                           bool HasSimdModifier, bool HasMonotonicModifier,
-                          bool HasNonmonotonicModifier, bool HasOrderedClause) {
-  OMPScheduleType BaseSchedule =
-      getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
+                          bool HasNonmonotonicModifier, bool HasOrderedClause,
+                          bool HasDistScheduleChunks) {
+  OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
+      ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
   OMPScheduleType OrderedSchedule =
       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
@@ -4637,7 +4643,8 @@ static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    WorksharingLoopType LoopType, bool NeedsBarrier) {
+    WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
+    OMPScheduleType DistScheduleSchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -4693,15 +4700,26 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
 
   // Call the "init" function and update the trip count of the loop with the
   // value it produced.
-  SmallVector<Value *, 10> Args(
-      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound, PUpperBound});
-  if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
-    Value *PDistUpperBound =
-        Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
-    Args.push_back(PDistUpperBound);
+  auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
+                        PUpperBound, IVTy, PStride, One, Zero,
+                        StaticInit](Value *SchedulingType, auto &Builder) {
+    SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
+                                   PLowerBound, PUpperBound});
+    if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
+      Value *PDistUpperBound =
+          Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
+      Args.push_back(PDistUpperBound);
+    }
+    Args.append({PStride, One, Zero});
+    Builder.CreateCall(StaticInit, Args);
+  };
+  BuildInitCall(SchedulingType, Builder);
+  if (HasDistSchedule &&
+      LoopType != WorksharingLoopType::DistributeStaticLoop) {
+    Constant *DistScheduleSchedType = ConstantInt::get(
+        I32Type, static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
+    BuildInitCall(DistScheduleSchedType, Builder);
   }
-  Args.append({PStride, One, Zero});
-  Builder.CreateCall(StaticInit, Args);
   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
@@ -4740,14 +4758,44 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   return AfterIP;
 }
 
+static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                                   LoopInfo &LI);
+static void addLoopMetadata(CanonicalLoopInfo *Loop,
+                            ArrayRef<Metadata *> Properties);
+
+static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
+                                          LLVMContext &Ctx, Loop *Loop,
+                                          LoopInfo &LoopInfo,
+                                          SmallVector<Metadata *> &LoopMDList) {
+  SmallSet<BasicBlock *, 8> Reachable;
+
+  // Get the basic blocks from the loop in which memref instructions
+  // can be found.
+  // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
+  // preferably without running any passes.
+  for (BasicBlock *Block : Loop->getBlocks()) {
+    if (Block == CLI->getCond() || Block == CLI->getHeader())
+      continue;
+    Reachable.insert(Block);
+  }
+
+  // Add access group metadata to memory-access instructions.
+  MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
+  for (BasicBlock *BB : Reachable)
+    addAccessGroupMetadata(BB, AccessGroup, LoopInfo);
+  // TODO:  If the loop has existing parallel access metadata, have
+  // to combine two lists.
+  LoopMDList.push_back(MDNode::get(
+      Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
+}
+
 OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
-                                                 CanonicalLoopInfo *CLI,
-                                                 InsertPointTy AllocaIP,
-                                                 bool NeedsBarrier,
-                                                 Value *ChunkSize) {
+OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
+    DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+    bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
+    Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
-  assert(ChunkSize && "Chunk size is required");
+  assert(ChunkSize || DistScheduleChunkSize && "Chunk size is required");
 
   LLVMContext &Ctx = CLI->getFunction()->getContext();
   Value *IV = CLI->getIndVar();
@@ -4761,6 +4809,18 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
   Constant *One = ConstantInt::get(InternalIVTy, 1);
 
+  Function *F = CLI->getFunction();
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([]() { return DominatorTreeAnalysis(); });
+  FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
+  LoopAnalysis LIA;
+  LoopInfo &&LI = LIA.run(*F, FAM);
+  Loop *L = LI.getLoopFor(CLI->getHeader());
+  SmallVector<Metadata *> LoopMDList;
+  if (ChunkSize || DistScheduleChunkSize)
+    applyParallelAccessesMetadata(CLI, Ctx, L, LI, LoopMDList);
+  addLoopMetadata(CLI, LoopMDList);
+
   // Declare useful OpenMP runtime functions.
   FunctionCallee StaticInit =
       getKmpcForStaticInitForType(InternalIVTy, M, *this);
@@ -4783,13 +4843,18 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Builder.SetCurrentDebugLocation(DL);
 
   // TODO: Detect overflow in ubsan or max-out with current tripcount.
-  Value *CastedChunkSize =
-      Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
+  Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
+      ChunkSize ? ChunkSize : Zero, InternalIVTy, "chunksize");
+  Value *CastestDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
+      DistScheduleChunkSize ? DistScheduleChunkSize : Zero, InternalIVTy,
+      "distschedulechunksize");
   Value *CastedTripCount =
       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
 
-  Constant *SchedulingType = ConstantInt::get(
-      I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
+  Constant *SchedulingType =
+      ConstantInt::get(I32Type, static_cast<int>(SchedType));
+  Constant *DistSchedulingType =
+      ConstantInt::get(I32Type, static_cast<int>(DistScheduleSchedType));
   Builder.CreateStore(Zero, PLowerBound);
   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
   Builder.CreateStore(OrigUpperBound, PUpperBound);
@@ -4801,12 +4866,25 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
-  Builder.CreateCall(StaticInit,
-                     {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
-                      /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
-                      /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
-                      /*pstride=*/PStride, /*incr=*/One,
-                      /*chunk=*/CastedChunkSize});
+  auto BuildInitCall =
+      [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound,
+       PStride, One](Value *SchedulingType, Value *ChunkSize, auto &Builder) {
+        Builder.CreateCall(
+            StaticInit, {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
+                         /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
+                         /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
+                         /*pstride=*/PStride, /*incr=*/One,
+                         /*chunk=*/ChunkSize});
+      };
+  BuildInitCall(SchedulingType, CastedChunkSize, Builder);
+  if (DistScheduleSchedType != OMPScheduleType::None &&
+      SchedType != OMPScheduleType::OrderedDistributeChunked &&
+      SchedType != OMPScheduleType::OrderedDistribute) {
+    // We want to emit a second init function call for the dist_schedule clause
+    // to the Distribute construct. This should only be done however if a
+    // Workshare Loop is nested within a Distribute Construct
+    BuildInitCall(DistSchedulingType, CastestDistScheduleChunkSize, Builder);
+  }
 
   // Load values written by the "init" function.
   Value *FirstChunkStart =
@@ -5130,31 +5208,47 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
     bool HasSimdModifier, bool HasMonotonicModifier,
     bool HasNonmonotonicModifier, bool HasOrderedClause,
-    WorksharingLoopType LoopType) {
+    WorksharingLoopType LoopType, bool HasDistSchedule,
+    Value *DistScheduleChunkSize) {
   if (Config.isTargetDevice())
     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
-      HasNonmonotonicModifier, HasOrderedClause);
+      HasNonmonotonicModifier, HasOrderedClause, DistScheduleChunkSize);
 
   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
                    OMPScheduleType::ModifierOrdered;
+  OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
+  if (HasDistSchedule) {
+    DistScheduleSchedType = DistScheduleChunkSize
+                                ? OMPScheduleType::OrderedDistributeChunked
+                                : OMPScheduleType::OrderedDistribute;
+  }
   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
   case OMPScheduleType::BaseStatic:
-    assert(!ChunkSize && "No chunk size with static-chunked schedule");
-    if (IsOrdered)
+  case OMPScheduleType::BaseDistribute:
+    assert(!ChunkSize || !DistScheduleChunkSize &&
+                             "No chunk size with static-chunked schedule");
+    if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier);
+    if (DistScheduleChunkSize)
+      return applyStaticChunkedWorkshareLoop(
+          DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
+          DistScheduleChunkSize, DistScheduleSchedType);
+    return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
+                                    HasDistSchedule);
 
   case OMPScheduleType::BaseStaticChunked:
-    if (IsOrdered)
+  case OMPScheduleType::BaseDistributeChunked:
+    if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
-                                           ChunkSize);
+    return applyStaticChunkedWorkshareLoop(
+        DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
+        DistScheduleChunkSize, DistScheduleSchedType);
 
   case OMPScheduleType::BaseRuntime:
   case OMPScheduleType::BaseAuto:
@@ -5230,7 +5324,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy
 OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
                                            InsertPointTy AllocaIP,
                                            OMPScheduleType SchedType,
-                                           bool NeedsBarrier, Value *Chunk) {
+                                           bool NeedsBarrier, Value *Chunk,
+                                           Value *DistScheduleChunk) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -5747,8 +5842,8 @@ static void addLoopMetadata(CanonicalLoopInfo *Loop,
 }
 
 /// Attach llvm.access.group metadata to the memref instructions of \p Block
-static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
-                            LoopInfo &LI) {
+static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                                   LoopInfo &LI) {
   for (Instruction &I : *Block) {
     if (I.mayReadOrWriteMemory()) {
       // TODO: This instruction may already have access group from
@@ -5918,19 +6013,6 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
     createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, "simd");
   }
 
-  SmallSet<BasicBlock *, 8> Reachable;
-
-  // Get the basic blocks from the loop in which memref instructions
-  // can be found.
-  // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
-  // preferably without running any passes.
-  for (BasicBlock *Block : L->getBlocks()) {
-    if (Block == CanonicalLoop->getCond() ||
-        Block == CanonicalLoop->getHeader())
-      continue;
-    Reachable.insert(Block);
-  }
-
   SmallVector<Metadata *> LoopMDList;
 
   // In presence of finite 'safelen', it may be unsafe to mark all
@@ -5938,16 +6020,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
   // dependences of 'safelen' iterations are possible.
   // If clause order(concurrent) is specified then the memory instructions
   // are marked parallel even if 'safelen' is finite.
-  if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
-    // Add access group metadata to memory-access instructions.
-    MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
-    for (BasicBlock *BB : Reachable)
-      addSimdMetadata(BB, AccessGroup, LI);
-    // TODO:  If the loop has existing parallel access metadata, have
-    // to combine two lists.
-    LoopMDList.push_back(MDNode::get(
-        Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
-  }
+  if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
+    applyParallelAccessesMetadata(CanonicalLoop, Ctx, L, LI, LoopMDList);
 
   // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
   // versions so we can't add the loop attributes in that case.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767ef90b0..9e2031401403c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -386,6 +386,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
     break;
   case ClauseScheduleKind::Auto:
   case ClauseScheduleKind::Runtime:
+  case ClauseScheduleKind::Distribute:
     chunkSize = std::nullopt;
   }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2cdd502ad0275..b11af583f4c16 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -60,6 +60,8 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
     return llvm::omp::OMP_SCHEDULE_Auto;
   case omp::ClauseScheduleKind::Runtime:
     return llvm::omp::OMP_SCHEDULE_Runtime;
+  case omp::ClauseScheduleKind::Distribute:
+    return llvm::omp::OMP_SCHEDULE_Distribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
 }
@@ -318,10 +320,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDevice())
       result = todo("device");
   };
-  auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
-    if (op.getDistScheduleChunkSize())
-      result = todo("dist_schedule with chunk_size");
-  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -392,7 +390,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::DistributeOp op) {
         checkAllocate(op, result);
-        checkDistSchedule(op, result);
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
@@ -2490,6 +2487,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
+  omp::DistributeOp distributeOp = nullptr;
+  llvm::Value *distScheduleChunk = nullptr;
+  bool hasDistSchedule = false;
+  if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
+    distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
+    hasDistSchedule = distributeOp.getDistScheduleStatic();
+    if (distributeOp.getDistScheduleChunkSize()) {
+      llvm::Value *chunkVar = moduleTranslation.lookupValue(
+          distributeOp.getDistScheduleChunkSize());
+      distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
+    }
+  }
+
   PrivateVarsInfo privateVarsInfo(wsloopOp);
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2596,7 +2606,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType);
+          workshareLoopType, hasDistSchedule, distScheduleChunk);
 
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
@@ -4836,15 +4846,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
     if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
       // TODO: Add support for clauses which are valid for DISTRIBUTE
       // constructs. Static schedule is the default.
-      auto schedule = omp::ClauseScheduleKind::Static;
-      bool isOrdered = false;
+      bool hasDistSchedule = distributeOp.getDistScheduleStatic();
+      auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
+                                      : omp::ClauseScheduleKind::Static;
+      // dist_schedule clauses are ordered - otherise this should be false
+      bool isOrdered = hasDistSchedule;
       std::optional<omp::ScheduleModifier> scheduleMod;
       bool isSimd = false;
       llvm::omp::WorksharingLoopType workshareLoopType =
           llvm::omp::WorksharingLoopType::DistributeStaticLoop;
       bool loopNeedsBarrier = false;
-      llvm::Value *chunk = nullptr;
-
+      llvm::Value *chunk = moduleTranslation.lookupValue(
+          distributeOp.getDistScheduleChunkSize());
       llvm::CanonicalLoopInfo *loopInfo =
           findCurrentLoopInfo(moduleTranslation);
       llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
@@ -4853,12 +4866,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
               convertToScheduleKind(schedule), chunk, isSimd,
               scheduleMod == omp::ScheduleModifier::monotonic,
               scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType);
+              workshareLoopType, hasDistSchedule, chunk);
 
       if (!wsloopIP)
         return wsloopIP.takeError();
     }
-
     if (failed(cleanupPrivateVars(builder, moduleTranslation,
                                   distributeOp.getLoc(), privVarsInfo.llvmVars,
                                   privVarsInfo.privatizers)))
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index d69de998346b5..e180cdc2cb075 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -614,3 +614,22 @@ omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>
   // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr)
   omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr)
 }
+
+// CHECK-LABEL: llvm.func @omp_dist_schedule(%arg0: i32) {
+func.func @omp_dist_schedule(%arg0: i32) {
+  %c1_i32 = arith.constant 1 : i32
+  // CHECK: %1 = llvm.mlir.constant(1024 : i32) : i32
+  %c1024_i32 = arith.constant 1024 : i32
+  %c16_i32 = arith.constant 16 : i32
+  %c8_i32 = arith.constant 8 : i32
+  omp.teams num_teams( to %c8_i32 : i32) thread_limit(%c16_i32 : i32) {
+    // CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+    omp.distribute dist_schedule_static dist_schedule_chunk_size(%c1024_i32 : i32) {
+      omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%arg0) inclusive step (%c1_i32) {
+        omp.terminator
+      }
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
new file mode 100644
index 0000000000000..291c0d3e51d6c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
@@ -0,0 +1,30 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @distribute_dist_schedule_chunk_size(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+  // CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+  %1 = llvm.mlir.constant(1024: i32) : i32
+  omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// When a chunk size is present, we need to make sure the correct parallel accesses metadata is added
+// CHECK: !2 = !{!"llvm.loop.parallel_accesses", !3}
+// CHECK-NEXT: !3 = distinct !{}
+
+// -----
+
+llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+  // CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+  omp.distribute dist_schedule_static {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
new file mode 100644
index 0000000000000..b25675c78a23c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
@@ -0,0 +1,99 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate while using workshare loops.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @distribute_wsloop_dist_scheule_chunked_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %dcs = llvm.mlir.constant(1024 : i32) : i32
+  %scs = llvm.mlir.constant(64 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+        omp.wsloop schedule(static = %scs : i32) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_scheule_chunked_schedule_chunked..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+
+llvm.func @distribute_wsloop_dist_scheule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %dcs = llvm.mlir.constant(1024 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_scheule_chunked..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+
+llvm.func @distribute_wsloop_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %scs = llvm.mlir.constant(64 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static = %scs : i32) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_schedule_chunked..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+
+llvm.func @distribute_wsloop_no_chunks(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_no_chunks..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i32 1, i32 0)
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 2fa4470bb8300..b3b1e853014f9 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -39,19 +39,6 @@ llvm.func @distribute_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr
 
 // -----
 
-llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause dist_schedule with chunk_size in omp.distribute operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.distribute}}
-  omp.distribute dist_schedule_static dist_schedule_chunk_size(%x : i32) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @distribute_order(%lb : i32, %ub : i32, %step : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause order in omp.distribute operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.distribute}}

>From bbcb90283671581c846e1d4d84729e8de8f3acb4 Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Tue, 26 Aug 2025 08:58:32 +0100
Subject: [PATCH 2/2] Handle dist_schedule in unique workshare loop.

It should be noted that at this stage, the wsloop is inserted at
the wrong location, and the insertion point needs to be handled to
ensure the `.ll` output is the same as Clang.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 18 +-----
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 61 +++++--------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 34 +++++------
 3 files changed, 32 insertions(+), 81 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 41c2e2156736b..742aa6e7646f1 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1098,15 +1098,12 @@ class OpenMPIRBuilder {
   /// \param LoopType Type of workshare loop.
   /// \param HasDistSchedule Defines if the clause being lowered is
   /// dist_schedule as this is handled slightly differently
-  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute
-  /// loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyStaticWorkshareLoop(
       DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
       omp::WorksharingLoopType LoopType, bool NeedsBarrier,
-      bool HasDistSchedule = false,
-      omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
+      bool HasDistSchedule = false);
 
   /// Modifies the canonical loop a statically-scheduled workshare loop with a
   /// user-specified chunk size.
@@ -1121,20 +1118,13 @@ class OpenMPIRBuilder {
   /// \param ChunkSize    The user-specified chunk size.
   /// \param SchedType    Optional type of scheduling to be passed to the init
   /// function.
-  /// \param DistScheduleChunkSize    The size of dist_shcedule chunk considered
-  /// as a unit when
-  ///                 scheduling. If \p nullptr, defaults to 1.
-  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute
-  /// loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyStaticChunkedWorkshareLoop(
       DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
       bool NeedsBarrier, Value *ChunkSize,
       omp::OMPScheduleType SchedType =
-          omp::OMPScheduleType::UnorderedStaticChunked,
-      Value *DistScheduleChunkSize = nullptr,
-      omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
+          omp::OMPScheduleType::UnorderedStaticChunked);
 
   /// Modifies the canonical loop to be a dynamically-scheduled workshare loop.
   ///
@@ -1216,8 +1206,6 @@ class OpenMPIRBuilder {
   /// \param HasDistSchedule Defines if the clause being lowered is
   /// dist_schedule as this is handled slightly differently
   ///
-  /// \param ChunkSize The chunk size for dist_schedule loop
-  ///
   /// \returns Point where to insert code after the workshare construct.
   LLVM_ABI InsertPointOrErrorTy applyWorkshareLoop(
       DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
@@ -1228,7 +1216,7 @@ class OpenMPIRBuilder {
       bool HasOrderedClause = false,
       omp::WorksharingLoopType LoopType =
           omp::WorksharingLoopType::ForStaticLoop,
-      bool HasDistSchedule = false, Value *DistScheduleChunkSize = nullptr);
+      bool HasDistSchedule = false);
 
   /// Tile a loop nest.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1860ade264740..3b8616e46ead5 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -172,7 +172,7 @@ static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
 /// arguments.
 static OMPScheduleType
 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
-                          bool HasSimdModifier, bool HasDistScheduleChunks) {
+                          bool HasSimdModifier) {
   // Currently, the default schedule it static.
   switch (ClauseKind) {
   case OMP_SCHEDULE_Default:
@@ -190,7 +190,7 @@ getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
                            : OMPScheduleType::BaseRuntime;
   case OMP_SCHEDULE_Distribute:
-    return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
+    return HasChunks ? OMPScheduleType::BaseDistributeChunked
                                  : OMPScheduleType::BaseDistribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
@@ -260,10 +260,9 @@ getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
 static OMPScheduleType
 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
                           bool HasSimdModifier, bool HasMonotonicModifier,
-                          bool HasNonmonotonicModifier, bool HasOrderedClause,
-                          bool HasDistScheduleChunks) {
+                          bool HasNonmonotonicModifier, bool HasOrderedClause) {
   OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
-      ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
+      ClauseKind, HasChunks, HasSimdModifier);
   OMPScheduleType OrderedSchedule =
       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
@@ -4643,8 +4642,7 @@ static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
-    OMPScheduleType DistScheduleSchedType) {
+    WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -4714,12 +4712,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
     Builder.CreateCall(StaticInit, Args);
   };
   BuildInitCall(SchedulingType, Builder);
-  if (HasDistSchedule &&
-      LoopType != WorksharingLoopType::DistributeStaticLoop) {
-    Constant *DistScheduleSchedType = ConstantInt::get(
-        I32Type, static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
-    BuildInitCall(DistScheduleSchedType, Builder);
-  }
   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
@@ -4792,10 +4784,9 @@ static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
 OpenMPIRBuilder::InsertPointOrErrorTy
 OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
-    Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
+    bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
-  assert(ChunkSize || DistScheduleChunkSize && "Chunk size is required");
+  assert(ChunkSize && "Chunk size is required");
 
   LLVMContext &Ctx = CLI->getFunction()->getContext();
   Value *IV = CLI->getIndVar();
@@ -4817,7 +4808,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
   LoopInfo &&LI = LIA.run(*F, FAM);
   Loop *L = LI.getLoopFor(CLI->getHeader());
   SmallVector<Metadata *> LoopMDList;
-  if (ChunkSize || DistScheduleChunkSize)
+  if (ChunkSize)
     applyParallelAccessesMetadata(CLI, Ctx, L, LI, LoopMDList);
   addLoopMetadata(CLI, LoopMDList);
 
@@ -4839,22 +4830,17 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
   CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
-  Builder.restoreIP(CLI->getPreheaderIP());
+  Builder.restoreIP(CLI->getPreheaderIP()); // -> sets insert point to omploop! Why?
   Builder.SetCurrentDebugLocation(DL);
 
   // TODO: Detect overflow in ubsan or max-out with current tripcount.
   Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
       ChunkSize ? ChunkSize : Zero, InternalIVTy, "chunksize");
-  Value *CastestDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
-      DistScheduleChunkSize ? DistScheduleChunkSize : Zero, InternalIVTy,
-      "distschedulechunksize");
   Value *CastedTripCount =
       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
 
   Constant *SchedulingType =
       ConstantInt::get(I32Type, static_cast<int>(SchedType));
-  Constant *DistSchedulingType =
-      ConstantInt::get(I32Type, static_cast<int>(DistScheduleSchedType));
   Builder.CreateStore(Zero, PLowerBound);
   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
   Builder.CreateStore(OrigUpperBound, PUpperBound);
@@ -4877,14 +4863,6 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
                          /*chunk=*/ChunkSize});
       };
   BuildInitCall(SchedulingType, CastedChunkSize, Builder);
-  if (DistScheduleSchedType != OMPScheduleType::None &&
-      SchedType != OMPScheduleType::OrderedDistributeChunked &&
-      SchedType != OMPScheduleType::OrderedDistribute) {
-    // We want to emit a second init function call for the dist_schedule clause
-    // to the Distribute construct. This should only be done however if a
-    // Workshare Loop is nested within a Distribute Construct
-    BuildInitCall(DistSchedulingType, CastestDistScheduleChunkSize, Builder);
-  }
 
   // Load values written by the "init" function.
   Value *FirstChunkStart =
@@ -5208,35 +5186,27 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
     bool HasSimdModifier, bool HasMonotonicModifier,
     bool HasNonmonotonicModifier, bool HasOrderedClause,
-    WorksharingLoopType LoopType, bool HasDistSchedule,
-    Value *DistScheduleChunkSize) {
+    WorksharingLoopType LoopType, bool HasDistSchedule) {
   if (Config.isTargetDevice())
     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
-      HasNonmonotonicModifier, HasOrderedClause, DistScheduleChunkSize);
+      HasNonmonotonicModifier, HasOrderedClause);
 
   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
                    OMPScheduleType::ModifierOrdered;
-  OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
-  if (HasDistSchedule) {
-    DistScheduleSchedType = DistScheduleChunkSize
-                                ? OMPScheduleType::OrderedDistributeChunked
-                                : OMPScheduleType::OrderedDistribute;
-  }
   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
   case OMPScheduleType::BaseStatic:
   case OMPScheduleType::BaseDistribute:
-    assert(!ChunkSize || !DistScheduleChunkSize &&
+    assert(!ChunkSize &&
                              "No chunk size with static-chunked schedule");
     if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    if (DistScheduleChunkSize)
+    if (ChunkSize)
       return applyStaticChunkedWorkshareLoop(
-          DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
-          DistScheduleChunkSize, DistScheduleSchedType);
+          DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType);
     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
                                     HasDistSchedule);
 
@@ -5247,8 +5217,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
     return applyStaticChunkedWorkshareLoop(
-        DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
-        DistScheduleChunkSize, DistScheduleSchedType);
+        DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType);
 
   case OMPScheduleType::BaseRuntime:
   case OMPScheduleType::BaseAuto:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index b11af583f4c16..cfb0ab4a3ac0e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2464,6 +2464,7 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
 static LogicalResult
 convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                  LLVM::ModuleTranslation &moduleTranslation) {
+  printf("CONVERTING WSLOOP\n");
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto wsloopOp = cast<omp::WsloopOp>(opInst);
   if (failed(checkImplementationStatus(opInst)))
@@ -2487,19 +2488,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
-  omp::DistributeOp distributeOp = nullptr;
-  llvm::Value *distScheduleChunk = nullptr;
-  bool hasDistSchedule = false;
-  if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
-    distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
-    hasDistSchedule = distributeOp.getDistScheduleStatic();
-    if (distributeOp.getDistScheduleChunkSize()) {
-      llvm::Value *chunkVar = moduleTranslation.lookupValue(
-          distributeOp.getDistScheduleChunkSize());
-      distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
-    }
-  }
-
   PrivateVarsInfo privateVarsInfo(wsloopOp);
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2600,13 +2588,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   }
 
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
+      printf("loopInfo Address: %p\n", loopInfo);
+      printf("Applying omp.wloop Workshare Loop\n");
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType, hasDistSchedule, distScheduleChunk);
+          workshareLoopType);
 
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
@@ -3052,6 +3042,7 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
 
   // Update the stack frame created for this loop to point to the resulting loop
   // after applying transformations.
+  printf("Applying loopInfo\n");
   moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
       [&](OpenMPLoopInfoStackFrame &frame) {
         frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
@@ -4767,6 +4758,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
 static LogicalResult
 convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
                      LLVM::ModuleTranslation &moduleTranslation) {
+  printf("CONVERTING DISTRIBUTE\n");
   llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
   auto distributeOp = cast<omp::DistributeOp>(opInst);
   if (failed(checkImplementationStatus(opInst)))
@@ -4835,7 +4827,7 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
     llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
     llvm::Expected<llvm::BasicBlock *> regionBlock =
         convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
-                            builder, moduleTranslation);
+                            builder, moduleTranslation); // -> this is causing Schedule to be emitted first.
     if (!regionBlock)
       return regionBlock.takeError();
     builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -4843,13 +4835,12 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
     // Skip applying a workshare loop below when translating 'distribute
     // parallel do' (it's been already handled by this point while translating
     // the nested omp.wsloop).
-    if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
+    if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper()) || distributeOp.getDistScheduleStatic()) {
       // TODO: Add support for clauses which are valid for DISTRIBUTE
       // constructs. Static schedule is the default.
       bool hasDistSchedule = distributeOp.getDistScheduleStatic();
       auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
                                       : omp::ClauseScheduleKind::Static;
-      // dist_schedule clauses are ordered - otherise this should be false
       bool isOrdered = hasDistSchedule;
       std::optional<omp::ScheduleModifier> scheduleMod;
       bool isSimd = false;
@@ -4859,14 +4850,17 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
       llvm::Value *chunk = moduleTranslation.lookupValue(
           distributeOp.getDistScheduleChunkSize());
       llvm::CanonicalLoopInfo *loopInfo =
-          findCurrentLoopInfo(moduleTranslation);
+          findCurrentLoopInfo(moduleTranslation); // Do we need a new loop info here?
+      printf("loopInfo Address: %p\n", loopInfo);
+      printf("InsertPoint Name : %s\n", builder.GetInsertBlock()->getName().str().c_str());
+      printf("Applying omp.ditribute Workshare Loop\n");
       llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
           ompBuilder->applyWorkshareLoop(
               ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
               convertToScheduleKind(schedule), chunk, isSimd,
               scheduleMod == omp::ScheduleModifier::monotonic,
               scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType, hasDistSchedule, chunk);
+              workshareLoopType, hasDistSchedule);
 
       if (!wsloopIP)
         return wsloopIP.takeError();
@@ -5907,7 +5901,7 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
       !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
 
   if (isOutermostLoopWrapper)
-    moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
+    moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>(); // -> Need another one of these when Distribute AND WSLoop is present?
 
   auto result =
       llvm::TypeSwitch<Operation *, LogicalResult>(op)



More information about the Mlir-commits mailing list