[llvm] [mlir] [MLIR][OpenMP] Add MLIR Lowering Support for dist_schedule (PR #152736)

Jack Styles via llvm-commits llvm-commits at lists.llvm.org
Fri Aug 8 07:57:12 PDT 2025


https://github.com/Stylie777 created https://github.com/llvm/llvm-project/pull/152736

`dist_schedule` was previously supported in Flang/Clang but was not implemented in MLIR, instead a user would get a "not yet implemented" error. This patch adds support for the `dist_schedule` clause to be lowered to LLVM IR when used in an `omp.distribute` section. Support is also added for `dist_schedule` to be used when the loop nest is embedded within a Workshare Loop.

There has needed to be some rework required to ensure that MLIR/LLVM emits the correct Schedule Type for the clause, as it uses a different schedule type to other OpenMP directives/clauses in the runtime library.

Add llvm loop metadata

Update implementation to support processing in workshare loop.

>From 464cd87b160522f301af40d45cad330356a8c464 Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Tue, 8 Jul 2025 08:57:18 +0100
Subject: [PATCH] [MLIR][OpenMP] Add support for `dist_schedule`

`dist_schedule` was previously supported in Flang/Clang but was not
implemented in MLIR, instead a user would get a "not yet implemented"
error. This patch adds support for the `dist_schedule` clause to be
lowered to LLVM IR when used in an `omp.distribute` section. Support
is also added for `dist_schedule` to be used when the loop nest is
embedded within a Workshare Loop.

There has needed to be some rework required to ensure that MLIR/LLVM
emits the correct Schedule Type for the clause, as it uses a different
schedule type to other OpenMP directives/clauses in the runtime library.

Add llvm loop metadata

Update implementation to support processing in workshare loop.
---
 llvm/include/llvm/Frontend/OpenMP/OMP.td      |   4 +-
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  25 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 198 ++++++++++++------
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |   1 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  36 ++--
 .../OpenMPToLLVM/convert-to-llvmir.mlir       |  19 ++
 .../Target/LLVMIR/openmp-dist_schedule.mlir   |  30 +++
 .../openmp-dist_schedule_with_wsloop.mlir     |  99 +++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  13 --
 9 files changed, 332 insertions(+), 93 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
 create mode 100644 mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 79f25bb05f20e..4117e112367c6 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -458,7 +458,8 @@ def OMP_SCHEDULE_Dynamic : EnumVal<"dynamic", 3, 1> {}
 def OMP_SCHEDULE_Guided : EnumVal<"guided", 4, 1> {}
 def OMP_SCHEDULE_Auto : EnumVal<"auto", 5, 1> {}
 def OMP_SCHEDULE_Runtime : EnumVal<"runtime", 6, 1> {}
-def OMP_SCHEDULE_Default : EnumVal<"default", 7, 0> { let isDefault = 1; }
+def OMP_SCHEDULE_Distribute : EnumVal<"distribute", 7, 1> {}
+def OMP_SCHEDULE_Default : EnumVal<"default", 8, 0> { let isDefault = 1; }
 def OMPC_Schedule : Clause<[Spelling<"schedule">]> {
   let clangClass = "OMPScheduleClause";
   let flangClass = "OmpScheduleClause";
@@ -469,6 +470,7 @@ def OMPC_Schedule : Clause<[Spelling<"schedule">]> {
     OMP_SCHEDULE_Guided,
     OMP_SCHEDULE_Auto,
     OMP_SCHEDULE_Runtime,
+    OMP_SCHEDULE_Distribute,
     OMP_SCHEDULE_Default
   ];
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index f70659120e1e6..395df392babde 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1096,11 +1096,13 @@ class OpenMPIRBuilder {
   /// \param NeedsBarrier Indicates whether a barrier must be inserted after
   ///                     the loop.
   /// \param LoopType Type of workshare loop.
+  /// \param HasDistSchedule Defines if the clause being lowered is dist_schedule as this is handled slightly differently
+  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyStaticWorkshareLoop(
       DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-      omp::WorksharingLoopType LoopType, bool NeedsBarrier);
+      omp::WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule = false, omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
 
   /// Modifies the canonical loop a statically-scheduled workshare loop with a
   /// user-specified chunk size.
@@ -1113,13 +1115,20 @@ class OpenMPIRBuilder {
   /// \param NeedsBarrier Indicates whether a barrier must be inserted after the
   ///                     loop.
   /// \param ChunkSize    The user-specified chunk size.
+  /// \param SchedType    Optional type of scheduling to be passed to the init function.
+  /// \param DistScheduleChunkSize    The size of dist_shcedule chunk considered as a unit when
+  ///                 scheduling. If \p nullptr, defaults to 1.
+  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyStaticChunkedWorkshareLoop(DebugLoc DL,
                                                        CanonicalLoopInfo *CLI,
                                                        InsertPointTy AllocaIP,
                                                        bool NeedsBarrier,
-                                                       Value *ChunkSize);
+                                                       Value *ChunkSize,
+                                                       omp::OMPScheduleType SchedType = omp::OMPScheduleType::UnorderedStaticChunked,
+                                                       Value *DistScheduleChunkSize = nullptr,
+                                                       omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
 
   /// Modifies the canonical loop to be a dynamically-scheduled workshare loop.
   ///
@@ -1139,6 +1148,8 @@ class OpenMPIRBuilder {
   ///                     the loop.
   /// \param Chunk    The size of loop chunk considered as a unit when
   ///                 scheduling. If \p nullptr, defaults to 1.
+  /// \param DistScheduleChunk    The size of dist_shcedule chunk considered as a unit when
+  ///                 scheduling. If \p nullptr, defaults to 1.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyDynamicWorkshareLoop(DebugLoc DL,
@@ -1146,7 +1157,8 @@ class OpenMPIRBuilder {
                                                  InsertPointTy AllocaIP,
                                                  omp::OMPScheduleType SchedType,
                                                  bool NeedsBarrier,
-                                                 Value *Chunk = nullptr);
+                                                 Value *Chunk = nullptr,
+                                                 Value *DistScheduleChunk = nullptr);
 
   /// Create alternative version of the loop to support if clause
   ///
@@ -1197,6 +1209,9 @@ class OpenMPIRBuilder {
   ///                         present.
   /// \param LoopType Information about type of loop worksharing.
   ///                 It corresponds to type of loop workshare OpenMP pragma.
+  /// \param HasDistSchedule Defines if the clause being lowered is dist_schedule as this is handled slightly differently
+  ///
+  /// \param ChunkSize The chunk size for dist_schedule loop
   ///
   /// \returns Point where to insert code after the workshare construct.
   LLVM_ABI InsertPointOrErrorTy applyWorkshareLoop(
@@ -1207,7 +1222,9 @@ class OpenMPIRBuilder {
       bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
       bool HasOrderedClause = false,
       omp::WorksharingLoopType LoopType =
-          omp::WorksharingLoopType::ForStaticLoop);
+          omp::WorksharingLoopType::ForStaticLoop,
+      bool HasDistSchedule = false,
+      Value* DistScheduleChunkSize = nullptr);
 
   /// Tile a loop nest.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ea027e48fa2f1..18da0d772912f 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -136,6 +136,8 @@ static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
   case OMPScheduleType::NomergeOrderedRuntime:
   case OMPScheduleType::NomergeOrderedAuto:
   case OMPScheduleType::NomergeOrderedTrapezoidal:
+  case OMPScheduleType::OrderedDistributeChunked:
+  case OMPScheduleType::OrderedDistribute:
     break;
   default:
     return false;
@@ -170,7 +172,7 @@ static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
 /// arguments.
 static OMPScheduleType
 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
-                          bool HasSimdModifier) {
+                          bool HasSimdModifier, bool HasDistScheduleChunks) {
   // Currently, the default schedule it static.
   switch (ClauseKind) {
   case OMP_SCHEDULE_Default:
@@ -187,6 +189,9 @@ getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
   case OMP_SCHEDULE_Runtime:
     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
                            : OMPScheduleType::BaseRuntime;
+  case OMP_SCHEDULE_Distribute:
+    return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
+                                 : OMPScheduleType::BaseDistribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
 }
@@ -255,9 +260,10 @@ getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
 static OMPScheduleType
 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
                           bool HasSimdModifier, bool HasMonotonicModifier,
-                          bool HasNonmonotonicModifier, bool HasOrderedClause) {
-  OMPScheduleType BaseSchedule =
-      getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
+                          bool HasNonmonotonicModifier, bool HasOrderedClause,
+                          bool HasDistScheduleChunks) {
+  OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
+      ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
   OMPScheduleType OrderedSchedule =
       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
@@ -4637,7 +4643,8 @@ static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    WorksharingLoopType LoopType, bool NeedsBarrier) {
+    WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
+    OMPScheduleType DistScheduleSchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -4693,15 +4700,26 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
 
   // Call the "init" function and update the trip count of the loop with the
   // value it produced.
-  SmallVector<Value *, 10> Args(
-      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound, PUpperBound});
-  if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
-    Value *PDistUpperBound =
-        Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
-    Args.push_back(PDistUpperBound);
+  auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
+                        PUpperBound, IVTy, PStride, One, Zero,
+                        StaticInit](Value *SchedulingType, auto &Builder) {
+    SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
+                                   PLowerBound, PUpperBound});
+    if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
+      Value *PDistUpperBound =
+          Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
+      Args.push_back(PDistUpperBound);
+    }
+    Args.append({PStride, One, Zero});
+    Builder.CreateCall(StaticInit, Args);
+  };
+  BuildInitCall(SchedulingType, Builder);
+  if (HasDistSchedule &&
+      LoopType != WorksharingLoopType::DistributeStaticLoop) {
+    Constant *DistScheduleSchedType = ConstantInt::get(
+        I32Type, static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
+    BuildInitCall(DistScheduleSchedType, Builder);
   }
-  Args.append({PStride, One, Zero});
-  Builder.CreateCall(StaticInit, Args);
   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
@@ -4740,14 +4758,42 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   return AfterIP;
 }
 
