[Mlir-commits] [flang] [llvm] [mlir] [MLIR][OpenMP] Add MLIR Lowering Support for dist_schedule (PR #152736)

Jack Styles llvmlistbot at llvm.org
Wed Oct 29 09:44:46 PDT 2025


https://github.com/Stylie777 updated https://github.com/llvm/llvm-project/pull/152736

>From 96fdc045fa5cac25dbdd5680a7c132ed2bbfffcf Mon Sep 17 00:00:00 2001
From: Jack Styles <jack.styles at arm.com>
Date: Tue, 8 Jul 2025 08:57:18 +0100
Subject: [PATCH] [MLIR][OpenMP] Add support for `dist_schedule`

`dist_schedule` was previously supported in Flang/Clang but was not
implemented in MLIR, instead a user would get a "not yet implemented"
error. This patch adds support for the `dist_schedule` clause to be
lowered to LLVM IR when used in an `omp.distribute` or `omp.wsloop`
section.

There has needed to be some rework required to ensure that MLIR/LLVM
emits the correct Schedule Type for the clause, as it uses a different
schedule type to other OpenMP directives/clauses in the runtime library.

This patch also ensures that when using dist_schedule or a chunked
schedule clause, the correct llvm loop parallel accesses details are
added.
---
 flang/docs/OpenMPSupport.md                   |  22 +-
 llvm/include/llvm/Frontend/OpenMP/OMP.td      |   4 +-
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  47 ++--
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 189 +++++++++++-----
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |   1 +
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  36 ++-
 .../OpenMPToLLVM/convert-to-llvmir.mlir       |  19 ++
 .../Target/LLVMIR/openmp-dist_schedule.mlir   |  34 +++
 .../openmp-dist_schedule_with_wsloop.mlir     | 205 ++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  13 --
 10 files changed, 470 insertions(+), 100 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
 create mode 100644 mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir

diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 81f5f9f6dee5b..8eea39c6ba91b 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -42,10 +42,10 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | target update construct                                    | P      | device clause not supported |
 | declare target directive                                   | P      | |
 | teams construct                                            | Y      | |
-| distribute construct                                       | P      | dist_schedule clause not supported |
-| distribute simd construct                                  | P      | dist_schedule and linear clauses are not supported |
-| distribute parallel loop construct                         | P      | dist_schedule clause not supported |
-| distribute parallel loop simd construct                    | P      | dist_schedule and linear clauses are not supported |
+| distribute construct                                       | P      | |
+| distribute simd construct                                  | P      | linear clauses are not supported |
+| distribute parallel loop construct                         | P      | |
+| distribute parallel loop simd construct                    | P      | linear clauses are not supported |
 | depend clause                                              | Y      | |
 | declare reduction construct                                | N      | |
 | atomic construct extensions                                | Y      | |
@@ -53,13 +53,13 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
 | cancellation point construct                               | Y      | |
 | parallel do simd construct                                 | P      | linear clause not supported |
 | target teams construct                                     | P      | device clause not supported |
-| teams distribute construct                                 | P      | dist_schedule clause not supported |
-| teams distribute simd construct                            | P      | dist_schedule and linear clauses are not supported |
-| target teams distribute construct                          | P      | device and dist_schedule clauses are not supported |
-| teams distribute parallel loop construct                   | P      | dist_schedule clause not supported |
-| target teams distribute parallel loop construct            | P      | device and dist_schedule clauses are not supported |
-| teams distribute parallel loop simd construct              | P      | dist_schedule and linear clauses are not supported |
-| target teams distribute parallel loop simd construct       | P      | device, dist_schedule and linear clauses are not supported |
+| teams distribute construct                                 | P      | |
+| teams distribute simd construct                            | P      | linear clause is not supported |
+| target teams distribute construct                          | P      | device clause is not supported |
+| teams distribute parallel loop construct                   | P      | |
+| target teams distribute parallel loop construct            | P      | device clause is not supported |
+| teams distribute parallel loop simd construct              | P      | linear clause is not supported |
+| target teams distribute parallel loop simd construct       | P      | device and linear clauses are not supported |
 
 ## Extensions
 ### ATOMIC construct
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td
index 61a1a05f6e904..fdcae9916b95b 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -485,7 +485,8 @@ def OMP_SCHEDULE_Dynamic : EnumVal<"dynamic", 3, 1> {}
 def OMP_SCHEDULE_Guided : EnumVal<"guided", 4, 1> {}
 def OMP_SCHEDULE_Auto : EnumVal<"auto", 5, 1> {}
 def OMP_SCHEDULE_Runtime : EnumVal<"runtime", 6, 1> {}
-def OMP_SCHEDULE_Default : EnumVal<"default", 7, 0> { let isDefault = 1; }
+def OMP_SCHEDULE_Distribute : EnumVal<"distribute", 7, 1> {}
+def OMP_SCHEDULE_Default : EnumVal<"default", 8, 0> { let isDefault = 1; }
 def OMPC_Schedule : Clause<[Spelling<"schedule">]> {
   let clangClass = "OMPScheduleClause";
   let flangClass = "OmpScheduleClause";
@@ -496,6 +497,7 @@ def OMPC_Schedule : Clause<[Spelling<"schedule">]> {
     OMP_SCHEDULE_Guided,
     OMP_SCHEDULE_Auto,
     OMP_SCHEDULE_Runtime,
+    OMP_SCHEDULE_Distribute,
     OMP_SCHEDULE_Default
   ];
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 5331cb5abdc6f..fc6f59e5671e7 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1110,11 +1110,17 @@ class OpenMPIRBuilder {
   /// \param NeedsBarrier Indicates whether a barrier must be inserted after
   ///                     the loop.
   /// \param LoopType Type of workshare loop.
+  /// \param HasDistSchedule Defines if the clause being lowered is
+  /// dist_schedule as this is handled slightly differently
+  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute
+  /// loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointOrErrorTy applyStaticWorkshareLoop(
       DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-      omp::WorksharingLoopType LoopType, bool NeedsBarrier);
+      omp::WorksharingLoopType LoopType, bool NeedsBarrier,
+      bool HasDistSchedule = false,
+      omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
 
   /// Modifies the canonical loop a statically-scheduled workshare loop with a
   /// user-specified chunk size.
@@ -1127,13 +1133,22 @@ class OpenMPIRBuilder {
   /// \param NeedsBarrier Indicates whether a barrier must be inserted after the
   ///                     loop.
   /// \param ChunkSize    The user-specified chunk size.
+  /// \param SchedType    Optional type of scheduling to be passed to the init
+  /// function.
+  /// \param DistScheduleChunkSize    The size of dist_shcedule chunk considered
+  /// as a unit when
+  ///                 scheduling. If \p nullptr, defaults to 1.
+  /// \param DistScheduleSchedType Defines the Schedule Type for the Distribute
+  /// loop. Defaults to None if no Distribute loop is present.
   ///
   /// \returns Point where to insert code after the workshare construct.
-  InsertPointOrErrorTy applyStaticChunkedWorkshareLoop(DebugLoc DL,
-                                                       CanonicalLoopInfo *CLI,
-                                                       InsertPointTy AllocaIP,
-                                                       bool NeedsBarrier,
-                                                       Value *ChunkSize);
+  InsertPointOrErrorTy applyStaticChunkedWorkshareLoop(
+      DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+      bool NeedsBarrier, Value *ChunkSize,
+      omp::OMPScheduleType SchedType =
+          omp::OMPScheduleType::UnorderedStaticChunked,
+      Value *DistScheduleChunkSize = nullptr,
+      omp::OMPScheduleType DistScheduleSchedType = omp::OMPScheduleType::None);
 
   /// Modifies the canonical loop to be a dynamically-scheduled workshare loop.
   ///
@@ -1153,14 +1168,15 @@ class OpenMPIRBuilder {
   ///                     the loop.
   /// \param Chunk    The size of loop chunk considered as a unit when
   ///                 scheduling. If \p nullptr, defaults to 1.
+  /// \param DistScheduleChunk    The size of dist_shcedule chunk considered as
+  /// a unit when
+  ///                 scheduling. If \p nullptr, defaults to 1.
   ///
   /// \returns Point where to insert code after the workshare construct.
-  InsertPointOrErrorTy applyDynamicWorkshareLoop(DebugLoc DL,
-                                                 CanonicalLoopInfo *CLI,
-                                                 InsertPointTy AllocaIP,
-                                                 omp::OMPScheduleType SchedType,
-                                                 bool NeedsBarrier,
-                                                 Value *Chunk = nullptr);
+  InsertPointOrErrorTy applyDynamicWorkshareLoop(
+      DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+      omp::OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk = nullptr,
+      Value *DistScheduleChunk = nullptr);
 
   /// Create alternative version of the loop to support if clause
   ///
@@ -1212,6 +1228,10 @@ class OpenMPIRBuilder {
   /// \param LoopType Information about type of loop worksharing.
   ///                 It corresponds to type of loop workshare OpenMP pragma.
   /// \param NoLoop If true, no-loop code is generated.
+  /// \param HasDistSchedule Defines if the clause being lowered is
+  /// dist_schedule as this is handled slightly differently
+  ///
+  /// \param ChunkSize The chunk size for dist_schedule loop
   ///
   /// \returns Point where to insert code after the workshare construct.
   LLVM_ABI InsertPointOrErrorTy applyWorkshareLoop(
@@ -1223,7 +1243,8 @@ class OpenMPIRBuilder {
       bool HasOrderedClause = false,
       omp::WorksharingLoopType LoopType =
           omp::WorksharingLoopType::ForStaticLoop,
-      bool NoLoop = false);
+      bool NoLoop = false,
+      bool HasDistSchedule = false, Value *DistScheduleChunkSize = nullptr);
 
   /// Tile a loop nest.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 286ed039b1214..500b4e50978f4 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -14,6 +14,7 @@
 
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Analysis/AssumptionCache.h"
@@ -136,6 +137,8 @@ static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
   case OMPScheduleType::NomergeOrderedRuntime:
   case OMPScheduleType::NomergeOrderedAuto:
   case OMPScheduleType::NomergeOrderedTrapezoidal:
+  case OMPScheduleType::OrderedDistributeChunked:
+  case OMPScheduleType::OrderedDistribute:
     break;
   default:
     return false;
@@ -182,7 +185,7 @@ static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
 /// arguments.
 static OMPScheduleType
 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
-                          bool HasSimdModifier) {
+                          bool HasSimdModifier, bool HasDistScheduleChunks) {
   // Currently, the default schedule it static.
   switch (ClauseKind) {
   case OMP_SCHEDULE_Default:
@@ -199,6 +202,9 @@ getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
   case OMP_SCHEDULE_Runtime:
     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
                            : OMPScheduleType::BaseRuntime;
+  case OMP_SCHEDULE_Distribute:
+    return HasDistScheduleChunks ? OMPScheduleType::BaseDistributeChunked
+                                 : OMPScheduleType::BaseDistribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
 }
@@ -267,9 +273,10 @@ getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
 static OMPScheduleType
 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
                           bool HasSimdModifier, bool HasMonotonicModifier,
-                          bool HasNonmonotonicModifier, bool HasOrderedClause) {
-  OMPScheduleType BaseSchedule =
-      getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
+                          bool HasNonmonotonicModifier, bool HasOrderedClause,
+                          bool HasDistScheduleChunks) {
+  OMPScheduleType BaseSchedule = getOpenMPBaseScheduleType(
+      ClauseKind, HasChunks, HasSimdModifier, HasDistScheduleChunks);
   OMPScheduleType OrderedSchedule =
       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
@@ -4674,7 +4681,8 @@ static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    WorksharingLoopType LoopType, bool NeedsBarrier) {
+    WorksharingLoopType LoopType, bool NeedsBarrier, bool HasDistSchedule,
+    OMPScheduleType DistScheduleSchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -4730,15 +4738,26 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
 
   // Call the "init" function and update the trip count of the loop with the
   // value it produced.
-  SmallVector<Value *, 10> Args(
-      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound, PUpperBound});
-  if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
-    Value *PDistUpperBound =
-        Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
-    Args.push_back(PDistUpperBound);
+  auto BuildInitCall = [LoopType, SrcLoc, ThreadNum, PLastIter, PLowerBound,
+                        PUpperBound, IVTy, PStride, One, Zero,
+                        StaticInit](Value *SchedulingType, auto &Builder) {
+    SmallVector<Value *, 10> Args({SrcLoc, ThreadNum, SchedulingType, PLastIter,
+                                   PLowerBound, PUpperBound});
+    if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
+      Value *PDistUpperBound =
+          Builder.CreateAlloca(IVTy, nullptr, "p.distupperbound");
+      Args.push_back(PDistUpperBound);
+    }
+    Args.append({PStride, One, Zero});
+    Builder.CreateCall(StaticInit, Args);
+  };
+  BuildInitCall(SchedulingType, Builder);
+  if (HasDistSchedule &&
+      LoopType != WorksharingLoopType::DistributeStaticLoop) {
+    Constant *DistScheduleSchedType = ConstantInt::get(
+        I32Type, static_cast<int>(omp::OMPScheduleType::OrderedDistribute));
+    BuildInitCall(DistScheduleSchedType, Builder);
   }
-  Args.append({PStride, One, Zero});
-  Builder.CreateCall(StaticInit, Args);
   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
@@ -4777,14 +4796,44 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   return AfterIP;
 }
 
+static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                                   LoopInfo &LI);
+static void addLoopMetadata(CanonicalLoopInfo *Loop,
+                            ArrayRef<Metadata *> Properties);
+
+static void applyParallelAccessesMetadata(CanonicalLoopInfo *CLI,
+                                          LLVMContext &Ctx, Loop *Loop,
+                                          LoopInfo &LoopInfo,
+                                          SmallVector<Metadata *> &LoopMDList) {
+  SmallSet<BasicBlock *, 8> Reachable;
+
+  // Get the basic blocks from the loop in which memref instructions
+  // can be found.
+  // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
+  // preferably without running any passes.
+  for (BasicBlock *Block : Loop->getBlocks()) {
+    if (Block == CLI->getCond() || Block == CLI->getHeader())
+      continue;
+    Reachable.insert(Block);
+  }
+
+  // Add access group metadata to memory-access instructions.
+  MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
+  for (BasicBlock *BB : Reachable)
+    addAccessGroupMetadata(BB, AccessGroup, LoopInfo);
+  // TODO:  If the loop has existing parallel access metadata, have
+  // to combine two lists.
+  LoopMDList.push_back(MDNode::get(
+      Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
+}
+
 OpenMPIRBuilder::InsertPointOrErrorTy
-OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
-                                                 CanonicalLoopInfo *CLI,
-                                                 InsertPointTy AllocaIP,
-                                                 bool NeedsBarrier,
-                                                 Value *ChunkSize) {
+OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
+    DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+    bool NeedsBarrier, Value *ChunkSize, OMPScheduleType SchedType,
+    Value *DistScheduleChunkSize, OMPScheduleType DistScheduleSchedType) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
-  assert(ChunkSize && "Chunk size is required");
+  assert(ChunkSize || DistScheduleChunkSize && "Chunk size is required");
 
   LLVMContext &Ctx = CLI->getFunction()->getContext();
   Value *IV = CLI->getIndVar();
@@ -4798,6 +4847,18 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
   Constant *One = ConstantInt::get(InternalIVTy, 1);
 
+  Function *F = CLI->getFunction();
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([]() { return DominatorTreeAnalysis(); });
+  FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
+  LoopAnalysis LIA;
+  LoopInfo &&LI = LIA.run(*F, FAM);
+  Loop *L = LI.getLoopFor(CLI->getHeader());
+  SmallVector<Metadata *> LoopMDList;
+  if (ChunkSize || DistScheduleChunkSize)
+    applyParallelAccessesMetadata(CLI, Ctx, L, LI, LoopMDList);
+  addLoopMetadata(CLI, LoopMDList);
+
   // Declare useful OpenMP runtime functions.
   FunctionCallee StaticInit =
       getKmpcForStaticInitForType(InternalIVTy, M, *this);
@@ -4820,13 +4881,18 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Builder.SetCurrentDebugLocation(DL);
 
   // TODO: Detect overflow in ubsan or max-out with current tripcount.
-  Value *CastedChunkSize =
-      Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
+  Value *CastedChunkSize = Builder.CreateZExtOrTrunc(
+      ChunkSize ? ChunkSize : Zero, InternalIVTy, "chunksize");
+  Value *CastedDistScheduleChunkSize = Builder.CreateZExtOrTrunc(
+      DistScheduleChunkSize ? DistScheduleChunkSize : Zero, InternalIVTy,
+      "distschedulechunksize");
   Value *CastedTripCount =
       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
 
-  Constant *SchedulingType = ConstantInt::get(
-      I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
+  Constant *SchedulingType =
+      ConstantInt::get(I32Type, static_cast<int>(SchedType));
+  Constant *DistSchedulingType =
+      ConstantInt::get(I32Type, static_cast<int>(DistScheduleSchedType));
   Builder.CreateStore(Zero, PLowerBound);
   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
   Builder.CreateStore(OrigUpperBound, PUpperBound);
@@ -4838,12 +4904,25 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
-  Builder.CreateCall(StaticInit,
-                     {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
-                      /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
-                      /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
-                      /*pstride=*/PStride, /*incr=*/One,
-                      /*chunk=*/CastedChunkSize});
+  auto BuildInitCall =
+      [StaticInit, SrcLoc, ThreadNum, PLastIter, PLowerBound, PUpperBound,
+       PStride, One](Value *SchedulingType, Value *ChunkSize, auto &Builder) {
+        Builder.CreateCall(
+            StaticInit, {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
+                         /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
+                         /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
+                         /*pstride=*/PStride, /*incr=*/One,
+                         /*chunk=*/ChunkSize});
+      };
+  BuildInitCall(SchedulingType, CastedChunkSize, Builder);
+  if (DistScheduleSchedType != OMPScheduleType::None &&
+      SchedType != OMPScheduleType::OrderedDistributeChunked &&
+      SchedType != OMPScheduleType::OrderedDistribute) {
+    // We want to emit a second init function call for the dist_schedule clause
+    // to the Distribute construct. This should only be done however if a
+    // Workshare Loop is nested within a Distribute Construct
+    BuildInitCall(DistSchedulingType, CastedDistScheduleChunkSize, Builder);
+  }
 
   // Load values written by the "init" function.
   Value *FirstChunkStart =
@@ -5170,31 +5249,47 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyWorkshareLoop(
     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
     bool HasSimdModifier, bool HasMonotonicModifier,
     bool HasNonmonotonicModifier, bool HasOrderedClause,
-    WorksharingLoopType LoopType, bool NoLoop) {
+    WorksharingLoopType LoopType, bool NoLoop, bool HasDistSchedule,
+    Value *DistScheduleChunkSize) {
   if (Config.isTargetDevice())
     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType, NoLoop);
   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
-      HasNonmonotonicModifier, HasOrderedClause);
+      HasNonmonotonicModifier, HasOrderedClause, DistScheduleChunkSize);
 
   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
                    OMPScheduleType::ModifierOrdered;
+  OMPScheduleType DistScheduleSchedType = OMPScheduleType::None;
+  if (HasDistSchedule) {
+    DistScheduleSchedType = DistScheduleChunkSize
+                                ? OMPScheduleType::OrderedDistributeChunked
+                                : OMPScheduleType::OrderedDistribute;
+  }
   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
   case OMPScheduleType::BaseStatic:
-    assert(!ChunkSize && "No chunk size with static-chunked schedule");
-    if (IsOrdered)
+  case OMPScheduleType::BaseDistribute:
+    assert(!ChunkSize || !DistScheduleChunkSize &&
+                             "No chunk size with static-chunked schedule");
+    if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier);
+    if (DistScheduleChunkSize)
+      return applyStaticChunkedWorkshareLoop(
+          DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
+          DistScheduleChunkSize, DistScheduleSchedType);
+    return applyStaticWorkshareLoop(DL, CLI, AllocaIP, LoopType, NeedsBarrier,
+                                    HasDistSchedule);
 
   case OMPScheduleType::BaseStaticChunked:
