[llvm-branch-commits] [clang] [flang] [llvm] [openmp] [Clang][OpenMP] Add reverse and interchange directives (PR #92030)

Alexey Bataev via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue May 14 03:28:09 PDT 2024


================
@@ -15745,6 +15760,388 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                     buildPreInits(Context, PreInits));
 }
 
+StmtResult
+SemaOpenMP::ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses,
+                                        Stmt *AStmt, SourceLocation StartLoc,
+                                        SourceLocation EndLoc) {
+  ASTContext &Context = getASTContext();
+  Scope *CurScope = SemaRef.getCurScope();
+  assert(Clauses.empty() && "reverse directive does not accept any clauses; "
+                            "must have beed checked before");
+
+  // Empty statement should only be possible if there already was an error.
+  if (!AStmt)
+    return StmtError();
+
+  constexpr unsigned NumLoops = 1;
+  Stmt *Body = nullptr;
+  SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers(
+      NumLoops);
+  SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits;
+  if (!checkTransformableLoopNest(OMPD_reverse, AStmt, NumLoops, LoopHelpers,
+                                  Body, OriginalInits))
+    return StmtError();
+
+  // Delay applying the transformation to when template is completely
+  // instantiated.
+  if (SemaRef.CurContext->isDependentContext())
+    return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                       AStmt, nullptr, nullptr);
+
+  assert(LoopHelpers.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  assert(OriginalInits.size() == NumLoops &&
+         "Expecting a single-dimensional loop iteration space");
+  OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front();
+
+  // Find the loop statement.
+  Stmt *LoopStmt = nullptr;
+  collectLoopStmts(AStmt, {LoopStmt});
+
+  // Determine the PreInit declarations.
+  SmallVector<Stmt *> PreInits;
+  addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits);
+
+  auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef);
+  QualType IVTy = IterationVarRef->getType();
+  uint64_t IVWidth = Context.getTypeSize(IVTy);
+  auto *OrigVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+  // Iteration variable SourceLocations.
+  SourceLocation OrigVarLoc = OrigVar->getExprLoc();
+  SourceLocation OrigVarLocBegin = OrigVar->getBeginLoc();
+  SourceLocation OrigVarLocEnd = OrigVar->getEndLoc();
+
+  // Locations pointing to the transformation.
+  SourceLocation TransformLoc = StartLoc;
+  SourceLocation TransformLocBegin = StartLoc;
+  SourceLocation TransformLocEnd = EndLoc;
+
+  // Internal variable names.
+  std::string OrigVarName = OrigVar->getNameInfo().getAsString();
+  std::string TripCountName = (Twine(".tripcount.") + OrigVarName).str();
+  std::string ForwardIVName = (Twine(".forward.iv.") + OrigVarName).str();
+  std::string ReversedIVName = (Twine(".reversed.iv.") + OrigVarName).str();
+
+  // LoopHelper.Updates will read the logical iteration number from
+  // LoopHelper.IterationVarRef, compute the value of the user loop counter of
+  // that logical iteration from it, then assign it to the user loop counter
+  // variable. We cannot directly use LoopHelper.IterationVarRef as the
+  // induction variable of the generated loop because it may cause an underflow:
+  // \code
+  //   for (unsigned i = 0; i < n; ++i)
+  //     body(i);
+  // \endcode
+  //
+  // Naive reversal:
+  // \code
+  //   for (unsigned i = n-1; i >= 0; ++i)
+  //     body(i);
+  // \endcode
+  //
+  // Instead, we introduce a new iteration variable representing the logical
+  // iteration counter of the original loop, convert it to the logical iteration
+  // number of the reversed loop, then let LoopHelper.Updates compute the user's
+  // loop iteration variable from it.
+  // \code
+  //   for (auto .forward.iv = 0; .forward.iv < n; ++.forward.iv) {
+  //     auto .reversed.iv = n - .forward.iv - 1;
+  //     i = (.reversed.iv + 0) * 1                 // LoopHelper.Updates
+  //     body(i);                                   // Body
+  //   }
+  // \endcode
+
+  // Subexpressions with more than one use. One of the constraints of an AST is
+  // that every node object must appear at most once, hence we define a lambda
+  // that creates a new AST node at every use.
+  CaptureVars CopyTransformer(SemaRef);
+  auto MakeNumIterations = [&CopyTransformer, &LoopHelper]() -> Expr * {
+    return AssertSuccess(
+        CopyTransformer.TransformExpr(LoopHelper.NumIterations));
+  };
+
+  // Create the iteration variable for the forward loop (from 0 to n-1).
+  VarDecl *ForwardIVDecl =
+      buildVarDecl(SemaRef, {}, IVTy, ForwardIVName, nullptr, OrigVar);
+  auto MakeForwardRef = [&SemaRef = this->SemaRef, ForwardIVDecl, IVTy,
+                         OrigVarLoc]() {
+    return buildDeclRefExpr(SemaRef, ForwardIVDecl, IVTy, OrigVarLoc);
+  };
+
+  // Iteration variable for the reversed induction variable (from n-1 downto 0):
+  // Reuse the iteration variable created by checkOpenMPLoop.
+  auto *ReversedIVDecl = cast<VarDecl>(IterationVarRef->getDecl());
+  ReversedIVDecl->setDeclName(
+      &SemaRef.PP.getIdentifierTable().get(ReversedIVName));
+
+  // For init-statement:
+  // \code
+  //   auto .forward.iv = 0
+  // \endcode
+  IntegerLiteral *Zero =
+      IntegerLiteral::Create(Context, llvm::APInt::getZero(IVWidth),
+                             ForwardIVDecl->getType(), OrigVarLoc);
+  SemaRef.AddInitializerToDecl(ForwardIVDecl, Zero, /*DirectInit=*/false);
+  StmtResult Init = new (Context)
+      DeclStmt(DeclGroupRef(ForwardIVDecl), OrigVarLocBegin, OrigVarLocEnd);
+  if (!Init.isUsable())
+    return StmtError();
+
+  // Forward iv cond-expression:
+  // \code
+  //   .forward.iv < NumIterations
+  // \endcode
+  ExprResult Cond =
+      SemaRef.BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), BO_LT,
+                         MakeForwardRef(), MakeNumIterations());
+  if (!Cond.isUsable())
+    return StmtError();
+
+  // Forward incr-statement: ++.forward.iv
+  ExprResult Incr = SemaRef.BuildUnaryOp(CurScope, LoopHelper.Inc->getExprLoc(),
+                                         UO_PreInc, MakeForwardRef());
+  if (!Incr.isUsable())
+    return StmtError();
+
+  // Reverse the forward-iv: auto .reversed.iv = MakeNumIterations() - 1 -
+  // .forward.iv
+  IntegerLiteral *One = IntegerLiteral::Create(Context, llvm::APInt(IVWidth, 1),
+                                               IVTy, TransformLoc);
+  ExprResult Minus = SemaRef.BuildBinOp(CurScope, TransformLoc, BO_Sub,
+                                        MakeNumIterations(), One);
+  if (!Minus.isUsable())
+    return StmtError();
+  Minus = SemaRef.BuildBinOp(CurScope, TransformLoc, BO_Sub, Minus.get(),
+                             MakeForwardRef());
+  if (!Minus.isUsable())
+    return StmtError();
+  StmtResult InitReversed = new (Context) DeclStmt(
+      DeclGroupRef(ReversedIVDecl), TransformLocBegin, TransformLocEnd);
+  if (!InitReversed.isUsable())
+    return StmtError();
+  SemaRef.AddInitializerToDecl(ReversedIVDecl, Minus.get(),
+                               /*DirectInit=*/false);
+
+  // The new loop body.
+  SmallVector<Stmt *> BodyStmts;
+  BodyStmts.push_back(InitReversed.get());
+  llvm::append_range(BodyStmts, LoopHelper.Updates);
+  if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt))
+    BodyStmts.push_back(CXXRangeFor->getLoopVarStmt());
+  BodyStmts.push_back(Body);
+  auto *ReversedBody =
+      CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(),
+                           Body->getBeginLoc(), Body->getEndLoc());
+
+  // Finally create the reversed For-statement.
+  auto *ReversedFor = new (Context)
+      ForStmt(Context, Init.get(), Cond.get(), nullptr, Incr.get(),
+              ReversedBody, LoopHelper.Init->getBeginLoc(),
+              LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
+  return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
+                                     ReversedFor,
+                                     buildPreInits(Context, PreInits));
+}
+
+StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective(
+    ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc,
+    SourceLocation EndLoc) {
+  ASTContext &Context = getASTContext();
+  DeclContext *CurContext = SemaRef.CurContext;
+  Scope *CurScope = SemaRef.getCurScope();
+
+  // Empty statement should only be possible if there already was an error.
+  if (!AStmt)
+    return StmtError();
+
+  // interchange without permutation clause swaps two loops.
+  const OMPPermutationClause *PermutationClause =
+      OMPExecutableDirective::getSingleClause<OMPPermutationClause>(Clauses);
+  size_t NumLoops = PermutationClause ? PermutationClause->getNumLoops() : 2;
+
+  // Verify and diagnose loop nest.
+  SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops);
+  Stmt *Body = nullptr;
+  SmallVector<SmallVector<Stmt *, 0>, 2> OriginalInits;
+  if (!checkTransformableLoopNest(OMPD_interchange, AStmt, NumLoops,
+                                  LoopHelpers, Body, OriginalInits))
+    return StmtError();
+
+  // Delay interchange to when template is completely instantiated.
+  if (CurContext->isDependentContext())
+    return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                           NumLoops, AStmt, nullptr, nullptr);
+
+  // An invalid expression in the permutation clause is set to nullptr in
+  // ActOnOpenMPPermutationClause.
+  if (PermutationClause && llvm::any_of(PermutationClause->getArgsRefs(),
+                                        [](Expr *E) { return !E; }))
+    return StmtError();
+
+  assert(LoopHelpers.size() == NumLoops &&
+         "Expecting loop iteration space dimensionaly to match number of "
+         "affected loops");
+  assert(OriginalInits.size() == NumLoops &&
+         "Expecting loop iteration space dimensionaly to match number of "
+         "affected loops");
+
+  // Decode the permutation clause.
+  SmallVector<uint64_t, 2> Permutation;
+  if (!PermutationClause) {
+    Permutation = {1, 0};
+  } else {
+    ArrayRef<Expr *> PermArgs = PermutationClause->getArgsRefs();
+    llvm::BitVector Flags(PermArgs.size());
+    for (Expr *PermArg : PermArgs) {
+      std::optional<llvm::APSInt> PermCstExpr =
+          PermArg->getIntegerConstantExpr(Context);
+      if (!PermCstExpr)
+        continue;
+      uint64_t PermInt = PermCstExpr->getZExtValue();
+      assert(1 <= PermInt && PermInt <= NumLoops &&
+             "Must be a permutation; diagnostic emitted in "
+             "ActOnOpenMPPermutationClause");
+      if (Flags[PermInt - 1]) {
+        SourceRange ExprRange(PermArg->getBeginLoc(), PermArg->getEndLoc());
+        Diag(PermArg->getExprLoc(),
+             diag::err_omp_interchange_permutation_value_repeated)
+            << PermInt << ExprRange;
+        continue;
+      }
+      Flags[PermInt - 1] = true;
+
+      Permutation.push_back(PermInt - 1);
+    }
+
+    if (Permutation.size() != NumLoops)
+      return StmtError();
+  }
+
+  // Nothing to transform with trivial permutation.
+  if (NumLoops <= 1 || llvm::all_of(llvm::enumerate(Permutation), [](auto p) {
+        auto [Idx, Arg] = p;
+        return Idx == Arg;
+      }))
+    return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses,
+                                           NumLoops, AStmt, AStmt, nullptr);
+
+  // Find the affected loops.
+  SmallVector<Stmt *> LoopStmts(NumLoops, nullptr);
+  collectLoopStmts(AStmt, LoopStmts);
+
+  // Collect pre-init statements on the order before the permuation.
+  SmallVector<Stmt *> PreInits;
+  for (auto I : llvm::seq<int>(NumLoops)) {
+    OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I];
+
+    assert(LoopHelper.Counters.size() == 1 &&
+           "Single-dimensional loop iteration space expected");
+    auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters.front());
+
+    std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+    addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I],
+                    PreInits);
+  }
+
+  SmallVector<VarDecl *> PermutedIndVars;
+  PermutedIndVars.resize(NumLoops);
+  CaptureVars CopyTransformer(SemaRef);
+
+  // Create the permuted loops from the inside to the outside of the
+  // interchanged loop nest. Body of the innermost new loop is the original
+  // innermost body.
+  Stmt *Inner = Body;
+  for (auto TargetIdx : llvm::reverse(llvm::seq<int>(NumLoops))) {
+    // Get the original loop that belongs to this new position.
+    uint64_t SourceIdx = Permutation[TargetIdx];
+    OMPLoopBasedDirective::HelperExprs &SourceHelper = LoopHelpers[SourceIdx];
+    Stmt *SourceLoopStmt = LoopStmts[SourceIdx];
+    assert(SourceHelper.Counters.size() == 1 &&
+           "Single-dimensional loop iteration space expected");
+    auto *OrigCntVar = cast<DeclRefExpr>(SourceHelper.Counters.front());
+
+    // Normalized loop counter variable: From 0 to n-1, always an integer type.
+    DeclRefExpr *IterVarRef = cast<DeclRefExpr>(SourceHelper.IterationVarRef);
+    QualType IVTy = IterVarRef->getType();
+    assert(IVTy->isIntegerType() &&
+           "Expected the logical iteration counter to be an integer");
+
+    std::string OrigVarName = OrigCntVar->getNameInfo().getAsString();
+    SourceLocation OrigVarLoc = IterVarRef->getExprLoc();
+
+    // Make a copy of the NumIterations expression for each use: By the AST
+    // constraints, every expression object in a DeclContext must be unique.
+    auto MakeNumIterations = [&CopyTransformer, &SourceHelper]() -> Expr * {
+      return AssertSuccess(
+          CopyTransformer.TransformExpr(SourceHelper.NumIterations));
+    };
+
+    // Iteration variable for the permuted loop. Reuse the one from
+    // checkOpenMPLoop which will also be used to update the original loop
+    // variable.
+    std::string PermutedCntName =
+        (Twine(".permuted_") + llvm::utostr(TargetIdx) + ".iv." + OrigVarName)
+            .str();
+    auto *PermutedCntDecl = cast<VarDecl>(IterVarRef->getDecl());
+    PermutedCntDecl->setDeclName(
+        &SemaRef.PP.getIdentifierTable().get(PermutedCntName));
+    PermutedIndVars[TargetIdx] = PermutedCntDecl;
+    auto MakePermutedRef = [this, PermutedCntDecl, IVTy, OrigVarLoc]() {
+      return buildDeclRefExpr(SemaRef, PermutedCntDecl, IVTy, OrigVarLoc);
+    };
+
+    // For init-statement:
+    // \code{c}
+    //   auto .permuted_{target}.iv = 0
+    // \endcode
+    ExprResult Zero = SemaRef.ActOnIntegerConstant(OrigVarLoc, 0);
+    if (!Zero.isUsable())
+      return StmtError();
+    SemaRef.AddInitializerToDecl(PermutedCntDecl, Zero.get(),
+                                 /*DirectInit=*/false);
+    StmtResult InitStmt = new (Context)
+        DeclStmt(DeclGroupRef(PermutedCntDecl), OrigCntVar->getBeginLoc(),
+                 OrigCntVar->getEndLoc());
+    if (!InitStmt.isUsable())
+      return StmtError();
+
+    // For cond-expression:
+    // \code{c}
+    //   .permuted_{target}.iv < NumIterations
+    // \endcode
+    ExprResult CondExpr =
+        SemaRef.BuildBinOp(CurScope, SourceHelper.Cond->getExprLoc(), BO_LT,
+                           MakePermutedRef(), MakeNumIterations());
+    if (!CondExpr.isUsable())
+      return StmtError();
+
+    // For incr-statement:
+    // \code{c}
+    //   ++.tile.iv
+    // \endcode
+    ExprResult IncrStmt = SemaRef.BuildUnaryOp(
+        CurScope, SourceHelper.Inc->getExprLoc(), UO_PreInc, MakePermutedRef());
+    if (!IncrStmt.isUsable())
+      return StmtError();
+
+    SmallVector<Stmt *, 4> BodyParts;
+    llvm::append_range(BodyParts, SourceHelper.Updates);
----------------
alexey-bataev wrote:

```suggestion
    SmallVector<Stmt *, 4> BodyParts(SourceHelper.Updates.begin(), SourceHelper.Updates.end());
```

https://github.com/llvm/llvm-project/pull/92030


More information about the llvm-branch-commits mailing list