+static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                                   LoopInfo &LI);
+static void addLoopMetadata(CanonicalLoopInfo *Loop,
+ArrayRef<Metadata *> Properties);
+
+static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI, LLVMContext &Ctx, Loop *Loop, LoopInfo &LoopInfo, SmallVector<Metadata *> &LoopMDList) {
+  SmallSet<BasicBlock *, 8> Reachable;
+
+  // Get the basic blocks from the loop in which memref instructions
+  // can be found.
+  // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
+  // preferably without running any passes.
+  for (BasicBlock *Block : Loop->getBlocks()) {
+    if (Block == CLI->getCond() ||
+        Block == CLI->getHeader())
+      continue;
+    Reachable.insert(Block);
+  }
+
+  // Add access group metadata to memory-access instructions.
+  MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
+  for (BasicBlock *BB : Reachable)
+    addAccessGroupMetadata(BB, AccessGroup, LoopInfo);
+  // TODO:  If the loop has existing parallel access metadata, have
+  // to combine two lists.
+  LoopMDList.push_back(MDNode::get(
+      Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
+}
+
 OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
-                                                 CanonicalLoopInfo *CLI,
-                                                 InsertPointTy AllocaIP,
-                                                 bool NeedsBarrier,
-                                                 Value *ChunkSize) {
+OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
+    DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+    bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
+    Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
-  assert(ChunkSize && "Chunk size is required");
+  assert(ChunkSize || DistScheduleChunkSize && "Chunk size is required");
 
   LLVMContext &Ctx = CLI->getFunction()->getContext();
   Value *IV = CLI->getIndVar();
@@ -4761,6 +4807,18 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
   Constant *One = ConstantInt::get(InternalIVTy, 1);
 
+  Function *F = CLI->getFunction();
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([]() { return DominatorTreeAnalysis(); });
+  FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
+  LoopAnalysis LIA;
+  LoopInfo &&LI = LIA.run(*F, FAM);
+  Loop *L = LI.getLoopFor(CLI->getHeader());
+  SmallVector<Metadata *> LoopMDList;
+  if (ChunkSize || DistScheduleChunkSize)
+    applyParallelAccessesMetadata(CLI, Ctx, L, LI, LoopMDList);
+  addLoopMetadata(CLI, LoopMDList);
+
   // Declare useful OpenMP runtime functions.
   FunctionCallee StaticInit =
       getKmpcForStaticInitForType(InternalIVTy, M, *this);
@@ -4783,13 +4841,18 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Builder.SetCurrentDebugLocation(DL);
 
   // TODO: Detect overflow in ubsan or max-out with current tripcount.
-  Value *CastedChunkSize =
-      Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
+  Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
+      ChunkSize ? ChunkSize : Zero, InternalIVTy, "chunksize");
+  Value *CastestDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
+      DistScheduleChunkSize ? DistScheduleChunkSize : Zero, InternalIVTy,
+      "distschedulechunksize");
   Value *CastedTripCount =
       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
 