-    if (IsOrdered)
+  case OMPScheduleType::BaseDistributeChunked:
+    if (IsOrdered && !HasDistSchedule)
       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
                                        NeedsBarrier, ChunkSize);
     // FIXME: Monotonicity ignored?
-    return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
-                                           ChunkSize);
+    return applyStaticChunkedWorkshareLoop(
+        DL, CLI, AllocaIP, NeedsBarrier, ChunkSize, EffectiveScheduleType,
+        DistScheduleChunkSize, DistScheduleSchedType);
 
   case OMPScheduleType::BaseRuntime:
   case OMPScheduleType::BaseAuto:
@@ -5270,7 +5365,8 @@ OpenMPIRBuilder::InsertPointOrErrorTy
 OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
                                            InsertPointTy AllocaIP,
                                            OMPScheduleType SchedType,
-                                           bool NeedsBarrier, Value *Chunk) {
+                                           bool NeedsBarrier, Value *Chunk,
+                                           Value *DistScheduleChunk) {
   assert(CLI->isValid() && "Requires a valid canonical loop");
   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
          "Require dedicated allocate IP");
@@ -5335,6 +5431,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Builder.CreateCall(DynamicInit,
                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
                       UpperBound, /* step */ One, Chunk});
+  // TODO Do we need a Init call here for if dist_schedule is present?
 
   // An outer loop around the existing one.
   BasicBlock *OuterCond = BasicBlock::Create(
@@ -5787,8 +5884,8 @@ static void addLoopMetadata(CanonicalLoopInfo *Loop,
 }
 
 /// Attach llvm.access.group metadata to the memref instructions of \p Block
-static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
-                            LoopInfo &LI) {
+static void addAccessGroupMetadata(BasicBlock *Block, MDNode *AccessGroup,
+                                   LoopInfo &LI) {
   for (Instruction &I : *Block) {
     if (I.mayReadOrWriteMemory()) {
       // TODO: This instruction may already have access group from
@@ -5978,16 +6075,8 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
   // dependences of 'safelen' iterations are possible.
   // If clause order(concurrent) is specified then the memory instructions
   // are marked parallel even if 'safelen' is finite.
-  if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
-    // Add access group metadata to memory-access instructions.
-    MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
-    for (BasicBlock *BB : Reachable)
-      addSimdMetadata(BB, AccessGroup, LI);
-    // TODO:  If the loop has existing parallel access metadata, have
-    // to combine two lists.
-    LoopMDList.push_back(MDNode::get(
-        Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
-  }
+  if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent))
+    applyParallelAccessesMetadata(CanonicalLoop, Ctx, L, LI, LoopMDList);
 
   // FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
   // versions so we can't add the loop attributes in that case.
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1b069c62a8be9..0d6b2870c625a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -617,6 +617,7 @@ parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
     break;
   case ClauseScheduleKind::Auto:
   case ClauseScheduleKind::Runtime:
+  case ClauseScheduleKind::Distribute:
     chunkSize = std::nullopt;
   }
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index f28454075f1d3..cf300973a9d23 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -61,6 +61,8 @@ convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
     return llvm::omp::OMP_SCHEDULE_Auto;
   case omp::ClauseScheduleKind::Runtime:
     return llvm::omp::OMP_SCHEDULE_Runtime;
+  case omp::ClauseScheduleKind::Distribute:
+    return llvm::omp::OMP_SCHEDULE_Distribute;
   }
   llvm_unreachable("unhandled schedule clause argument");
 }
