[llvm-branch-commits] [mlir] [MLIR][OpenMP] Support target SPMD (PR #127821)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Feb 24 04:57:02 PST 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/127821

>From 03651527c28805cb1d786c5674d4bf5644bd1642 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 19 Feb 2025 14:41:12 +0000
Subject: [PATCH] [MLIR][OpenMP] Support target SPMD

This patch implements MLIR to LLVM IR translation of host-evaluated loop
bounds, completing initial support for `target teams distribute parallel do
[simd]` and `target teams distribute [simd]`.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 83 ++++++++++++----
 .../Target/LLVMIR/openmp-target-spmd.mlir     | 96 +++++++++++++++++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 24 -----
 3 files changed, 159 insertions(+), 44 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/openmp-target-spmd.mlir

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index c9c074707d3b1..f4387921a3d64 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -173,15 +173,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
-  auto checkHostEval = [](auto op, LogicalResult &result) {
-    // Host evaluated clauses are supported, except for loop bounds.
-    for (BlockArgument arg :
-         cast<omp::BlockArgOpenMPOpInterface>(*op).getHostEvalBlockArgs())
-      for (Operation *user : arg.getUsers())
-        if (isa<omp::LoopNestOp>(user))
-          result = op.emitError("not yet implemented: host evaluation of loop "
-                                "bounds in omp.target operation");
-  };
   auto checkInReduction = [&todo](auto op, LogicalResult &result) {
     if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
         op.getInReductionSyms())
@@ -318,7 +309,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkBare(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
-        checkHostEval(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
         checkPrivate(op, result);
@@ -4158,9 +4148,13 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
-                                   Value &numTeamsLower, Value &numTeamsUpper,
-                                   Value &threadLimit) {
+static void
+extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
+                       Value &numTeamsLower, Value &numTeamsUpper,
+                       Value &threadLimit,
+                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -4185,11 +4179,26 @@ static void extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::LoopNestOp loopOp) {
-            // TODO: Extract bounds and step values. Currently, this cannot be
-            // reached because translation would have been stopped earlier as a
-            // result of `checkImplementationStatus` detecting and reporting
-            // this situation.
-            llvm_unreachable("unsupported host_eval use");
+            auto processBounds =
+                [&](OperandRange opBounds,
+                    llvm::SmallVectorImpl<Value> *outBounds) -> bool {
+              bool found = false;
+              for (auto [i, lb] : llvm::enumerate(opBounds)) {
+                if (lb == blockArg) {
+                  found = true;
+                  if (outBounds)
+                    (*outBounds)[i] = hostEvalVar;
+                }
+              }
+              return found;
+            };
+            bool found =
+                processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
+            found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
+                    found;
+            found = processBounds(loopOp.getLoopSteps(), steps) || found;
+            (void)found;
+            assert(found && "unsupported host_eval use");
           })
           .Default([](Operation *) {
             llvm_unreachable("unsupported host_eval use");
@@ -4326,6 +4335,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
+  attrs.ExecFlags = targetOp.getKernelExecFlags();
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4343,9 +4353,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                        LLVM::ModuleTranslation &moduleTranslation,
                        omp::TargetOp targetOp,
                        llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
+  omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
+      targetOp.getInnermostCapturedOmpOp());
+  unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
+
   Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
+      steps(numLoops);
   extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                         teamsThreadLimit);
+                         teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
   if (Value targetThreadLimit = targetOp.getThreadLimit())
@@ -4365,7 +4381,34 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  // TODO: Populate attrs.LoopTripCount if it is target SPMD.
+  if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+    llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+    attrs.LoopTripCount = nullptr;
+
+    // To calculate the trip count, we multiply together the trip counts of
+    // every collapsed canonical loop. We don't need to create the loop nests
+    // here, since we're only interested in the trip count.
+    for (auto [loopLower, loopUpper, loopStep] :
+         llvm::zip_equal(lowerBounds, upperBounds, steps)) {
+      llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
+      llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
+      llvm::Value *step = moduleTranslation.lookupValue(loopStep);
+
+      llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+      llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
+          loc, lowerBound, upperBound, step, /*IsSigned=*/true,
+          loopOp.getLoopInclusive());
+
+      if (!attrs.LoopTripCount) {
+        attrs.LoopTripCount = tripCount;
+        continue;
+      }
+
+      // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
+      attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
+                                              {}, /*HasNUW=*/true);
+    }
+  }
 }
 
 static LogicalResult
diff --git a/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
new file mode 100644
index 0000000000000..7930554cbe11a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-spmd.mlir
@@ -0,0 +1,96 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @main(%x : i32) {
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.teams {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// HOST-LABEL: define void @main
+// HOST:         %omp_loop.tripcount = {{.*}}
+// HOST-NEXT:    br label %[[ENTRY:.*]]
+// HOST:       [[ENTRY]]:
+// HOST-NEXT:    %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST:       [[OFFLOAD_FAILED]]:
+// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[TARGET_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[TEAMS_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[PARALLEL_OUTLINE]]
+// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST:         call void @__kmpc_dist_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 34, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+
+//--- device.mlir
+
+module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
+  llvm.func @main(%x : i32) {
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
+      omp.teams {
+        omp.parallel {
+          omp.distribute {
+            omp.wsloop {
+              omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+                omp.yield
+              }
+            } {omp.composite}
+          } {omp.composite}
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// DEVICE:      @[[KERNEL_NAME:.*]]_exec_mode = weak protected constant i8 2
+// DEVICE:      @llvm.compiler.used = appending global [1 x ptr] [ptr @[[KERNEL_NAME]]_exec_mode], section "llvm.metadata"
+// DEVICE:      @[[KERNEL_NAME]]_kernel_environment = weak_odr protected constant %struct.KernelEnvironmentTy {
+// DEVICE-SAME: %struct.ConfigurationEnvironmentTy { i8 0, i8 1, i8 [[EXEC_MODE:2]], {{.*}}},
+// DEVICE-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
+// DEVICE:      define weak_odr protected amdgpu_kernel void @[[KERNEL_NAME]]({{.*}})
+// DEVICE:        %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL_NAME]]_kernel_environment, {{.*}})
+// DEVICE:        call void @[[TARGET_OUTLINE:.*]]({{.*}})
+// DEVICE:        call void @__kmpc_target_deinit()
+
+// DEVICE:      define internal void @[[TARGET_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_parallel_51(ptr {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], ptr {{.*}}, ptr {{.*}}, i64 {{.*}})
+
+// DEVICE:      define internal void @[[PARALLEL_OUTLINE]]({{.*}})
+// DEVICE:        call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// DEVICE:      define internal void @[[DISTRIBUTE_OUTLINE]]({{.*}})
+// DEVICE:        call void @__kmpc_distribute_for_static_loop{{.*}}({{.*}})
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index d1c745af9bff5..f907bb3f94a2a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -319,30 +319,6 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
 
 // -----
 
-llvm.func @target_host_eval(%x : i32) {
-  // expected-error at below {{not yet implemented: host evaluation of loop bounds in omp.target operation}}
-  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
-  omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) {
-    omp.teams {
-      omp.parallel {
-        omp.distribute {
-          omp.wsloop {
-            omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
-              omp.yield
-            }
-          } {omp.composite}
-        } {omp.composite}
-        omp.terminator
-      } {omp.composite}
-      omp.terminator
-    }
-    omp.terminator
-  }
-  llvm.return
-}
-
-// -----
-
 omp.declare_reduction @add_f32 : f32
 init {
 ^bb0(%arg: f32):



More information about the llvm-branch-commits mailing list