-  Constant *SchedulingType = ConstantInt::get(
-      I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
+  Constant *SchedulingType =
+      ConstantInt::get(I32Type, static_cast<int>(SchedType));
+  Constant *DistSchedulingType =
+      ConstantInt::get(I32Type, static_cast<int>(DistScheduleSchedType));
   Builder.CreateStore(Zero, PLowerBound);
   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
   Builder.CreateStore(OrigUpperBound, PUpperBound);
@@ -4801,12 +4864,25 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
-  Builder.CreateCall(StaticInit,
-                     {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
-                      /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
-                      /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
-                      /*pstride=*/PStride, /*incr=*/One,
-                      /*chunk=*/CastedChunkSize});
+  auto BuildInitCall =
+      [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound,
+       PStride, One](Value *SchedulingType, Value *ChunkSize, auto &Builder) {
+        Builder.CreateCall(
+            StaticInit, {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
+                         /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
+                         /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
+                         /*pstride=*/PStride, /*incr=*/One,
+                         /*chunk=*/ChunkSize});
+      };
+  BuildInitCall(SchedulingType, CastedChunkSize, Builder);
+  if (DistScheduleSchedType != OMPScheduleType::None &&
+      SchedType != OMPScheduleType::OrderedDistributeChunked &&
+      SchedType != OMPScheduleType::OrderedDistribute) {
+    // We want to emit a second init function call for the dist_schedule clause
+    // to the Distribute construct. This should only be done however if a
+    // Workshare Loop is nested within a Distribute Construct
+    BuildInitCall(DistSchedulingType, CastestDistScheduleChunkSize, Builder);
+  }
 
   // Load values written by the "init" function.
   Value *FirstChunkStart =
@@ -5130,31 +5206,47 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
     bool HasSimdModifier, bool HasMonotonicModifier,
     bool HasNonmonotonicModifier, bool HasOrderedClause,
-    WorksharingLoopType LoopType) {
+    WorksharingLoopType LoopType, bool HasDistSchedule,
+    Value *DistScheduleChunkSize) {
   if (Config.isTargetDevice())
     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
-      HasNonmonotonicModifier, HasOrderedClause);
+      HasNonmonotonicModifier, HasOrderedClause, DistScheduleChunkSize);
 
   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
                    OMPScheduleType::ModifierOrdered;