@@ -319,10 +321,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getDevice())
       result = todo("device");
   };
-  auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
-    if (op.getDistScheduleChunkSize())
-      result = todo("dist_schedule with chunk_size");
-  };
   auto checkHint = [](auto op, LogicalResult &) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
@@ -387,7 +385,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::DistributeOp op) {
         checkAllocate(op, result);
-        checkDistSchedule(op, result);
         checkOrder(op, result);
       })
       .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
@@ -2484,6 +2481,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
     chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
   }
 
+  omp::DistributeOp distributeOp = nullptr;
+  llvm::Value *distScheduleChunk = nullptr;
+  bool hasDistSchedule = false;
+  if (llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())) {
+    distributeOp = cast<omp::DistributeOp>(opInst.getParentOp());
+    hasDistSchedule = distributeOp.getDistScheduleStatic();
+    if (distributeOp.getDistScheduleChunkSize()) {
+      llvm::Value *chunkVar = moduleTranslation.lookupValue(
+          distributeOp.getDistScheduleChunkSize());
+      distScheduleChunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
+    }
+  }
+
   PrivateVarsInfo privateVarsInfo(wsloopOp);
 
   SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -2611,7 +2621,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
           convertToScheduleKind(schedule), chunk, isSimd,
           scheduleMod == omp::ScheduleModifier::monotonic,
           scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-          workshareLoopType, noLoopMode);
