[llvm-branch-commits] [mlir] [MLIR][OpenMP] Add host_eval clause to omp.target (PR #116049)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Dec 4 05:16:21 PST 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116049

>From 4d20a42bc7b9f93efab561dd931a31867a928829 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 8 Nov 2024 12:00:45 +0000
Subject: [PATCH 1/2] [MLIR][OpenMP] Add host_eval clause to omp.target

This patch adds the `host_eval` clause to the `omp.target` operation.
Additionally, it updates its op verifier to make sure all uses of block
arguments defined by this clause fall within one of the few cases where they
are allowed.

MLIR to LLVM IR translation fails on translation of this clause with a
not-yet-implemented error.
---
 mlir/docs/Dialects/OpenMPDialect/_index.md    |  58 +++++-
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  34 +++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 167 +++++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |   5 +
 mlir/test/Dialect/OpenMP/invalid.mlir         |  70 +++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             |  38 +++-
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  14 ++
 7 files changed, 371 insertions(+), 15 deletions(-)

diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 03d5b95217cce07..b651b3c06485c61 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -298,7 +298,8 @@ introduction of private copies of the same underlying variable defined outside
 the MLIR operation the clause is attached to. Currently, clauses with this
 property can be classified into three main categories:
   - Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
-  specification), `map`, `use_device_addr` and `use_device_ptr`.
+  specification: [see more](#host-evaluated-clauses-in-target-regions)), `map`,
+  `use_device_addr` and `use_device_ptr`.
   - Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
   - Privatization clauses: `private`.
 
@@ -523,3 +524,58 @@ omp.parallel ... {
   omp.terminator
 } {omp.composite}
 ```
+
+## Host-Evaluated Clauses in Target Regions
+
+The `omp.target` operation, which represents the OpenMP `target` construct, is
+marked with the `IsolatedFromAbove` trait. This means that, inside of its
+region, no MLIR values defined outside of the op itself can be used. This is
+consistent with the OpenMP specification of the `target` construct, which
+mandates that all host device values used inside of the `target` region must
+either be privatized (data-sharing) or mapped (data-mapping).
+
+Normally, clauses applied to a construct are evaluated before entering that
+construct. Further, in some cases, the OpenMP specification stipulates that
+clauses be evaluated _on the host device_ on entry to a parent `target`
+construct. In particular, the `num_teams` and `thread_limit` clauses of the
+`teams` construct must be evaluated on the host device if it's nested inside or
+combined with a `target` construct.
+
+Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
+the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
+`target teams distribute parallel {do,for}` in OpenMP), which requires
+specifying in advance what the total trip count of the loop is. Consequently, it
+is also beneficial to evaluate the trip count on the host device prior to the
+kernel launch.
+
+These host-evaluated values in MLIR would need to be placed outside of the
+`omp.target` region and also attached to the corresponding nested operations,
+which is not possible because of the `IsolatedFromAbove` trait. The solution
+implemented to address this problem has been to introduce the `host_eval`
+argument to the `omp.target` operation. It works similarly to a `map` clause,
+but its only intended use is to forward host-evaluated values to their
+corresponding operation inside of the region. Any uses outside of the previously
+described result in a verifier error.
+
+```mlir
+// Initialize %0, %1, %2, %3...
+omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
+  omp.teams num_teams(to %nt : i32) {
+    omp.parallel {
+      omp.distribute {
+        omp.wsloop {
+          omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+            // ...
+            omp.yield
+          }
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  omp.terminator
+}
+```
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index f6c7f19fffddf9e..4f9c772a5ee28d7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1213,9 +1213,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
-    OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
-    OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
-    OpenMP_NowaitClause, OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+    OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
+    OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
+    OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
+    OpenMP_PrivateClause, OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "target construct";
   let description = [{
@@ -1257,17 +1258,34 @@ def TargetOp : OpenMP_Op<"target", traits = [
 
       return getMapVars()[mapInfoOpIdx];
     }
+
+    /// Returns the innermost OpenMP dialect operation captured by this target
+    /// construct. For an operation to be detected as captured, it must be
+    /// inside a (possibly multi-level) nest of OpenMP dialect operation's
+    /// regions where none of these levels contain other operations considered
+    /// not-allowed for these purposes (i.e. only terminator operations are
+    /// allowed from the OpenMP dialect, and other dialect's operations are
+    /// allowed as long as they don't have a memory write effect).
+    ///
+    /// If there are omp.loop_nest operations in the sequence of nested
+    /// operations, the top level one will be the one captured.
+    Operation *getInnermostCapturedOmpOp();
+
+    /// Checks whether this target region represents the MLIR equivalent to a
+    /// 'target teams distribute parallel {do, for} [simd]' OpenMP construct.
+    bool isTargetSPMDLoop();
   }] # clausesExtraClassDeclaration;
 
   let assemblyFormat = clausesAssemblyFormat # [{
-    custom<InReductionMapPrivateRegion>(
-        $region, $in_reduction_vars, type($in_reduction_vars),
-        $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms, $private_maps)
-        attr-dict
+    custom<HostEvalInReductionMapPrivateRegion>(
+        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
+        type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
+        $map_vars, type($map_vars), $private_vars, type($private_vars),
+        $private_syms, $private_maps) attr-dict
   }];
 
   let hasVerifier = 1;
+  let hasRegionVerifier = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index f626d18e9f4d691..139e57a0d4ce9b7 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -691,8 +691,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
   return parser.parseRegion(region, entryBlockArgs);
 }
 
-static ParseResult parseInReductionMapPrivateRegion(
+static ParseResult parseHostEvalInReductionMapPrivateRegion(
     OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
+    SmallVectorImpl<Type> &hostEvalTypes,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -702,6 +704,7 @@ static ParseResult parseInReductionMapPrivateRegion(
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
     DenseI64ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
+  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
@@ -931,13 +934,15 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
-static void printInReductionMapPrivateRegion(
-    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
+static void printHostEvalInReductionMapPrivateRegion(
+    OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
+    TypeRange hostEvalTypes, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
     ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
     DenseI64ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
+  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
@@ -1719,7 +1724,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   // inReductionByref, inReductionSyms.
   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
                   makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
+                  clauses.device, clauses.hasDeviceAddrVars,
+                  clauses.hostEvalVars, clauses.ifExpr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1742,6 +1748,159 @@ LogicalResult TargetOp::verify() {
   return verifyPrivateVarsMapping(*this);
 }
 
+LogicalResult TargetOp::verifyRegions() {
+  auto teamsOps = getOps<TeamsOp>();
+  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
+    return emitError("target containing multiple 'omp.teams' nested ops");
+
+  // Check that host_eval values are only used in legal ways.
+  bool isTargetSPMD = isTargetSPMDLoop();
+  for (Value hostEvalArg :
+       cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
+    for (Operation *user : hostEvalArg.getUsers()) {
+      if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
+        if (llvm::is_contained({teamsOp.getNumTeamsLower(),
+                                teamsOp.getNumTeamsUpper(),
+                                teamsOp.getThreadLimit()},
+                               hostEvalArg))
+          continue;
+
+        return emitOpError() << "host_eval argument only legal as 'num_teams' "
+                                "and 'thread_limit' in 'omp.teams'";
+      }
+      if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
+        if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
+          continue;
+
+        return emitOpError()
+               << "host_eval argument only legal as 'num_threads' in "
+                  "'omp.parallel' when representing target SPMD";
+      }
+      if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
+        if (isTargetSPMD &&
+            (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
+             llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
+             llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
+          continue;
+
+        return emitOpError()
+               << "host_eval argument only legal as loop bounds and steps in "
+                  "'omp.loop_nest' when representing target SPMD";
+      }
+
+      return emitOpError() << "host_eval argument illegal use in '"
+                           << user->getName() << "' operation";
+    }
+  }
+  return success();
+}
+
+/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
+/// effects, but don't include a memory write effect.
+static bool siblingAllowedInCapture(Operation *op) {
+  if (!op)
+    return false;
+
+  bool isOmpDialect =
+      op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
+      op->getDialect();
+
+  if (isOmpDialect)
+    return op->hasTrait<OpTrait::IsTerminator>();
+
+  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
+    SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
+    memOp.getEffects(effects);
+    return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+      return isa<MemoryEffects::Write>(effect.getEffect()) &&
+             isa<SideEffects::AutomaticAllocationScopeResource>(
+                 effect.getResource());
+    });
+  }
+  return true;
+}
+
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+  Dialect *ompDialect = (*this)->getDialect();
+  Operation *capturedOp = nullptr;
+
+  // Process in pre-order to check operations from outermost to innermost,
+  // ensuring we only enter the region of an operation if it meets the criteria
+  // for being captured. We stop the exploration of nested operations as soon as
+  // we process a region holding no operations to be captured.
+  walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op == *this)
+      return WalkResult::advance();
+
+    // Ignore operations of other dialects or omp operations with no regions,
+    // because these will only be checked if they are siblings of an omp
+    // operation that can potentially be captured.
+    bool isOmpDialect = op->getDialect() == ompDialect;
+    bool hasRegions = op->getNumRegions() > 0;
+    if (!isOmpDialect || !hasRegions)
+      return WalkResult::skip();
+
+    // Don't capture this op if it has a not-allowed sibling, and stop recursing
+    // into nested operations.
+    for (Operation &sibling : op->getParentRegion()->getOps())
+      if (&sibling != op && !siblingAllowedInCapture(&sibling))
+        return WalkResult::interrupt();
+
+    // Don't continue capturing nested operations if we reach an omp.loop_nest.
+    // Otherwise, process the contents of this operation.
+    capturedOp = op;
+    return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
+                                     : WalkResult::advance();
+  });
+
+  return capturedOp;
+}
+
+bool TargetOp::isTargetSPMDLoop() {
+  // The expected MLIR representation for a target SPMD loop is:
+  // omp.target {
+  //   omp.teams {
+  //     omp.parallel {
+  //       omp.distribute {
+  //         omp.wsloop {
+  //           omp.loop_nest ... { ... }
+  //         } {omp.composite}
+  //       } {omp.composite}
+  //       omp.terminator
+  //     } {omp.composite}
+  //     omp.terminator
+  //   }
+  //   omp.terminator
+  // }
+
+  Operation *capturedOp = getInnermostCapturedOmpOp();
+  if (!isa_and_present<LoopNestOp>(capturedOp))
+    return false;
+
+  Operation *workshareOp = capturedOp->getParentOp();
+
+  // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
+  if (isa_and_present<SimdOp>(workshareOp))
+    workshareOp = workshareOp->getParentOp();
+
+  if (!isa_and_present<WsloopOp>(workshareOp))
+    return false;
+
+  Operation *distributeOp = workshareOp->getParentOp();
+  if (!isa_and_present<DistributeOp>(distributeOp))
+    return false;
+
+  Operation *parallelOp = distributeOp->getParentOp();
+  if (!isa_and_present<ParallelOp>(parallelOp))
+    return false;
+
+  Operation *teamsOp = parallelOp->getParentOp();
+  if (!isa_and_present<TeamsOp>(teamsOp))
+    return false;
+
+  return teamsOp->getParentOp() == (*this);
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 35b0633a04a3522..16e31e9c21af841 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -174,6 +174,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
+  auto checkHostEval = [&todo](auto op, LogicalResult &result) {
+    if (!op.getHostEvalVars().empty())
+      result = todo("host_eval");
+  };
   auto checkIf = [&todo](auto op, LogicalResult &result) {
     if (op.getIfExpr())
       result = todo("if");
@@ -286,6 +290,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkAllocate(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
+        checkHostEval(op, result);
         checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2a19e4837f55042..161f5e6e1915a31 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2138,11 +2138,79 @@ func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_multiple_teams() {
+  // expected-error @below {{target containing multiple 'omp.teams' nested ops}}
+  omp.target {
+    omp.teams {
+      omp.terminator
+    }
+    omp.teams {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval(%x : !llvm.ptr) {
+  // expected-error @below {{op host_eval argument illegal use in 'llvm.load' operation}}
+  omp.target host_eval(%x -> %arg0 : !llvm.ptr) {
+    %0 = llvm.load %arg0 : !llvm.ptr -> f32
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_teams(%x : i1) {
+  // expected-error @below {{op host_eval argument only legal as 'num_teams' and 'thread_limit' in 'omp.teams'}}
+  omp.target host_eval(%x -> %arg0 : i1) {
+    omp.teams if(%arg0) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_parallel(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as 'num_threads' in 'omp.parallel' when representing target SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.parallel num_threads(%arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_loop(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.wsloop {
+      omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 94c63dd8e9aa0e7..cde8e3b46237e83 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -770,7 +770,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2774,6 +2774,42 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?
   return
 }
 
+func.func @omp_target_host_eval(%x : i32) {
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
+  // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams num_teams(to %arg0 : i32) thread_limit(%arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams
+  // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+  // CHECK: omp.distribute {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) {
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams {
+      omp.parallel num_threads(%arg0 : i32) {
+        omp.distribute {
+          omp.wsloop {
+            omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+              omp.yield
+            }
+          } {omp.composite}
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: omp_loop
 func.func @omp_loop(%lb : index, %ub : index, %step : index) {
   // CHECK: omp.loop {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index de797ea2aa3649b..b8b851cdf97f2b8 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -278,6 +278,20 @@ llvm.func @target_has_device_addr(%x : !llvm.ptr) {
 
 // -----
 
+llvm.func @target_host_eval(%x : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause host_eval in omp.target operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams num_teams(to %arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @target_if(%x : i1) {
   // expected-error at below {{not yet implemented: Unhandled clause if in omp.target operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.target}}

>From 27ffa9f980cc7c6922f188417b708c6a9d8121f2 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 4 Dec 2024 13:11:43 +0000
Subject: [PATCH 2/2] More robust kernel type detection

---
 .../mlir/Dialect/OpenMP/OpenMPDialect.h       |   1 +
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   6 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 109 +++++++++++-------
 mlir/test/Dialect/OpenMP/invalid.mlir         |  28 ++++-
 mlir/test/Dialect/OpenMP/ops.mlir             |  18 ++-
 5 files changed, 117 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
index bee21432196e423..248ac2eb72c61a9 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
@@ -22,6 +22,7 @@
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
 
 #define GET_TYPEDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 4f9c772a5ee28d7..1050f89b6021116 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1271,9 +1271,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
     /// operations, the top level one will be the one captured.
     Operation *getInnermostCapturedOmpOp();
 
-    /// Checks whether this target region represents the MLIR equivalent to a
-    /// 'target teams distribute parallel {do, for} [simd]' OpenMP construct.
-    bool isTargetSPMDLoop();
+    /// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
+    /// contents of the target region.
+    llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
   }] # clausesExtraClassDeclaration;
 
   let assemblyFormat = clausesAssemblyFormat # [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 139e57a0d4ce9b7..88f4c9f15958237 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
 #include <cstddef>
 #include <iterator>
 #include <optional>
@@ -1754,7 +1755,7 @@ LogicalResult TargetOp::verifyRegions() {
     return emitError("target containing multiple 'omp.teams' nested ops");
 
   // Check that host_eval values are only used in legal ways.
-  bool isTargetSPMD = isTargetSPMDLoop();
+  llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
   for (Value hostEvalArg :
        cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
     for (Operation *user : hostEvalArg.getUsers()) {
@@ -1769,7 +1770,8 @@ LogicalResult TargetOp::verifyRegions() {
                                 "and 'thread_limit' in 'omp.teams'";
       }
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
-        if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
+        if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+            hostEvalArg == parallelOp.getNumThreads())
           continue;
 
         return emitOpError()
@@ -1777,15 +1779,15 @@ LogicalResult TargetOp::verifyRegions() {
                   "'omp.parallel' when representing target SPMD";
       }
       if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
-        if (isTargetSPMD &&
+        if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
             (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
           continue;
 
-        return emitOpError()
-               << "host_eval argument only legal as loop bounds and steps in "
-                  "'omp.loop_nest' when representing target SPMD";
+        return emitOpError() << "host_eval argument only legal as loop bounds "
+                                "and steps in 'omp.loop_nest' when "
+                                "representing target SPMD or Generic-SPMD";
       }
 
       return emitOpError() << "host_eval argument illegal use in '"
@@ -1823,6 +1825,7 @@ static bool siblingAllowedInCapture(Operation *op) {
 Operation *TargetOp::getInnermostCapturedOmpOp() {
   Dialect *ompDialect = (*this)->getDialect();
   Operation *capturedOp = nullptr;
+  DominanceInfo domInfo;
 
   // Process in pre-order to check operations from outermost to innermost,
   // ensuring we only enter the region of an operation if it meets the criteria
@@ -1840,6 +1843,22 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
     if (!isOmpDialect || !hasRegions)
       return WalkResult::skip();
 
+    // This operation cannot be captured if it can be executed more than once
+    // (i.e. its block's successors can reach it) or if it's not guaranteed to
+    // be executed before all exits of the region (i.e. it doesn't dominate all
+    // blocks with no successors reachable from the entry block).
+    Region *parentRegion = op->getParentRegion();
+    Block *parentBlock = op->getBlock();
+
+    for (Block *successor : parentBlock->getSuccessors())
+      if (successor->isReachable(parentBlock))
+        return WalkResult::interrupt();
+
+    for (Block &block : *parentRegion)
+      if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
+          !domInfo.dominates(parentBlock, &block))
+        return WalkResult::interrupt();
+
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
     for (Operation &sibling : op->getParentRegion()->getOps())
@@ -1856,49 +1875,61 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   return capturedOp;
 }
 
-bool TargetOp::isTargetSPMDLoop() {
-  // The expected MLIR representation for a target SPMD loop is:
-  // omp.target {
-  //   omp.teams {
-  //     omp.parallel {
-  //       omp.distribute {
-  //         omp.wsloop {
-  //           omp.loop_nest ... { ... }
-  //         } {omp.composite}
-  //       } {omp.composite}
-  //       omp.terminator
-  //     } {omp.composite}
-  //     omp.terminator
-  //   }
-  //   omp.terminator
-  // }
+llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
+  using namespace llvm::omp;
 
+  // Make sure this region is capturing a loop. Otherwise, it's a generic
+  // kernel.
   Operation *capturedOp = getInnermostCapturedOmpOp();
   if (!isa_and_present<LoopNestOp>(capturedOp))
-    return false;
+    return OMP_TGT_EXEC_MODE_GENERIC;
 
-  Operation *workshareOp = capturedOp->getParentOp();
+  SmallVector<LoopWrapperInterface> wrappers;
+  cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
+  assert(!wrappers.empty());
 
-  // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
-  if (isa_and_present<SimdOp>(workshareOp))
-    workshareOp = workshareOp->getParentOp();
+  // Ignore optional SIMD leaf construct.
+  auto *innermostWrapper = wrappers.begin();
+  if (isa<SimdOp>(innermostWrapper))
+    innermostWrapper = std::next(innermostWrapper);
 
-  if (!isa_and_present<WsloopOp>(workshareOp))
-    return false;
+  long numWrappers = std::distance(innermostWrapper, wrappers.end());
 
-  Operation *distributeOp = workshareOp->getParentOp();
-  if (!isa_and_present<DistributeOp>(distributeOp))
-    return false;
+  // Detect Generic-SPMD: target-teams-distribute[-simd].
+  if (numWrappers == 1) {
+    if (!isa<DistributeOp>(innermostWrapper))
+      return OMP_TGT_EXEC_MODE_GENERIC;
 
-  Operation *parallelOp = distributeOp->getParentOp();
-  if (!isa_and_present<ParallelOp>(parallelOp))
-    return false;
+    Operation *teamsOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<TeamsOp>(teamsOp))
+      return OMP_TGT_EXEC_MODE_GENERIC;
 
-  Operation *teamsOp = parallelOp->getParentOp();
-  if (!isa_and_present<TeamsOp>(teamsOp))
-    return false;
+    if (teamsOp->getParentOp() == *this)
+      return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
+  }
+
+  // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+  if (numWrappers == 2) {
+    if (!isa<WsloopOp>(innermostWrapper))
+      return OMP_TGT_EXEC_MODE_GENERIC;
+
+    innermostWrapper = std::next(innermostWrapper);
+    if (!isa<DistributeOp>(innermostWrapper))
+      return OMP_TGT_EXEC_MODE_GENERIC;
+
+    Operation *parallelOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return OMP_TGT_EXEC_MODE_GENERIC;
+
+    Operation *teamsOp = parallelOp->getParentOp();
+    if (!isa_and_present<TeamsOp>(teamsOp))
+      return OMP_TGT_EXEC_MODE_GENERIC;
+
+    if (teamsOp->getParentOp() == *this)
+      return OMP_TGT_EXEC_MODE_SPMD;
+  }
 
-  return teamsOp->getParentOp() == (*this);
+  return OMP_TGT_EXEC_MODE_GENERIC;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 161f5e6e1915a31..77bd59976882924 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2191,8 +2191,8 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
 
 // -----
 
-func.func @omp_target_host_eval_loop(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD}}
+func.func @omp_target_host_eval_loop1(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.wsloop {
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2206,6 +2206,30 @@ func.func @omp_target_host_eval_loop(%x : i32) {
 
 // -----
 
+func.func @omp_target_host_eval_loop2(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams {
+    ^bb0:
+      %0 = arith.constant 0 : i1
+      llvm.cond_br %0, ^bb1, ^bb2
+    ^bb1:
+      omp.distribute {
+        omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+          omp.yield
+        }
+      }
+      llvm.br ^bb2
+    ^bb2:
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index cde8e3b46237e83..296535c867d59d8 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2786,7 +2786,7 @@ func.func @omp_target_host_eval(%x : i32) {
   }
 
   // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
-  // CHECK: omp.teams
+  // CHECK: omp.teams {
   // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
   // CHECK: omp.distribute {
   // CHECK: omp.wsloop {
@@ -2807,6 +2807,22 @@ func.func @omp_target_host_eval(%x : i32) {
     }
     omp.terminator
   }
+
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams {
+  // CHECK: omp.distribute {
+  // CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) {
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams {
+      omp.distribute {
+        omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
   return
 }
 



More information about the llvm-branch-commits mailing list