+  OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
+  if (HasDistSchedule) {
+    DistScheduleSchedType = DistScheduleChunkSize
+                                ? OMPScheduleType::OrderedDistributeChunked
+                                : OMPScheduleType::OrderedDistribute;
+  }
   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
   case OMPScheduleType::BaseStatic:
-    assert(!ChunkSize && "No chunk size with static-chunked schedule");
-    if (IsOrdered)
+  case OMPScheduleType::BaseDistribute:
+    assert(!ChunkSize || !DistScheduleChunkSize &&
+                             "No chunk size with static-chunked schedule");
+    if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier);
+    if (DistScheduleChunkSize)
+      return applyStaticChunkedWorkshareLoop(
+          DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
+          DistScheduleChunkSize, DistScheduleSchedType);
+    return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
+                                    HasDistSchedule);
 
   case OMPScheduleType::BaseStaticChunked:
-    if (IsOrdered)
+  case OMPScheduleType::BaseDistributeChunked:
+    if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
-                                           ChunkSize);
+    return applyStaticChunkedWorkshareLoop(
+        DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
+        DistScheduleChunkSize, DistScheduleSchedType);
 
   case OMPScheduleType::BaseRuntime:
   case OMPScheduleType::BaseAuto:
@@ -5230,7 +5322,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy
 OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
                                            InsertPointTy AllocaIP,
                                            OMPScheduleType SchedType,