+          workshareLoopType, noLoopMode, hasDistSchedule, distScheduleChunk);
 
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
@@ -4997,15 +5007,18 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
     if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
       // TODO: Add support for clauses which are valid for DISTRIBUTE
       // constructs. Static schedule is the default.
-      auto schedule = omp::ClauseScheduleKind::Static;
-      bool isOrdered = false;
+      bool hasDistSchedule = distributeOp.getDistScheduleStatic();
+      auto schedule = hasDistSchedule ? omp::ClauseScheduleKind::Distribute
+                                      : omp::ClauseScheduleKind::Static;
+      // dist_schedule clauses are ordered - otherise this should be false
+      bool isOrdered = hasDistSchedule;
       std::optional<omp::ScheduleModifier> scheduleMod;
       bool isSimd = false;
       llvm::omp::WorksharingLoopType workshareLoopType =
           llvm::omp::WorksharingLoopType::DistributeStaticLoop;
       bool loopNeedsBarrier = false;
-      llvm::Value *chunk = nullptr;
-
+      llvm::Value *chunk = moduleTranslation.lookupValue(
+          distributeOp.getDistScheduleChunkSize());
       llvm::CanonicalLoopInfo *loopInfo =
           findCurrentLoopInfo(moduleTranslation);
       llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
