[Mlir-commits] [mlir] [MLIR][OpenMP] Skip host omp ops when compiling for the target device (PR #85239)

Sergio Afonso llvmlistbot at llvm.org
Mon Mar 18 04:37:06 PDT 2024


================
@@ -2922,6 +2922,189 @@ convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
   return success();
 }
 
+static bool isInternalTargetDeviceOp(Operation *op) {
+  // Assumes no reverse offloading
+  if (op->getParentOfType<omp::TargetOp>())
+    return true;
+
+  auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>();
+  if (auto declareTargetIface =
+          llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
+              parentFn.getOperation()))
+    if (declareTargetIface.isDeclareTarget() &&
+        declareTargetIface.getDeclareTargetDeviceType() !=
+            mlir::omp::DeclareTargetDeviceType::host)
+      return true;
+
+  return false;
+}
+
+/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
+/// (including OpenMP runtime calls).
+static LogicalResult
+convertCommonOperation(Operation *op, llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation) {
+
+  llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+
+  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+      .Case([&](omp::BarrierOp) {
+        ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
+        return success();
+      })
+      .Case([&](omp::TaskwaitOp) {
+        ompBuilder->createTaskwait(builder.saveIP());
+        return success();
+      })
+      .Case([&](omp::TaskyieldOp) {
+        ompBuilder->createTaskyield(builder.saveIP());
+        return success();
+      })
+      .Case([&](omp::FlushOp) {
+        // No support in Openmp runtime function (__kmpc_flush) to accept
+        // the argument list.
+        // OpenMP standard states the following:
+        //  "An implementation may implement a flush with a list by ignoring
+        //   the list, and treating it the same as a flush without a list."
+        //
+        // The argument list is discarded so that, flush with a list is treated
+        // same as a flush without a list.
+        ompBuilder->createFlush(builder.saveIP());
+        return success();
+      })
+      .Case([&](omp::ParallelOp op) {
+        return convertOmpParallel(op, builder, moduleTranslation);
+      })
+      .Case([&](omp::ReductionOp reductionOp) {
+        return convertOmpReductionOp(reductionOp, builder, moduleTranslation);
+      })
+      .Case([&](omp::MasterOp) {
+        return convertOmpMaster(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::CriticalOp) {
+        return convertOmpCritical(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::OrderedRegionOp) {
+        return convertOmpOrderedRegion(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::OrderedOp) {
+        return convertOmpOrdered(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::WsLoopOp) {
+        return convertOmpWsLoop(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::SimdLoopOp) {
+        return convertOmpSimdLoop(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::AtomicReadOp) {
+        return convertOmpAtomicRead(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::AtomicWriteOp) {
+        return convertOmpAtomicWrite(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::AtomicUpdateOp op) {
+        return convertOmpAtomicUpdate(op, builder, moduleTranslation);
+      })
+      .Case([&](omp::AtomicCaptureOp op) {
+        return convertOmpAtomicCapture(op, builder, moduleTranslation);
+      })
+      .Case([&](omp::SectionsOp) {
+        return convertOmpSections(*op, builder, moduleTranslation);
+      })
+      .Case([&](omp::SingleOp op) {
+        return convertOmpSingle(op, builder, moduleTranslation);
+      })
+      .Case([&](omp::TeamsOp op) {
+        return convertOmpTeams(op, builder, moduleTranslation);
+      })
+      .Case([&](omp::TaskOp op) {
+        return convertOmpTaskOp(op, builder, moduleTranslation);
+      })
+      .Case([&](omp::TaskGroupOp op) {
+        return convertOmpTaskgroupOp(op, builder, moduleTranslation);
+      })
+      .Case<omp::YieldOp, omp::TerminatorOp, omp::ReductionDeclareOp,
+            omp::CriticalDeclareOp>([](auto op) {
+        // `yield` and `terminator` can be just omitted. The block structure
+        // was created in the region that handles their parent operation.
+        // `reduction.declare` will be used by reductions and is not
+        // converted directly, skip it.
+        // `critical.declare` is only used to declare names of critical
+        // sections which will be used by `critical` ops and hence can be
+        // ignored for lowering. The OpenMP IRBuilder will create unique
+        // name for critical section names.
+        return success();
+      })
+      .Case([&](omp::ThreadprivateOp) {
+        return convertOmpThreadprivate(*op, builder, moduleTranslation);
+      })
+      .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp, omp::UpdateDataOp>(
+          [&](auto op) {
+            return convertOmpTargetData(op, builder, moduleTranslation);
+          })
+      .Case([&](omp::TargetOp) {
+        return convertOmpTarget(*op, builder, moduleTranslation);
+      })
+      .Case<omp::MapInfoOp, omp::DataBoundsOp, omp::PrivateClauseOp>(
+          [&](auto op) {
+            // No-op, should be handled by relevant owning operations e.g.
+            // TargetOp, EnterDataOp, ExitDataOp, DataOp etc. and then
+            // discarded
+            return success();
+          })
+      .Default([&](Operation *inst) {
+        return inst->emitError("unsupported OpenMP operation: ")
+               << inst->getName();
+      });
+}
+
+static LogicalResult
+convertInternalTargetOp(Operation *op, llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  return convertCommonOperation(op, builder, moduleTranslation);
+}
+
+static LogicalResult
+convertTopLevelTargetOp(Operation *op, llvm::IRBuilderBase &builder,
+                        LLVM::ModuleTranslation &moduleTranslation) {
+  return llvm::TypeSwitch<Operation *, LogicalResult>(op)
+      .Case<omp::DataOp, omp::EnterDataOp, omp::ExitDataOp, omp::UpdateDataOp>(
+          [&](auto op) {
+            return convertOmpTargetData(op, builder, moduleTranslation);
+          })
+      .Case([&](omp::TargetOp) {
+        return convertOmpTarget(*op, builder, moduleTranslation);
+      })
+      // Skip omp ops that are not legal top level ops for the target device
+      .Case<omp::BarrierOp, omp::TaskwaitOp, omp::TaskyieldOp, omp::FlushOp,
+            omp::ReductionOp, omp::OrderedOp, omp::SimdLoopOp,
+            omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
+            omp::AtomicCaptureOp, omp::YieldOp, omp::TerminatorOp,
+            omp::ReductionDeclareOp, omp::ThreadprivateOp, omp::MapInfoOp,
+            omp::DataBoundsOp, omp::CriticalDeclareOp>(
+          [&](auto op) { return success(); })
+      // Recursively traverse inner regions
+      .Case<omp::ParallelOp, omp::MasterOp, omp::CriticalOp,
----------------
skatrak wrote:

+1 on trying to process the op's contents based on the existence of regions attached to it rather than listing the types. Probably, rather than the current type switch, we could do something similar to this instead:
```c++
if (isa<TargetOp>(op))
  return convertOmpTarget(*op, builder, moduleTranslation);

bool interrupted = op->walk<WalkOrder::PreOrder>([&](omp::TargetOp targetOp) {
  if (failed(convertOmpTarget(*targetOp, builder, moduleTranslation)))
    return WalkResult::interrupt();
  return WalkResult::skip();
}).wasInterrupted();
return failure(interrupted);
```
That should do the trick of recursively looking for target regions, I think.

https://github.com/llvm/llvm-project/pull/85239


More information about the Mlir-commits mailing list