-                                           bool NeedsBarrier, Value *Chunk) {
+                                           bool NeedsBarrier, Value *Chunk,
+                                           Value *DistScheduleChunk) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -5747,8 +5840,8 @@ static void addLoopMetadata(CanonicalLoopInfo *Loop,
 }
 
 /// Attach llvm.access.group metadata to the memref instructions of \p Block
-static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
-                            LoopInfo &LI) {
+static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                                   LoopInfo &LI) {
   for (Instruction &I : *Block) {
     if (I.mayReadOrWriteMemory()) {
       // TODO: This instruction may already have access group from
@@ -5918,19 +6011,6 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
     createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, "simd");
   }
 
-  SmallSet<BasicBlock *, 8> Reachable;
-
-  // Get the basic blocks from the loop in which memref instructions
-  // can be found.
-  // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
-  // preferably without running any passes.
-  for (BasicBlock *Block : L->getBlocks()) {
-    if (Block == CanonicalLoop->getCond() ||
-        Block == CanonicalLoop->getHeader())
-      continue;
-    Reachable.insert(Block);
-  }
-
   SmallVector<Metadata *> LoopMDList;
 
   // In presence of finite 'safelen', it may be unsafe to mark all
@@ -5938,16 +6018,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
   // dependences of 'safelen' iterations are possible.
   // If clause order(concurrent) is specified then the memory instructions
   // are marked parallel even if 'safelen' is finite.
-  if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
-    // Add access group metadata to memory-access instructions.
-    MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
-    for (BasicBlock *BB : Reachable)
-      addSimdMetadata(BB, AccessGroup, LI);
-    // TODO:  If the loop has existing parallel access metadata, have
-    // to combine two lists.
-    LoopMDList.push_back(MDNode::get(
-        Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
-  }
+  if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
+    applyParallelAccessesMetadata(CanonicalLoop, Ctx, L, LI, LoopMDList);
 
   // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
   // versions so we can't add the loop attributes in that case.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767ef90b0..9e2031401403c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -386,6 +386,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
     break;
   case ClauseScheduleKind::Auto:
   case ClauseScheduleKind::Runtime:
+  case ClauseScheduleKind::Distribute:
     chunkSize = std::nullopt;
   }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2cdd502ad0275..e5f3ddd301006 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -60,6 +60,8 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
     return llvm::omp::OMP_SCHEDULE_Auto;
   case omp::ClauseScheduleKind::Runtime:
     return llvm::omp::OMP_SCHEDULE_Runtime;
+  case omp::ClauseScheduleKind::Distribute:
+    return llvm::omp::OMP_SCHEDULE_Distribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
 }
@@ -318,10 +320,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDevice())
       result = todo("device");
   };
-  auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
-    if (op.getDistScheduleChunkSize())
-      result = todo("dist_schedule with chunk_size");
-  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -392,7 +390,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::DistributeOp op) {
         checkAllocate(op, result);
-        checkDistSchedule(op, result);
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
@@ -2490,6 +2487,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
+  omp::DistributeOp distributeOp = nullptr;
+  llvm::Value *distScheduleChunk = nullptr;
+  bool hasDistSchedule = false;
+  if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
+    distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
+    hasDistSchedule = distributeOp.getDistScheduleStatic();
+    if (distributeOp.getDistScheduleChunkSize()) {
+      llvm::Value *chunkVar =
+        moduleTranslation.lookupValue(distributeOp.getDistScheduleChunkSize());
+      distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
+    }
+  }
+
   PrivateVarsInfo privateVarsInfo(wsloopOp);
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2596,7 +2606,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType);
+          workshareLoopType, hasDistSchedule, distScheduleChunk);
 
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
@@ -4836,15 +4846,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
     if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
       // TODO: Add support for clauses which are valid for DISTRIBUTE
       // constructs. Static schedule is the default.