@@ -5014,12 +5027,11 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
               convertToScheduleKind(schedule), chunk, isSimd,
               scheduleMod == omp::ScheduleModifier::monotonic,
               scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
-              workshareLoopType);
+              workshareLoopType, false, hasDistSchedule, chunk);
 
       if (!wsloopIP)
         return wsloopIP.takeError();
     }
-
     if (failed(cleanupPrivateVars(builder, moduleTranslation,
                                   distributeOp.getLoc(), privVarsInfo.llvmVars,
                                   privVarsInfo.privatizers)))
diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
index f2fbe91a41ecd..b122f425f0752 100644
--- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
+++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir
@@ -615,3 +615,22 @@ omp.declare_mapper @my_mapper : !llvm.struct<"_QFdeclare_mapperTmy_type", (i32)>
   // CHECK: omp.declare_mapper.info map_entries(%{{.*}}, %{{.*}} : !llvm.ptr, !llvm.ptr)
   omp.declare_mapper.info map_entries(%3, %2 : !llvm.ptr, !llvm.ptr)
 }
+
+// CHECK-LABEL: llvm.func @omp_dist_schedule(%arg0: i32) {
+func.func @omp_dist_schedule(%arg0: i32) {
+  %c1_i32 = arith.constant 1 : i32
+  // CHECK: %1 = llvm.mlir.constant(1024 : i32) : i32
+  %c1024_i32 = arith.constant 1024 : i32
+  %c16_i32 = arith.constant 16 : i32
+  %c8_i32 = arith.constant 8 : i32
+  omp.teams num_teams( to %c8_i32 : i32) thread_limit(%c16_i32 : i32) {
+    // CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+    omp.distribute dist_schedule_static dist_schedule_chunk_size(%c1024_i32 : i32) {
+      omp.loop_nest (%arg1) : i32 = (%c1_i32) to (%arg0) inclusive step (%c1_i32) {
+        omp.terminator
+      }
+    }
+    omp.terminator
+  }
+  return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
new file mode 100644
index 0000000000000..e3142590de639
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule.mlir
@@ -0,0 +1,34 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @distribute_dist_schedule_chunk_size(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+  // CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+  // We want to make sure that the next call is not another init builder.
+  // CHECK: %omp_firstchunk.lb = load i32, ptr %p.lowerbound, align 4
+  %1 = llvm.mlir.constant(1024: i32) : i32
+  omp.distribute dist_schedule_static dist_schedule_chunk_size(%1 : i32) {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  llvm.return
+}
+
+// When a chunk size is present, we need to make sure the correct parallel accesses metadata is added
+// CHECK: !2 = !{!"llvm.loop.parallel_accesses", !3}
+// CHECK-NEXT: !3 = distinct !{}
+
+// -----
+
+llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
+  // CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+  // We want to make sure that the next call is not another init builder.
+  // CHECK: %18 = load i32, ptr %p.lowerbound, align 4
+  omp.distribute dist_schedule_static {
+    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      omp.yield
+    }
+  }
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
new file mode 100644
index 0000000000000..dad32b48e5419
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-dist_schedule_with_wsloop.mlir
@@ -0,0 +1,205 @@
+// Test that dist_schedule gets correctly translated with the correct schedule type and chunk size where appropriate while using workshare loops.
+
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
+
+llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked(%n: i32, %teams: i32, %threads: i32, %dcs: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %scs = llvm.mlir.constant(64 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+        omp.wsloop schedule(static = %scs : i32) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK:  call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 %3)
+
+llvm.func @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %dcs = llvm.mlir.constant(1024 : i64) : i64
+  %scs = llvm.mlir.constant(64 : i64) : i64
+  %n64 = llvm.zext %n : i32 to i64
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) {
+        omp.wsloop schedule(static = %scs : i64) {
+          omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64)
+// call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024)
+
+// -----
+
+llvm.func @distribute_wsloop_dist_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %dcs = llvm.mlir.constant(1024 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i32) {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 1024)
+
+llvm.func @distribute_wsloop_dist_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %dcs = llvm.mlir.constant(1024 : i64) : i64
+  %n64 = llvm.zext %n : i32 to i64
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static dist_schedule_chunk_size(%dcs : i64) {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_dist_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0)
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 91, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 1024)
+
+// -----
+
+llvm.func @distribute_wsloop_schedule_chunked(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+  %scs = llvm.mlir.constant(64 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static = %scs : i32) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_schedule_chunked..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 64)
+// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i32 1, i32 0)
+
+llvm.func @distribute_wsloop_schedule_chunked_i64(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %scs = llvm.mlir.constant(64 : i64) : i64
+  %n64 = llvm.zext %n : i32 to i64
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static = %scs : i64) {
+          omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+
+// CHECK: define internal void @distribute_wsloop_schedule_chunked_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 33, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 64)
+// CHECK: call void @__kmpc_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.stride, i64 1, i64 0)
+
+// -----
+
+llvm.func @distribute_wsloop_no_chunks(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i32) : i32
+  %1 = llvm.mlir.constant(1 : i32) : i32
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i32 = (%0) to (%n) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_no_chunks..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i32 1, i32 0)
+// CHECK: call void @__kmpc_dist_for_static_init_4u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i32 1, i32 0)
+
+llvm.func @distribute_wsloop_no_chunks_i64(%n: i32, %teams: i32, %threads: i32) {
+  %0 = llvm.mlir.constant(0 : i64) : i64
+  %1 = llvm.mlir.constant(1 : i64) : i64
+  %n64 = llvm.zext %n : i32 to i64
+
+  omp.teams num_teams(to %teams : i32) thread_limit(%threads : i32) {
+    omp.parallel {
+      omp.distribute dist_schedule_static {
+        omp.wsloop schedule(static) {
+          omp.loop_nest (%i) : i64 = (%0) to (%n64) step (%1) {
+            omp.yield
+          }
+        } {omp.composite}
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  llvm.return
+}
+// CHECK: define internal void @distribute_wsloop_no_chunks_i64..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #0 {
+// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 34, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound, ptr %p.stride, i64 1, i64 0)
+// CHECK: call void @__kmpc_dist_for_static_init_8u(ptr @1, i32 %omp_global_thread_num9, i32 92, ptr %p.lastiter, ptr %p.lowerbound, ptr %p.upperbound, ptr %p.distupperbound10, ptr %p.stride, i64 1, i64 0)
\ No newline at end of file
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index af6d254cfd3c3..731a6322736d4 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -39,19 +39,6 @@ llvm.func @distribute_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr
 
 // -----
 
-llvm.func @distribute_dist_schedule(%lb : i32, %ub : i32, %step : i32, %x : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause dist_schedule with chunk_size in omp.distribute operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.distribute}}
-  omp.distribute dist_schedule_static dist_schedule_chunk_size(%x : i32) {
-    omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-      omp.yield
-    }
-  }
-  llvm.return
-}
-
-// -----
-
 llvm.func @distribute_order(%lb : i32, %ub : i32, %step : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause order in omp.distribute operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.distribute}}



More information about the Mlir-commits mailing list