[Mlir-commits] [flang] [mlir] [Flang][MLIR] Add `!$omp unroll` and `omp.unroll_heuristic` (PR #144785)

Kareem Ergawy llvmlistbot at llvm.org
Mon Jun 30 05:08:50 PDT 2025


================
@@ -3012,6 +3013,366 @@ void LoopNestOp::gatherWrappers(
   }
 }
 
+//===----------------------------------------------------------------------===//
+// OpenMP canonical loop handling
+//===----------------------------------------------------------------------===//
+
+std::tuple<NewCliOp, OpOperand *, OpOperand *>
+mlir::omp ::decodeCli(Value cli) {
+
+  // Defining a CLI for a generated loop is optional; if there is none then
+  // there is no followup-tranformation
+  if (!cli)
+    return {{}, nullptr, nullptr};
+
+  MLIRContext *ctx = cli.getContext();
+  assert(cli.getType() == CanonicalLoopInfoType::get(ctx) &&
+         "Unexpected type of cli");
+
+  NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
+  OpOperand *gen = nullptr;
+  OpOperand *cons = nullptr;
+  for (OpOperand &use : cli.getUses()) {
+    auto op = cast<LoopTransformationInterface>(use.getOwner());
+    auto applyees = op.getApplyeesODSOperandIndexAndLength();
+    auto generatees = op.getGenerateesODSOperandIndexAndLength();
+
+    unsigned opnum = use.getOperandNumber();
+    if (generatees.first <= opnum &&
+        opnum < generatees.first + generatees.second) {
+      assert(!gen && "Each CLI may have at most one consumer");
+      gen = &use;
+    } else if (applyees.first <= opnum &&
+               opnum < applyees.first + applyees.second) {
+      assert(!cons && "Each CLI may have at most one def");
+      cons = &use;
+    } else {
+      llvm_unreachable("Unexpected operand for a CLI");
+    }
+  }
+
+  return {create, gen, cons};
+}
+
+void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
+                     ::mlir::OperationState &odsState) {
+  odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
+}
+
+void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
+  Value result = getResult();
+  auto [newCli, gen, cond] = decodeCli(result);
+
+  // Derive the CLI variable name from its generator:
+  //  * "canonloop" for omp.canonical_loop
+  //  * custom name for loop transformation generatees
+  //  * "cli" as fallback if no generator
+  //  * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
+  //  at that level
+  //  * "_s<idx>" suffix for operations with multiple regions, where <idx> is
+  //  the index of that region
+  std::string cliName{"cli"};
+  if (gen) {
+    cliName =
+        TypeSwitch<Operation *, std::string>(gen->getOwner())
+            .Case([&](CanonicalLoopOp op) {
+              // Find the canonical loop nesting: For each ancestor add a
+              // "+_r<idx>" suffix (in reverse order)
+              SmallVector<std::string> components;
+              Operation *o = op.getOperation();
+              while (o) {
+                if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
+                  break;
+
+                Region *r = o->getParentRegion();
+                if (!r)
+                  break;
+
+                Operation *parent = r->getParentOp();
+                auto getSequentialIndex = [](Region *r, Operation *o) {
+                  llvm::ReversePostOrderTraversal<Block *> traversal(
+                      &r->getBlocks().front());
+                  size_t idx = 0;
+                  for (Block *b : traversal) {
+                    for (Operation &op : *b) {
+                      if (&op == o)
+                        return idx;
+                      // Only consider operations that are containers as
+                      // possible children
+                      if (!op.getRegions().empty())
+                        idx += 1;
+                    }
+                  }
+                  llvm_unreachable("Operation not part of the region");
+                };
+                size_t sequentialIdx = getSequentialIndex(r, o);
+                components.push_back(("s" + Twine(sequentialIdx)).str());
+
+                if (!parent)
+                  break;
+
+                // If the operation has more than one region, also count in
+                // which of the regions
+                if (parent->getRegions().size() > 1) {
+                  auto getRegionIndex = [](Operation *o, Region *r) {
+                    for (auto [idx, region] :
+                         llvm::enumerate(o->getRegions())) {
+                      if (&region == r)
+                        return idx;
+                    }
+                    llvm_unreachable("Region not child its parent operation");
+                  };
+                  size_t regionIdx = getRegionIndex(parent, r);
+                  components.push_back(("r" + Twine(regionIdx)).str());
+                }
+
+                // next parent
+                o = parent;
+              }
+
+              SmallString<64> Name("canonloop");
+              for (std::string s : reverse(components)) {
+                Name += '_';
+                Name += s;
+              }
+
+              return Name;
+            })
+            .Case([&](UnrollHeuristicOp op) -> std::string {
+              llvm_unreachable("heuristic unrolling does not generate a loop");
+            })
+            .Default([&](Operation *op) {
+              assert(!"TODO: Custom name for this operation");
+              return "transformed";
+            });
+  }
+
+  setNameFn(result, cliName);
+}
+
+LogicalResult NewCliOp::verify() {
+  Value cli = getResult();
+
+  MLIRContext *ctx = cli.getContext();
+  assert(cli.getType() == CanonicalLoopInfoType::get(ctx) &&
+         "Unexpected type of cli");
+
+  // Check that the CLI is used in at most generator and one consumer
+  OpOperand *gen = nullptr;
+  OpOperand *cons = nullptr;
+  for (mlir::OpOperand &use : cli.getUses()) {
+    auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
+    auto applyees = op.getApplyeesODSOperandIndexAndLength();
+    auto generatees = op.getGenerateesODSOperandIndexAndLength();
+
+    unsigned opnum = use.getOperandNumber();
+    if (generatees.first <= opnum &&
+        opnum < generatees.first + generatees.second) {
+      if (gen) {
+        InFlightDiagnostic error =
+            emitOpError("CLI must have at most one generator");
+        error.attachNote(gen->getOwner()->getLoc())
+            .append("first generator here:");
+        error.attachNote(use.getOwner()->getLoc())
+            .append("second generator here:");
+        return error;
+      }
+
+      gen = &use;
+    } else if (applyees.first <= opnum &&
+               opnum < applyees.first + applyees.second) {
----------------
ergawy wrote:

The suggested `isApplyee` and `isGeneratee` API would be helpful here as well.

https://github.com/llvm/llvm-project/pull/144785


More information about the Mlir-commits mailing list