-      auto schedule = omp::ClauseScheduleKind::Static;
-      bool isOrdered = false;
+      bool hasDistSchedule = distributeOp.getDistScheduleStatic();
+      auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
+                                      : omp::ClauseScheduleKind::Static;
+      // dist_schedule clauses are ordered - otherise this should be false
+      bool isOrdered = hasDistSchedule;
       std::optional<omp::ScheduleModifier> scheduleMod;
       bool isSimd = false;
       llvm::omp::WorksharingLoopType workshareLoopType =
           llvm::omp::WorksharingLoopType::DistributeStaticLoop;
       bool loopNeedsBarrier = false;
-      llvm::Value *chunk = nullptr;
-
+      llvm::Value *chunk = moduleTranslation.lookupValue(
+          distributeOp.getDistScheduleChunkSize());
       llvm::CanonicalLoopInfo *loopInfo =
           findCurrentLoopInfo(moduleTranslation);
       llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
@@ -4853,12 +4866,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
               convertToScheduleKind(schedule), chunk, isSimd,
               scheduleMod == omp::ScheduleModifier::monotonic,
               scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType);
+              workshareLoopType, hasDistSchedule, chunk);
 
       if (!wsloopIP)
         return wsloopIP.takeError();
     }
-
     if (failed(cleanupPrivateVars(builder, moduleTranslation,
                                   distributeOp.getLoc(), privVarsInfo.llvmVars,
                                   privVarsInfo.privatizers)))
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index d69de998346b5..e180cdc2cb075 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -614,3 +614,22 @@ omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>
   // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr)
   omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr)
 }
+
+// CHECK-LABEL: llvm.func @omp_dist_schedule(%arg0: i32) {
+func.func @omp_dist_schedule(%arg0: i32) {
+  %c1_i32 = arith.constant 1 : i32
+  // CHECK: %1 = llvm.mlir.constant(1024 : i32) : i32
+  %c1024_i32 = arith.constant 1024 : i32
+  %c16_i32 = arith.constant 16 : i32
+  %c8_i32 = arith.constant 8 : i32
+  omp.teams num_teams( to %c8_i32 : i32) thread_limit(%c16_i32 : i32) {
+    // CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+    omp.distribute dist_schedule_static dist_schedule_chunk_size(%c1024_i32 : i32) {
+      omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%arg0) inclusive step (%c1_i32) {
+        omp.terminator
+      }
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
new file mode 100644
index 0000000000000..291c0d3e51d6c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
@@ -0,0 +1,30 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @distribute_dist_schedule_chunk_size(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+  // CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+  %1 = llvm.mlir.constant(1024: i32) : i32
+  omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// When a chunk size is present, we need to make sure the correct parallel accesses metadata is added
+// CHECK: !2 = !{!"llvm.loop.parallel_accesses", !3}
+// CHECK-NEXT: !3 = distinct !{}
+
+// -----
+
+llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+  // CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+  omp.distribute dist_schedule_static {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
new file mode 100644
index 0000000000000..b25675c78a23c
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
@@ -0,0 +1,99 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate while using workshare loops.
+
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @distribute_wsloop_dist_scheule_chunked_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %dcs = llvm.mlir.constant(1024 : i32) : i32
+  %scs = llvm.mlir.constant(64 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+        omp.wsloop schedule(static = %scs : i32) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_scheule_chunked_schedule_chunked..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+
+llvm.func @distribute_wsloop_dist_scheule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %dcs = llvm.mlir.constant(1024 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_scheule_chunked..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+
+llvm.func @distribute_wsloop_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %scs = llvm.mlir.constant(64 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static = %scs : i32) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_schedule_chunked..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+
+llvm.func @distribute_wsloop_no_chunks(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_no_chunks..omp_par(ptr %0) {
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i32 1, i32 0)
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 2fa4470bb8300..b3b1e853014f9 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -39,19 +39,6 @@ llvm.func @distribute_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr
 
 // -----
 
-llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause dist_schedule with chunk_size in omp.distribute operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.distribute}}
-  omp.distribute dist_schedule_static dist_schedule_chunk_size(%x : i32) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @distribute_order(%lb : i32, %ub : i32, %step : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause order in omp.distribute operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.distribute}}



More information about the llvm-commits mailing list