[polly] [Polly] Use separate DT/LI/SE for outlined subfn. NFC. (PR #102460)

Michael Kruse via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 8 08:13:38 PDT 2024


https://github.com/Meinersbur updated https://github.com/llvm/llvm-project/pull/102460

>From 9375abaeff0e34d570556c43a37377ead2e2fd78 Mon Sep 17 00:00:00 2001
From: Michael Kruse <llvm-project at meinersbur.de>
Date: Thu, 8 Aug 2024 14:46:53 +0200
Subject: [PATCH 1/2] [Polly] Use separate DT/LI/SE for outlined subfn

---
 polly/include/polly/CodeGen/BlockGenerators.h |  17 +-
 polly/include/polly/CodeGen/IslExprBuilder.h  |  18 +-
 polly/include/polly/CodeGen/IslNodeBuilder.h  |  21 +-
 polly/include/polly/CodeGen/LoopGenerators.h  |  24 +-
 .../polly/CodeGen/LoopGeneratorsGOMP.h        |   5 +-
 .../include/polly/CodeGen/LoopGeneratorsKMP.h |   5 +-
 polly/include/polly/Support/ScopHelper.h      |  11 +-
 polly/lib/CodeGen/BlockGenerators.cpp         |  27 +--
 polly/lib/CodeGen/IslExprBuilder.cpp          |  23 +-
 polly/lib/CodeGen/IslNodeBuilder.cpp          | 214 ++++++++++--------
 polly/lib/CodeGen/LoopGeneratorsGOMP.cpp      |  22 +-
 polly/lib/CodeGen/LoopGeneratorsKMP.cpp       |  22 +-
 polly/lib/Support/ScopHelper.cpp              | 148 ++++++++----
 13 files changed, 336 insertions(+), 221 deletions(-)

diff --git a/polly/include/polly/CodeGen/BlockGenerators.h b/polly/include/polly/CodeGen/BlockGenerators.h
index 13c27328d8c7e1..074426c8ccbda5 100644
--- a/polly/include/polly/CodeGen/BlockGenerators.h
+++ b/polly/include/polly/CodeGen/BlockGenerators.h
@@ -162,8 +162,21 @@ class BlockGenerator {
   /// The dominator tree of this function.
   DominatorTree &DT;
 
-  /// The entry block of the current function.
-  BasicBlock *EntryBB;
+  /// Relates to the region where the code is emitted into.
+  /// @{
+  DominatorTree *GenDT;
+  LoopInfo *GenLI;
+  ScalarEvolution *GenSE;
+  /// @}
+
+public:
+  /// Change the function that code is emitted into.
+  void switchGeneratedFunc(Function *GenFn, DominatorTree *GenDT,
+                           LoopInfo *GenLI, ScalarEvolution *GenSE) {
+    this->GenDT = GenDT;
+    this->GenLI = GenLI;
+    this->GenSE = GenSE;
+  }
 
   /// Map to resolve scalar dependences for PHI operands and scalars.
   ///
diff --git a/polly/include/polly/CodeGen/IslExprBuilder.h b/polly/include/polly/CodeGen/IslExprBuilder.h
index 6842aaa456ac27..6a6d644ee2439a 100644
--- a/polly/include/polly/CodeGen/IslExprBuilder.h
+++ b/polly/include/polly/CodeGen/IslExprBuilder.h
@@ -124,6 +124,15 @@ class IslExprBuilder final {
                  llvm::ScalarEvolution &SE, llvm::DominatorTree &DT,
                  llvm::LoopInfo &LI, llvm::BasicBlock *StartBlock);
 
+  /// Change the function that code is emitted into.
+  void switchGeneratedFunc(llvm::Function *GenFn, llvm::DominatorTree *GenDT,
+                           llvm::LoopInfo *GenLI,
+                           llvm::ScalarEvolution *GenSE) {
+    this->GenDT = GenDT;
+    this->GenLI = GenLI;
+    this->GenSE = GenSE;
+  }
+
   /// Create LLVM-IR for an isl_ast_expr[ession].
   ///
   /// @param Expr The ast expression for which we generate LLVM-IR.
@@ -205,10 +214,15 @@ class IslExprBuilder final {
 
   const llvm::DataLayout &DL;
   llvm::ScalarEvolution &SE;
-  llvm::DominatorTree &DT;
-  llvm::LoopInfo &LI;
   llvm::BasicBlock *StartBlock;
 
+  /// Relates to the region where the code is emitted into.
+  /// @{
+  llvm::DominatorTree *GenDT;
+  llvm::LoopInfo *GenLI;
+  llvm::ScalarEvolution *GenSE;
+  /// @}
+
   llvm::Value *createOp(__isl_take isl_ast_expr *Expr);
   llvm::Value *createOpUnary(__isl_take isl_ast_expr *Expr);
   llvm::Value *createOpAccess(__isl_take isl_ast_expr *Expr);
diff --git a/polly/include/polly/CodeGen/IslNodeBuilder.h b/polly/include/polly/CodeGen/IslNodeBuilder.h
index 05f53d79d74a4f..81343af94b636e 100644
--- a/polly/include/polly/CodeGen/IslNodeBuilder.h
+++ b/polly/include/polly/CodeGen/IslNodeBuilder.h
@@ -72,7 +72,7 @@ class IslNodeBuilder {
         BlockGen(Builder, LI, SE, DT, ScalarMap, EscapeMap, ValueMap,
                  &ExprBuilder, StartBlock),
         RegionGen(BlockGen), DL(DL), LI(LI), SE(SE), DT(DT),
-        StartBlock(StartBlock) {}
+        StartBlock(StartBlock), GenDT(&DT), GenLI(&LI), GenSE(&SE) {}
 
   virtual ~IslNodeBuilder() = default;
 
@@ -147,6 +147,13 @@ class IslNodeBuilder {
   DominatorTree &DT;
   BasicBlock *StartBlock;
 
+  /// Relates to the region where the code is emitted into.
+  /// @{
+  DominatorTree *GenDT;
+  LoopInfo *GenLI;
+  ScalarEvolution *GenSE;
+  /// @}
+
   /// The current iteration of out-of-scop loops
   ///
   /// This map provides for a given loop a llvm::Value that contains the current
@@ -246,18 +253,6 @@ class IslNodeBuilder {
                               SetVector<Value *> &Values,
                               SetVector<const Loop *> &Loops);
 
-  /// Change the llvm::Value(s) used for code generation.
-  ///
-  /// When generating code certain values (e.g., references to induction
-  /// variables or array base pointers) in the original code may be replaced by
-  /// new values. This function allows to (partially) update the set of values
-  /// used. A typical use case for this function is the case when we continue
-  /// code generation in a subfunction/kernel function and need to explicitly
-  /// pass down certain values.
-  ///
-  /// @param NewValues A map that maps certain llvm::Values to new llvm::Values.
-  void updateValues(ValueMapT &NewValues);
-
   /// Return the most up-to-date version of the llvm::Value for code generation.
   /// @param Original The Value to check for an up to date version.
   /// @returns A remapped `Value` from ValueMap, or `Original` if no mapping
diff --git a/polly/include/polly/CodeGen/LoopGenerators.h b/polly/include/polly/CodeGen/LoopGenerators.h
index 8ec75e69890c95..6076e5951fb0a5 100644
--- a/polly/include/polly/CodeGen/LoopGenerators.h
+++ b/polly/include/polly/CodeGen/LoopGenerators.h
@@ -55,7 +55,7 @@ extern int PollyChunkSize;
 /// @param Builder            The builder used to create the loop.
 /// @param P                  A pointer to the pass that uses this function.
 ///                           It is used to update analysis information.
-/// @param LI                 The loop info for the current function
+/// @param LI                 The loop info we need to update
 /// @param DT                 The dominator tree we need to update
 /// @param ExitBlock          The block the loop will exit to.
 /// @param Predicate          The predicate used to generate the upper loop
@@ -128,11 +128,9 @@ llvm::DebugLoc createDebugLocForGeneratedCode(Function *F);
 class ParallelLoopGenerator {
 public:
   /// Create a parallel loop generator for the current function.
-  ParallelLoopGenerator(PollyIRBuilder &Builder, LoopInfo &LI,
-                        DominatorTree &DT, const DataLayout &DL)
-      : Builder(Builder), LI(LI), DT(DT),
-        LongType(
-            Type::getIntNTy(Builder.getContext(), DL.getPointerSizeInBits())),
+  ParallelLoopGenerator(PollyIRBuilder &Builder, const DataLayout &DL)
+      : Builder(Builder), LongType(Type::getIntNTy(Builder.getContext(),
+                                                   DL.getPointerSizeInBits())),
         M(Builder.GetInsertBlock()->getParent()->getParent()),
         DLGenerated(createDebugLocForGeneratedCode(
             Builder.GetInsertBlock()->getParent())) {}
@@ -164,11 +162,11 @@ class ParallelLoopGenerator {
   /// The IR builder we use to create instructions.
   PollyIRBuilder &Builder;
 
-  /// The loop info of the current function we need to update.
-  LoopInfo &LI;
+  /// The loop info for the generated subfunction.
+  std::unique_ptr<LoopInfo> SubFnLI;
 
-  /// The dominance tree of the current function we need to update.
-  DominatorTree &DT;
+  /// The dominance tree for the generated subfunction.
+  std::unique_ptr<DominatorTree> SubFnDT;
 
   /// The type of a "long" on this hardware used for backend calls.
   Type *LongType;
@@ -184,6 +182,12 @@ class ParallelLoopGenerator {
   llvm::DebugLoc DLGenerated;
 
 public:
+  /// Returns the DominatorTree for the generated subfunction.
+  DominatorTree *getCalleeDominatorTree() const { return SubFnDT.get(); }
+
+  /// Returns the LoopInfo for the generated subfunction.
+  LoopInfo *getCalleeLoopInfo() const { return SubFnLI.get(); }
+
   /// Create a struct for all @p Values and store them in there.
   ///
   /// @param Values The values which should be stored in the struct.
diff --git a/polly/include/polly/CodeGen/LoopGeneratorsGOMP.h b/polly/include/polly/CodeGen/LoopGeneratorsGOMP.h
index 4cc4f394f36319..1bf6e6e2cbbb78 100644
--- a/polly/include/polly/CodeGen/LoopGeneratorsGOMP.h
+++ b/polly/include/polly/CodeGen/LoopGeneratorsGOMP.h
@@ -25,9 +25,8 @@ namespace polly {
 class ParallelLoopGeneratorGOMP final : public ParallelLoopGenerator {
 public:
   /// Create a parallel loop generator for the current function.
-  ParallelLoopGeneratorGOMP(PollyIRBuilder &Builder, LoopInfo &LI,
-                            DominatorTree &DT, const DataLayout &DL)
-      : ParallelLoopGenerator(Builder, LI, DT, DL) {}
+  ParallelLoopGeneratorGOMP(PollyIRBuilder &Builder, const DataLayout &DL)
+      : ParallelLoopGenerator(Builder, DL) {}
 
   // The functions below may be used if one does not want to generate a
   // specific OpenMP parallel loop, but generate individual parts of it
diff --git a/polly/include/polly/CodeGen/LoopGeneratorsKMP.h b/polly/include/polly/CodeGen/LoopGeneratorsKMP.h
index 245a63c7bae50f..f134857449c2df 100644
--- a/polly/include/polly/CodeGen/LoopGeneratorsKMP.h
+++ b/polly/include/polly/CodeGen/LoopGeneratorsKMP.h
@@ -27,9 +27,8 @@ using llvm::GlobalVariable;
 class ParallelLoopGeneratorKMP final : public ParallelLoopGenerator {
 public:
   /// Create a parallel loop generator for the current function.
-  ParallelLoopGeneratorKMP(PollyIRBuilder &Builder, LoopInfo &LI,
-                           DominatorTree &DT, const DataLayout &DL)
-      : ParallelLoopGenerator(Builder, LI, DT, DL) {
+  ParallelLoopGeneratorKMP(PollyIRBuilder &Builder, const DataLayout &DL)
+      : ParallelLoopGenerator(Builder, DL) {
     SourceLocationInfo = createSourceLocation();
   }
 
diff --git a/polly/include/polly/Support/ScopHelper.h b/polly/include/polly/Support/ScopHelper.h
index 17480c5381c516..13852ecb18ee7c 100644
--- a/polly/include/polly/Support/ScopHelper.h
+++ b/polly/include/polly/Support/ScopHelper.h
@@ -36,6 +36,9 @@ namespace polly {
 class Scop;
 class ScopStmt;
 
+/// Same as llvm/Analysis/ScalarEvolutionExpressions.h
+using LoopToScevMapT = llvm::DenseMap<const llvm::Loop *, const llvm::SCEV *>;
+
 /// Enumeration of assumptions Polly can take.
 enum AssumptionKind {
   ALIASING,
@@ -383,20 +386,24 @@ void splitEntryBlockForAlloca(llvm::BasicBlock *EntryBlock,
 /// as the call to SCEVExpander::expandCodeFor:
 ///
 /// @param S     The current Scop.
-/// @param SE    The Scalar Evolution pass.
+/// @param SE    The Scalar Evolution pass used by @p S.
+/// @param GenFn The function to generate code in. Can be the same as @p SE.
+/// @param GenSE The Scalar Evolution pass for @p GenFn.
 /// @param DL    The module data layout.
 /// @param Name  The suffix added to the new instruction names.
 /// @param E     The expression for which code is actually generated.
 /// @param Ty    The type of the resulting code.
 /// @param IP    The insertion point for the new code.
 /// @param VMap  A remapping of values used in @p E.
+/// @param LoopMap A remapping of loops used in @p E.
 /// @param RTCBB The last block of the RTC. Used to insert loop-invariant
 ///              instructions in rare cases.
 llvm::Value *expandCodeFor(Scop &S, llvm::ScalarEvolution &SE,
+                           llvm::Function *GenFn, llvm::ScalarEvolution &GenSE,
                            const llvm::DataLayout &DL, const char *Name,
                            const llvm::SCEV *E, llvm::Type *Ty,
                            llvm::Instruction *IP, ValueMapT *VMap,
-                           llvm::BasicBlock *RTCBB);
+                           LoopToScevMapT *LoopMap, llvm::BasicBlock *RTCBB);
 
 /// Return the condition for the terminator @p TI.
 ///
diff --git a/polly/lib/CodeGen/BlockGenerators.cpp b/polly/lib/CodeGen/BlockGenerators.cpp
index f7c777b4e80118..c7e1b21286b443 100644
--- a/polly/lib/CodeGen/BlockGenerators.cpp
+++ b/polly/lib/CodeGen/BlockGenerators.cpp
@@ -57,8 +57,8 @@ BlockGenerator::BlockGenerator(
     PollyIRBuilder &B, LoopInfo &LI, ScalarEvolution &SE, DominatorTree &DT,
     AllocaMapTy &ScalarMap, EscapeUsersAllocaMapTy &EscapeMap,
     ValueMapT &GlobalMap, IslExprBuilder *ExprBuilder, BasicBlock *StartBlock)
-    : Builder(B), LI(LI), SE(SE), ExprBuilder(ExprBuilder), DT(DT),
-      EntryBB(nullptr), ScalarMap(ScalarMap), EscapeMap(EscapeMap),
+    : Builder(B), LI(LI), SE(SE), ExprBuilder(ExprBuilder), DT(DT), GenDT(&DT),
+      GenLI(&LI), GenSE(&SE), ScalarMap(ScalarMap), EscapeMap(EscapeMap),
       GlobalMap(GlobalMap), StartBlock(StartBlock) {}
 
 Value *BlockGenerator::trySynthesizeNewValue(ScopStmt &Stmt, Value *Old,
@@ -75,7 +75,6 @@ Value *BlockGenerator::trySynthesizeNewValue(ScopStmt &Stmt, Value *Old,
   if (isa<SCEVCouldNotCompute>(Scev))
     return nullptr;
 
-  const SCEV *NewScev = SCEVLoopAddRecRewriter::rewrite(Scev, LTS, SE);
   ValueMapT VTV;
   VTV.insert(BBMap.begin(), BBMap.end());
   VTV.insert(GlobalMap.begin(), GlobalMap.end());
@@ -86,9 +85,9 @@ Value *BlockGenerator::trySynthesizeNewValue(ScopStmt &Stmt, Value *Old,
 
   assert(IP != Builder.GetInsertBlock()->end() &&
          "Only instructions can be insert points for SCEVExpander");
-  Value *Expanded =
-      expandCodeFor(S, SE, DL, "polly", NewScev, Old->getType(), &*IP, &VTV,
-                    StartBlock->getSinglePredecessor());
+  Value *Expanded = expandCodeFor(
+      S, SE, Builder.GetInsertBlock()->getParent(), *GenSE, DL, "polly", Scev,
+      Old->getType(), &*IP, &VTV, &LTS, StartBlock->getSinglePredecessor());
 
   BBMap[Old] = Expanded;
   return Expanded;
@@ -233,6 +232,8 @@ void BlockGenerator::copyInstScalar(ScopStmt &Stmt, Instruction *Inst,
       return;
     }
 
+    // FIXME: We will encounter "NewOperand" again if used twice. getNewValue()
+    // is meant to be called on old values only.
     NewInst->replaceUsesOfWith(OldOperand, NewOperand);
   }
 
@@ -410,7 +411,7 @@ void BlockGenerator::copyStmt(ScopStmt &Stmt, LoopToScevMapT &LTS,
 
 BasicBlock *BlockGenerator::splitBB(BasicBlock *BB) {
   BasicBlock *CopyBB = SplitBlock(Builder.GetInsertBlock(),
-                                  &*Builder.GetInsertPoint(), &DT, &LI);
+                                  &*Builder.GetInsertPoint(), GenDT, GenLI);
   CopyBB->setName("polly.stmt." + BB->getName());
   return CopyBB;
 }
@@ -434,8 +435,6 @@ BasicBlock *BlockGenerator::copyBB(ScopStmt &Stmt, BasicBlock *BB,
 void BlockGenerator::copyBB(ScopStmt &Stmt, BasicBlock *BB, BasicBlock *CopyBB,
                             ValueMapT &BBMap, LoopToScevMapT &LTS,
                             isl_id_to_ast_expr *NewAccesses) {
-  EntryBB = &CopyBB->getParent()->getEntryBlock();
-
   // Block statements and the entry blocks of region statement are code
   // generated from instruction lists. This allow us to optimize the
   // instructions that belong to a certain scop statement. As the code
@@ -497,7 +496,7 @@ Value *BlockGenerator::getOrCreateAlloca(const ScopArrayInfo *Array) {
   Addr =
       new AllocaInst(Ty, DL.getAllocaAddrSpace(), nullptr,
                      DL.getPrefTypeAlign(Ty), ScalarBase->getName() + NameExt);
-  EntryBB = &Builder.GetInsertBlock()->getParent()->getEntryBlock();
+  BasicBlock *EntryBB = &Builder.GetInsertBlock()->getParent()->getEntryBlock();
   Addr->insertBefore(&*EntryBB->getFirstInsertionPt());
 
   return Addr;
@@ -554,10 +553,6 @@ void BlockGenerator::generateScalarLoads(
 
     auto *Address =
         getImplicitAddress(*MA, getLoopForStmt(Stmt), LTS, BBMap, NewAccesses);
-    assert((!isa<Instruction>(Address) ||
-            DT.dominates(cast<Instruction>(Address)->getParent(),
-                         Builder.GetInsertBlock())) &&
-           "Domination violation");
     BBMap[MA->getAccessValue()] = Builder.CreateLoad(
         MA->getElementType(), Address, Address->getName() + ".reload");
   }
@@ -615,9 +610,9 @@ void BlockGenerator::generateConditionalExecution(
   StringRef BlockName = HeadBlock->getName();
 
   // Generate the conditional block.
-  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
+  DomTreeUpdater DTU(GenDT, DomTreeUpdater::UpdateStrategy::Eager);
   SplitBlockAndInsertIfThen(Cond, &*Builder.GetInsertPoint(), false, nullptr,
-                            &DTU, &LI);
+                            &DTU, GenLI);
   BranchInst *Branch = cast<BranchInst>(HeadBlock->getTerminator());
   BasicBlock *ThenBlock = Branch->getSuccessor(0);
   BasicBlock *TailBlock = Branch->getSuccessor(1);
diff --git a/polly/lib/CodeGen/IslExprBuilder.cpp b/polly/lib/CodeGen/IslExprBuilder.cpp
index f40511e0273a26..d573daee87a153 100644
--- a/polly/lib/CodeGen/IslExprBuilder.cpp
+++ b/polly/lib/CodeGen/IslExprBuilder.cpp
@@ -42,7 +42,8 @@ IslExprBuilder::IslExprBuilder(Scop &S, PollyIRBuilder &Builder,
                                DominatorTree &DT, LoopInfo &LI,
                                BasicBlock *StartBlock)
     : S(S), Builder(Builder), IDToValue(IDToValue), GlobalMap(GlobalMap),
-      DL(DL), SE(SE), DT(DT), LI(LI), StartBlock(StartBlock) {
+      DL(DL), SE(SE), StartBlock(StartBlock), GenDT(&DT), GenLI(&LI),
+      GenSE(&SE) {
   OverflowState = (OTMode == OT_ALWAYS) ? Builder.getFalse() : nullptr;
 }
 
@@ -307,14 +308,12 @@ IslExprBuilder::createAccessAddress(__isl_take isl_ast_expr *Expr) {
 
     const SCEV *DimSCEV = SAI->getDimensionSize(u);
 
-    llvm::ValueToSCEVMapTy Map;
-    for (auto &KV : GlobalMap)
-      Map[KV.first] = SE.getSCEV(KV.second);
-    DimSCEV = SCEVParameterRewriter::rewrite(DimSCEV, SE, Map);
-    Value *DimSize =
-        expandCodeFor(S, SE, DL, "polly", DimSCEV, DimSCEV->getType(),
-                      &*Builder.GetInsertPoint(), nullptr,
-                      StartBlock->getSinglePredecessor());
+    // DimSize should be invariant to the SCoP, so no BBMap nor LoopToScev
+    // needed. But GlobalMap may contain SCoP-invariant vars.
+    Value *DimSize = expandCodeFor(
+        S, SE, Builder.GetInsertBlock()->getParent(), *GenSE, DL, "polly",
+        DimSCEV, DimSCEV->getType(), &*Builder.GetInsertPoint(), &GlobalMap,
+        /*LoopMap*/ nullptr, StartBlock->getSinglePredecessor());
 
     Type *Ty = getWidestType(DimSize->getType(), IndexOp->getType());
 
@@ -602,10 +601,10 @@ IslExprBuilder::createOpBooleanConditional(__isl_take isl_ast_expr *Expr) {
 
   auto InsertBB = Builder.GetInsertBlock();
   auto InsertPoint = Builder.GetInsertPoint();
-  auto NextBB = SplitBlock(InsertBB, &*InsertPoint, &DT, &LI);
+  auto NextBB = SplitBlock(InsertBB, &*InsertPoint, GenDT, GenLI);
   BasicBlock *CondBB = BasicBlock::Create(Context, "polly.cond", F);
-  LI.changeLoopFor(CondBB, LI.getLoopFor(InsertBB));
-  DT.addNewBlock(CondBB, InsertBB);
+  GenLI->changeLoopFor(CondBB, GenLI->getLoopFor(InsertBB));
+  GenDT->addNewBlock(CondBB, InsertBB);
 
   InsertBB->getTerminator()->eraseFromParent();
   Builder.SetInsertPoint(InsertBB);
diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index 8b2207ecbf362c..1ee550a612f422 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -30,10 +30,12 @@
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/RegionInfo.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
@@ -44,11 +46,13 @@
 #include "llvm/IR/InstrTypes.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
 #include "llvm/IR/Type.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Support/Casting.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/TargetParser/Triple.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "isl/aff.h"
 #include "isl/aff_type.h"
@@ -366,22 +370,6 @@ void IslNodeBuilder::getReferencesInSubtree(const isl::ast_node &For,
   Values = ReplacedValues;
 }
 
-void IslNodeBuilder::updateValues(ValueMapT &NewValues) {
-  SmallPtrSet<Value *, 5> Inserted;
-
-  for (const auto &I : IDToValue) {
-    IDToValue[I.first] = NewValues[I.second];
-    Inserted.insert(I.second);
-  }
-
-  for (const auto &I : NewValues) {
-    if (Inserted.count(I.first))
-      continue;
-
-    ValueMap[I.first] = I.second;
-  }
-}
-
 Value *IslNodeBuilder::getLatestValue(Value *Original) const {
   auto It = ValueMap.find(Original);
   if (It == ValueMap.end())
@@ -488,10 +476,10 @@ void IslNodeBuilder::createForSequential(isl::ast_node_for For,
 
   // If we can show that LB <Predicate> UB holds at least once, we can
   // omit the GuardBB in front of the loop.
-  bool UseGuardBB =
-      !SE.isKnownPredicate(Predicate, SE.getSCEV(ValueLB), SE.getSCEV(ValueUB));
-  IV = createLoop(ValueLB, ValueUB, ValueInc, Builder, LI, DT, ExitBlock,
-                  Predicate, &Annotator, MarkParallel, UseGuardBB,
+  bool UseGuardBB = !GenSE->isKnownPredicate(Predicate, GenSE->getSCEV(ValueLB),
+                                             GenSE->getSCEV(ValueUB));
+  IV = createLoop(ValueLB, ValueUB, ValueInc, Builder, *GenLI, *GenDT,
+                  ExitBlock, Predicate, &Annotator, MarkParallel, UseGuardBB,
                   LoopVectorizerDisabled);
   IDToValue[IteratorID.get()] = IV;
 
@@ -506,50 +494,6 @@ void IslNodeBuilder::createForSequential(isl::ast_node_for For,
   SequentialLoops++;
 }
 
-/// Remove the BBs contained in a (sub)function from the dominator tree.
-///
-/// This function removes the basic blocks that are part of a subfunction from
-/// the dominator tree. Specifically, when generating code it may happen that at
-/// some point the code generation continues in a new sub-function (e.g., when
-/// generating OpenMP code). The basic blocks that are created in this
-/// sub-function are then still part of the dominator tree of the original
-/// function, such that the dominator tree reaches over function boundaries.
-/// This is not only incorrect, but also causes crashes. This function now
-/// removes from the dominator tree all basic blocks that are dominated (and
-/// consequently reachable) from the entry block of this (sub)function.
-///
-/// FIXME: A LLVM (function or region) pass should not touch anything outside of
-/// the function/region it runs on. Hence, the pure need for this function shows
-/// that we do not comply to this rule. At the moment, this does not cause any
-/// issues, but we should be aware that such issues may appear. Unfortunately
-/// the current LLVM pass infrastructure does not allow to make Polly a module
-/// or call-graph pass to solve this issue, as such a pass would not have access
-/// to the per-function analyses passes needed by Polly. A future pass manager
-/// infrastructure is supposed to enable such kind of access possibly allowing
-/// us to create a cleaner solution here.
-///
-/// FIXME: Instead of adding the dominance information and then dropping it
-/// later on, we should try to just not add it in the first place. This requires
-/// some careful testing to make sure this does not break in interaction with
-/// the SCEVBuilder and SplitBlock which may rely on the dominator tree or
-/// which may try to update it.
-///
-/// @param F The function which contains the BBs to removed.
-/// @param DT The dominator tree from which to remove the BBs.
-static void removeSubFuncFromDomTree(Function *F, DominatorTree &DT) {
-  DomTreeNode *N = DT.getNode(&F->getEntryBlock());
-  std::vector<BasicBlock *> Nodes;
-
-  // We can only remove an element from the dominator tree, if all its children
-  // have been removed. To ensure this we obtain the list of nodes to remove
-  // using a post-order tree traversal.
-  for (po_iterator<DomTreeNode *> I = po_begin(N), E = po_end(N); I != E; ++I)
-    Nodes.push_back(I->getBlock());
-
-  for (BasicBlock *BB : Nodes)
-    DT.eraseNode(BB);
-}
-
 void IslNodeBuilder::createForParallel(__isl_take isl_ast_node *For) {
   isl_ast_node *Body;
   isl_ast_expr *Init, *Inc, *Iterator, *UB;
@@ -619,31 +563,107 @@ void IslNodeBuilder::createForParallel(__isl_take isl_ast_node *For) {
 
   switch (PollyOmpBackend) {
   case OpenMPBackend::GNU:
-    ParallelLoopGenPtr.reset(
-        new ParallelLoopGeneratorGOMP(Builder, LI, DT, DL));
+    ParallelLoopGenPtr.reset(new ParallelLoopGeneratorGOMP(Builder, DL));
     break;
   case OpenMPBackend::LLVM:
-    ParallelLoopGenPtr.reset(new ParallelLoopGeneratorKMP(Builder, LI, DT, DL));
+    ParallelLoopGenPtr.reset(new ParallelLoopGeneratorKMP(Builder, DL));
     break;
   }
 
   IV = ParallelLoopGenPtr->createParallelLoop(
       ValueLB, ValueUB, ValueInc, SubtreeValues, NewValues, &LoopBody);
   BasicBlock::iterator AfterLoop = Builder.GetInsertPoint();
-  Builder.SetInsertPoint(&*LoopBody);
 
   // Remember the parallel subfunction
-  ParallelSubfunctions.push_back(LoopBody->getFunction());
-
-  // Save the current values.
-  auto ValueMapCopy = ValueMap;
+  Function *SubFn = LoopBody->getFunction();
+  ParallelSubfunctions.push_back(SubFn);
+
+  // We start working on the outlined function. Since DominatorTree/LoopInfo are
+  // not an inter-procedural passes, we temporarily switch them out. Save the
+  // old ones first.
+  Function *CallerFn = Builder.GetInsertBlock()->getParent();
+  DominatorTree *CallerDT = GenDT;
+  LoopInfo *CallerLI = GenLI;
+  ScalarEvolution *CallerSE = GenSE;
+  ValueMapT CallerGlobals = ValueMap;
   IslExprBuilder::IDToValueTy IDToValueCopy = IDToValue;
 
-  updateValues(NewValues);
+  // Get the analyses for the subfunction. ParallelLoopGenerator already create
+  // DominatorTree and LoopInfo for us.
+  DominatorTree *SubDT = ParallelLoopGenPtr->getCalleeDominatorTree();
+  LoopInfo *SubLI = ParallelLoopGenPtr->getCalleeLoopInfo();
+
+  // Create TargetLibraryInfo, AssumptionCachem and ScalarEvolution ourselves.
+  // TODO: Ideally, we would use the pass manager's TargetLibraryInfoPass and
+  // AssumptionAnalysis instead of our own. They contain more target-specific
+  // information than we have available here: TargetLibraryInfoImpl can be a
+  // derived class determine by TargetMachine, AssumptionCache can be configured
+  // using an TargetTransformInfo object also derived from TargetMachine.
+  TargetLibraryInfoImpl BaselineInfoImpl(
+      Triple(SubFn->getParent()->getTargetTriple()));
+  TargetLibraryInfo CalleeTLI(BaselineInfoImpl, SubFn);
+  AssumptionCache CalleeAC(*SubFn);
+  std::unique_ptr<ScalarEvolution> SubSE = std::make_unique<ScalarEvolution>(
+      *SubFn, CalleeTLI, CalleeAC, *SubDT, *SubLI);
+
+  // Switch to the subfunction
+  GenDT = SubDT;
+  GenLI = SubLI;
+  GenSE = SubSE.get();
+  BlockGen.switchGeneratedFunc(SubFn, GenDT, GenLI, GenSE);
+  ExprBuilder.switchGeneratedFunc(SubFn, GenDT, GenLI, GenSE);
+  Builder.SetInsertPoint(&*LoopBody);
+
+  // Update the ValueMap to use instructions in the subfunction. Note that
+  // "GlobalMap" used in BlockGenerator/IslExprBuilder is a reference to this
+  // ValueMap.
+  for (auto &[OldVal, NewVal] : ValueMap) {
+    NewVal = NewValues.lookup(NewVal);
+
+    // Clean-up any value that getReferencesInSubtree thinks we do not need.
+    // DenseMap::erase only writes a tombstone (and destroys OldVal/NewVal), so
+    // does not invalidate our iterator.
+    if (!NewVal)
+      ValueMap.erase(OldVal);
+  }
+
+  // This is for NewVals that do not appear in ValueMap (such as SCoP-invariant
+  // values whose original value can be reused as long as we are in the same
+  // function). No need to map the others.
+  for (auto &[NewVal, NewNewVal] : NewValues) {
+    if (Instruction *NewValInst = dyn_cast<Instruction>((Value *)NewVal)) {
+      if (S.contains(NewValInst))
+        continue;
+      assert(NewValInst->getFunction() == &S.getFunction());
+    }
+    assert(!ValueMap.contains(NewVal));
+    ValueMap[NewVal] = NewNewVal;
+  }
+
+  // Also update the IDToValue map to use instructions from the subfunction.
+  for (auto &[OldVal, NewVal] : IDToValue) {
+    NewVal = NewValues.lookup(NewVal);
+    assert(NewVal);
+  }
   IDToValue[IteratorID] = IV;
 
-  ValueMapT NewValuesReverse;
+#ifndef NDEBUG
+  // Check whether the maps now exclusively refer to SubFn values.
+  for (auto &[OldVal, SubVal] : ValueMap) {
+    Instruction *SubInst = dyn_cast<Instruction>((Value *)SubVal);
+    assert(SubInst->getFunction() == SubFn &&
+           "Instructions from outside the subfn cannot be accessed within the "
+           "subfn");
+  }
+  for (auto &[Id, SubVal] : IDToValue) {
+    Instruction *SubInst = dyn_cast<Instruction>((Value *)SubVal);
+    assert(SubInst->getFunction() == SubFn &&
+           "Instructions from outside the subfn cannot be accessed within the "
+           "subfn");
+  }
+#endif
 
+  ValueMapT NewValuesReverse;
   for (auto P : NewValues)
     NewValuesReverse[P.second] = P.first;
 
@@ -652,12 +672,16 @@ void IslNodeBuilder::createForParallel(__isl_take isl_ast_node *For) {
   create(Body);
 
   Annotator.resetAlternativeAliasBases();
-  // Restore the original values.
-  ValueMap = ValueMapCopy;
-  IDToValue = IDToValueCopy;
 
+  // Resume working on the caller function.
+  GenDT = CallerDT;
+  GenLI = CallerLI;
+  GenSE = CallerSE;
+  IDToValue = std::move(IDToValueCopy);
+  ValueMap = std::move(CallerGlobals);
+  ExprBuilder.switchGeneratedFunc(CallerFn, CallerDT, CallerLI, CallerSE);
+  BlockGen.switchGeneratedFunc(CallerFn, CallerDT, CallerLI, CallerSE);
   Builder.SetInsertPoint(&*AfterLoop);
-  removeSubFuncFromDomTree((*LoopBody).getParent()->getParent(), DT);
 
   for (const Loop *L : Loops)
     OutsideLoopIterations.erase(L);
@@ -686,21 +710,21 @@ void IslNodeBuilder::createIf(__isl_take isl_ast_node *If) {
   LLVMContext &Context = F->getContext();
 
   BasicBlock *CondBB = SplitBlock(Builder.GetInsertBlock(),
-                                  &*Builder.GetInsertPoint(), &DT, &LI);
+                                  &*Builder.GetInsertPoint(), GenDT, GenLI);
   CondBB->setName("polly.cond");
-  BasicBlock *MergeBB = SplitBlock(CondBB, &CondBB->front(), &DT, &LI);
+  BasicBlock *MergeBB = SplitBlock(CondBB, &CondBB->front(), GenDT, GenLI);
   MergeBB->setName("polly.merge");
   BasicBlock *ThenBB = BasicBlock::Create(Context, "polly.then", F);
   BasicBlock *ElseBB = BasicBlock::Create(Context, "polly.else", F);
 
-  DT.addNewBlock(ThenBB, CondBB);
-  DT.addNewBlock(ElseBB, CondBB);
-  DT.changeImmediateDominator(MergeBB, CondBB);
+  GenDT->addNewBlock(ThenBB, CondBB);
+  GenDT->addNewBlock(ElseBB, CondBB);
+  GenDT->changeImmediateDominator(MergeBB, CondBB);
 
-  Loop *L = LI.getLoopFor(CondBB);
+  Loop *L = GenLI->getLoopFor(CondBB);
   if (L) {
-    L->addBasicBlockToLoop(ThenBB, LI);
-    L->addBasicBlockToLoop(ElseBB, LI);
+    L->addBasicBlockToLoop(ThenBB, *GenLI);
+    L->addBasicBlockToLoop(ElseBB, *GenLI);
   }
 
   CondBB->getTerminator()->eraseFromParent();
@@ -1088,19 +1112,19 @@ Value *IslNodeBuilder::preloadInvariantLoad(const MemoryAccess &MA,
     Cond = Builder.CreateIsNotNull(Cond);
 
   BasicBlock *CondBB = SplitBlock(Builder.GetInsertBlock(),
-                                  &*Builder.GetInsertPoint(), &DT, &LI);
+                                  &*Builder.GetInsertPoint(), GenDT, GenLI);
   CondBB->setName("polly.preload.cond");
 
-  BasicBlock *MergeBB = SplitBlock(CondBB, &CondBB->front(), &DT, &LI);
+  BasicBlock *MergeBB = SplitBlock(CondBB, &CondBB->front(), GenDT, GenLI);
   MergeBB->setName("polly.preload.merge");
 
   Function *F = Builder.GetInsertBlock()->getParent();
   LLVMContext &Context = F->getContext();
   BasicBlock *ExecBB = BasicBlock::Create(Context, "polly.preload.exec", F);
 
-  DT.addNewBlock(ExecBB, CondBB);
-  if (Loop *L = LI.getLoopFor(CondBB))
-    L->addBasicBlockToLoop(ExecBB, LI);
+  GenDT->addNewBlock(ExecBB, CondBB);
+  if (Loop *L = GenLI->getLoopFor(CondBB))
+    L->addBasicBlockToLoop(ExecBB, *GenLI);
 
   auto *CondBBTerminator = CondBB->getTerminator();
   Builder.SetInsertPoint(CondBBTerminator);
@@ -1326,7 +1350,7 @@ bool IslNodeBuilder::preloadInvariantLoads() {
     return true;
 
   BasicBlock *PreLoadBB = SplitBlock(Builder.GetInsertBlock(),
-                                     &*Builder.GetInsertPoint(), &DT, &LI);
+                                     &*Builder.GetInsertPoint(), GenDT, GenLI);
   PreLoadBB->setName("polly.preload.begin");
   Builder.SetInsertPoint(&PreLoadBB->front());
 
@@ -1375,8 +1399,10 @@ Value *IslNodeBuilder::generateSCEV(const SCEV *Expr) {
   assert(Builder.GetInsertBlock()->end() != Builder.GetInsertPoint() &&
          "Insert location points after last valid instruction");
   Instruction *InsertLocation = &*Builder.GetInsertPoint();
-  return expandCodeFor(S, SE, DL, "polly", Expr, Expr->getType(),
-                       InsertLocation, &ValueMap,
+
+  return expandCodeFor(S, SE, Builder.GetInsertBlock()->getParent(), *GenSE, DL,
+                       "polly", Expr, Expr->getType(), InsertLocation,
+                       &ValueMap, /*LoopToScevMap*/ nullptr,
                        StartBlock->getSinglePredecessor());
 }
 
diff --git a/polly/lib/CodeGen/LoopGeneratorsGOMP.cpp b/polly/lib/CodeGen/LoopGeneratorsGOMP.cpp
index e7512c1f33f610..cd440b28202e6d 100644
--- a/polly/lib/CodeGen/LoopGeneratorsGOMP.cpp
+++ b/polly/lib/CodeGen/LoopGeneratorsGOMP.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "polly/CodeGen/LoopGeneratorsGOMP.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Module.h"
 
@@ -108,21 +109,20 @@ ParallelLoopGeneratorGOMP::createSubFn(Value *Stride, AllocaInst *StructData,
   Function *SubFn = createSubFnDefinition();
   LLVMContext &Context = SubFn->getContext();
 
-  // Store the previous basic block.
-  BasicBlock *PrevBB = Builder.GetInsertBlock();
-
   // Create basic blocks.
   BasicBlock *HeaderBB = BasicBlock::Create(Context, "polly.par.setup", SubFn);
+  SubFnDT = std::make_unique<DominatorTree>(*SubFn);
+  SubFnLI = std::make_unique<LoopInfo>(*SubFnDT);
+
   BasicBlock *ExitBB = BasicBlock::Create(Context, "polly.par.exit", SubFn);
   BasicBlock *CheckNextBB =
       BasicBlock::Create(Context, "polly.par.checkNext", SubFn);
   BasicBlock *PreHeaderBB =
       BasicBlock::Create(Context, "polly.par.loadIVBounds", SubFn);
 
-  DT.addNewBlock(HeaderBB, PrevBB);
-  DT.addNewBlock(ExitBB, HeaderBB);
-  DT.addNewBlock(CheckNextBB, HeaderBB);
-  DT.addNewBlock(PreHeaderBB, HeaderBB);
+  SubFnDT->addNewBlock(ExitBB, HeaderBB);
+  SubFnDT->addNewBlock(CheckNextBB, HeaderBB);
+  SubFnDT->addNewBlock(PreHeaderBB, HeaderBB);
 
   // Fill up basic block HeaderBB.
   Builder.SetInsertPoint(HeaderBB);
@@ -155,8 +155,8 @@ ParallelLoopGeneratorGOMP::createSubFn(Value *Stride, AllocaInst *StructData,
   Builder.SetInsertPoint(&*--Builder.GetInsertPoint());
   BasicBlock *AfterBB;
   Value *IV =
-      createLoop(LB, UB, Stride, Builder, LI, DT, AfterBB, ICmpInst::ICMP_SLE,
-                 nullptr, true, /* UseGuard */ false);
+      createLoop(LB, UB, Stride, Builder, *SubFnLI, *SubFnDT, AfterBB,
+                 ICmpInst::ICMP_SLE, nullptr, true, /* UseGuard */ false);
 
   BasicBlock::iterator LoopBody = Builder.GetInsertPoint();
 
@@ -167,6 +167,10 @@ ParallelLoopGeneratorGOMP::createSubFn(Value *Stride, AllocaInst *StructData,
 
   Builder.SetInsertPoint(&*LoopBody);
 
+  // FIXME: Call SubFnDT->verify() and SubFnLI->verify() to check that the
+  // DominatorTree/LoopInfo has been created correctly. Alternatively, recreate
+  // from scratch since it is not needed here directly.
+
   return std::make_tuple(IV, SubFn);
 }
 
diff --git a/polly/lib/CodeGen/LoopGeneratorsKMP.cpp b/polly/lib/CodeGen/LoopGeneratorsKMP.cpp
index b3af7b14f47808..4ec5afe6aa6317 100644
--- a/polly/lib/CodeGen/LoopGeneratorsKMP.cpp
+++ b/polly/lib/CodeGen/LoopGeneratorsKMP.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "polly/CodeGen/LoopGeneratorsKMP.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Module.h"
 
@@ -135,21 +136,20 @@ ParallelLoopGeneratorKMP::createSubFn(Value *SequentialLoopStride,
   Function *SubFn = createSubFnDefinition();
   LLVMContext &Context = SubFn->getContext();
 
-  // Store the previous basic block.
-  BasicBlock *PrevBB = Builder.GetInsertBlock();
-
   // Create basic blocks.
   BasicBlock *HeaderBB = BasicBlock::Create(Context, "polly.par.setup", SubFn);
+  SubFnDT = std::make_unique<DominatorTree>(*SubFn);
+  SubFnLI = std::make_unique<LoopInfo>(*SubFnDT);
+
   BasicBlock *ExitBB = BasicBlock::Create(Context, "polly.par.exit", SubFn);
   BasicBlock *CheckNextBB =
       BasicBlock::Create(Context, "polly.par.checkNext", SubFn);
   BasicBlock *PreHeaderBB =
       BasicBlock::Create(Context, "polly.par.loadIVBounds", SubFn);
 
-  DT.addNewBlock(HeaderBB, PrevBB);
-  DT.addNewBlock(ExitBB, HeaderBB);
-  DT.addNewBlock(CheckNextBB, HeaderBB);
-  DT.addNewBlock(PreHeaderBB, HeaderBB);
+  SubFnDT->addNewBlock(ExitBB, HeaderBB);
+  SubFnDT->addNewBlock(CheckNextBB, HeaderBB);
+  SubFnDT->addNewBlock(PreHeaderBB, HeaderBB);
 
   // Fill up basic block HeaderBB.
   Builder.SetInsertPoint(HeaderBB);
@@ -291,8 +291,8 @@ ParallelLoopGeneratorKMP::createSubFn(Value *SequentialLoopStride,
   Builder.CreateBr(CheckNextBB);
   Builder.SetInsertPoint(&*--Builder.GetInsertPoint());
   BasicBlock *AfterBB;
-  Value *IV = createLoop(LB, UB, SequentialLoopStride, Builder, LI, DT, AfterBB,
-                         ICmpInst::ICMP_SLE, nullptr, true,
+  Value *IV = createLoop(LB, UB, SequentialLoopStride, Builder, *SubFnLI,
+                         *SubFnDT, AfterBB, ICmpInst::ICMP_SLE, nullptr, true,
                          /* UseGuard */ false);
 
   BasicBlock::iterator LoopBody = Builder.GetInsertPoint();
@@ -307,6 +307,10 @@ ParallelLoopGeneratorKMP::createSubFn(Value *SequentialLoopStride,
   Builder.CreateRetVoid();
   Builder.SetInsertPoint(&*LoopBody);
 
+  // FIXME: Call SubFnDT->verify() and SubFnLI->verify() to check that the
+  // DominatorTree/LoopInfo has been created correctly. Alternatively, recreate
+  // from scratch since it is not needed here directly.
+
   return std::make_tuple(IV, SubFn);
 }
 
diff --git a/polly/lib/Support/ScopHelper.cpp b/polly/lib/Support/ScopHelper.cpp
index 24c7011b06de93..754bf50e2911f1 100644
--- a/polly/lib/Support/ScopHelper.cpp
+++ b/polly/lib/Support/ScopHelper.cpp
@@ -228,6 +228,22 @@ void polly::recordAssumption(polly::RecordedAssumptionsTy *RecordedAssumptions,
     RecordedAssumptions->push_back({Kind, Sign, Set, Loc, BB, RTC});
 }
 
+/// ScopExpander generates IR the the value of a SCEV that represents a value
+/// from a SCoP.
+///
+/// IMPORTANT: There are two ScalarEvolutions at play here. First, the SE that
+/// was used to analyze the original SCoP (not actually referenced anywhere
+/// here, but passed as argument to make the distinction clear). Second, GenSE
+/// which is the SE for the function that the code is emitted into. SE and GenSE
+/// may be different when the generated code is to be emitted into an outlined
+/// function, e.g. for a parallel loop. That is, each SCEV is to be used only by
+/// the SE that "owns" it and ScopExpander handles the translation between them.
+/// The SCEVVisitor methods are only to be called on SCEVs of the original SE.
+/// Their job is to create a new SCEV for GenSE. The nested SCEVExpander is to
+/// be used only with SCEVs belonging to GenSE. Currently SCEVs do not store a
+/// reference to the ScalarEvolution they belong to, so a mixup does not
+/// immediately cause a crash but certainly is a violation of its interface.
+///
 /// The SCEVExpander will __not__ generate any code for an existing SDiv/SRem
 /// instruction but just use it, if it is referenced as a SCEVUnknown. We want
 /// however to generate new code if the instruction is in the analyzed region
@@ -237,19 +253,19 @@ void polly::recordAssumption(polly::RecordedAssumptionsTy *RecordedAssumptions,
 struct ScopExpander final : SCEVVisitor<ScopExpander, const SCEV *> {
   friend struct SCEVVisitor<ScopExpander, const SCEV *>;
 
-  explicit ScopExpander(const Region &R, ScalarEvolution &SE,
-                        const DataLayout &DL, const char *Name, ValueMapT *VMap,
-                        BasicBlock *RTCBB)
-      : Expander(SE, DL, Name, /*PreserveLCSSA=*/false), SE(SE), Name(Name),
-        R(R), VMap(VMap), RTCBB(RTCBB) {}
+  explicit ScopExpander(const Region &R, ScalarEvolution &SE, Function *GenFn,
+                        ScalarEvolution &GenSE, const DataLayout &DL,
+                        const char *Name, ValueMapT *VMap,
+                        LoopToScevMapT *LoopMap, BasicBlock *RTCBB)
+      : Expander(GenSE, DL, Name, /*PreserveLCSSA=*/false), Name(Name), R(R),
+        VMap(VMap), LoopMap(LoopMap), RTCBB(RTCBB), GenSE(GenSE), GenFn(GenFn) {
+  }
 
-  Value *expandCodeFor(const SCEV *E, Type *Ty, Instruction *I) {
-    // If we generate code in the region we will immediately fall back to the
-    // SCEVExpander, otherwise we will stop at all unknowns in the SCEV and if
-    // needed replace them by copies computed in the entering block.
-    if (!R.contains(I))
-      E = visit(E);
-    return Expander.expandCodeFor(E, Ty, I);
+  Value *expandCodeFor(const SCEV *E, Type *Ty, Instruction *IP) {
+    assert(isInGenRegion(IP) &&
+           "ScopExpander assumes to be applied to generated code region");
+    const SCEV *GenE = visit(E);
+    return Expander.expandCodeFor(GenE, Ty, IP);
   }
 
   const SCEV *visit(const SCEV *E) {
@@ -265,16 +281,32 @@ struct ScopExpander final : SCEVVisitor<ScopExpander, const SCEV *> {
 
 private:
   SCEVExpander Expander;
-  ScalarEvolution &SE;
   const char *Name;
   const Region &R;
   ValueMapT *VMap;
+  LoopToScevMapT *LoopMap;
   BasicBlock *RTCBB;
   DenseMap<const SCEV *, const SCEV *> SCEVCache;
 
+  ScalarEvolution &GenSE;
+  Function *GenFn;
+
+  /// Is the instruction part of the original SCoP (in contrast to be located in
+  /// the code-generated region)?
+  bool isInOrigRegion(Instruction *Inst) {
+    Function *Fn = R.getEntry()->getParent();
+    bool isInOrigRegion = Inst->getFunction() == Fn && R.contains(Inst);
+    assert((isInOrigRegion || GenFn == Inst->getFunction()) &&
+           "Instruction expected to be either in the SCoP or the translated "
+           "region");
+    return isInOrigRegion;
+  }
+
+  bool isInGenRegion(Instruction *Inst) { return !isInOrigRegion(Inst); }
+
   const SCEV *visitGenericInst(const SCEVUnknown *E, Instruction *Inst,
                                Instruction *IP) {
-    if (!Inst || !R.contains(Inst))
+    if (!Inst || isInGenRegion(Inst))
       return E;
 
     assert(!Inst->mayThrow() && !Inst->mayReadOrWriteMemory() &&
@@ -282,15 +314,15 @@ struct ScopExpander final : SCEVVisitor<ScopExpander, const SCEV *> {
 
     auto *InstClone = Inst->clone();
     for (auto &Op : Inst->operands()) {
-      assert(SE.isSCEVable(Op->getType()));
-      auto *OpSCEV = SE.getSCEV(Op);
+      assert(GenSE.isSCEVable(Op->getType()));
+      auto *OpSCEV = GenSE.getSCEV(Op);
       auto *OpClone = expandCodeFor(OpSCEV, Op->getType(), IP);
       InstClone->replaceUsesOfWith(Op, OpClone);
     }
 
     InstClone->setName(Name + Inst->getName());
     InstClone->insertBefore(IP);
-    return SE.getSCEV(InstClone);
+    return GenSE.getSCEV(InstClone);
   }
 
   const SCEV *visitUnknown(const SCEVUnknown *E) {
@@ -298,19 +330,27 @@ struct ScopExpander final : SCEVVisitor<ScopExpander, const SCEV *> {
     // If a value mapping was given try if the underlying value is remapped.
     Value *NewVal = VMap ? VMap->lookup(E->getValue()) : nullptr;
     if (NewVal) {
-      auto *NewE = SE.getSCEV(NewVal);
+      auto *NewE = GenSE.getSCEV(NewVal);
 
       // While the mapped value might be different the SCEV representation might
       // not be. To this end we will check before we go into recursion here.
+      // FIXME: SCEVVisitor must only visit SCEVs that belong to the original
+      // SE. This calls it on SCEVs that belong GenSE.
       if (E != NewE)
         return visit(NewE);
     }
 
     Instruction *Inst = dyn_cast<Instruction>(E->getValue());
     Instruction *IP;
-    if (Inst && !R.contains(Inst))
+    if (Inst && isInGenRegion(Inst))
       IP = Inst;
-    else if (Inst && RTCBB->getParent() == Inst->getFunction())
+    else if (R.getEntry()->getParent() != GenFn) {
+      // RTCBB is in the original function, but we are generating for a
+      // subfunction so we cannot emit to RTCBB. Usually, we land here only
+      // because E->getValue() is not an instruction but a global or constant
+      // which do not need to emit anything.
+      IP = GenFn->getEntryBlock().getTerminator();
+    } else if (Inst && RTCBB->getParent() == Inst->getFunction())
       IP = RTCBB->getTerminator();
     else
       IP = RTCBB->getParent()->getEntryBlock().getTerminator();
@@ -319,11 +359,11 @@ struct ScopExpander final : SCEVVisitor<ScopExpander, const SCEV *> {
                   Inst->getOpcode() != Instruction::SDiv))
       return visitGenericInst(E, Inst, IP);
 
-    const SCEV *LHSScev = SE.getSCEV(Inst->getOperand(0));
-    const SCEV *RHSScev = SE.getSCEV(Inst->getOperand(1));
+    const SCEV *LHSScev = GenSE.getSCEV(Inst->getOperand(0));
+    const SCEV *RHSScev = GenSE.getSCEV(Inst->getOperand(1));
 
-    if (!SE.isKnownNonZero(RHSScev))
-      RHSScev = SE.getUMaxExpr(RHSScev, SE.getConstant(E->getType(), 1));
+    if (!GenSE.isKnownNonZero(RHSScev))
+      RHSScev = GenSE.getUMaxExpr(RHSScev, GenSE.getConstant(E->getType(), 1));
 
     Value *LHS = expandCodeFor(LHSScev, E->getType(), IP);
     Value *RHS = expandCodeFor(RHSScev, E->getType(), IP);
@@ -331,89 +371,105 @@ struct ScopExpander final : SCEVVisitor<ScopExpander, const SCEV *> {
     Inst =
         BinaryOperator::Create((Instruction::BinaryOps)Inst->getOpcode(), LHS,
                                RHS, Inst->getName() + Name, IP->getIterator());
-    return SE.getSCEV(Inst);
+    return GenSE.getSCEV(Inst);
   }
 
-  /// The following functions will just traverse the SCEV and rebuild it with
-  /// the new operands returned by the traversal.
+  /// The following functions will just traverse the SCEV and rebuild it using
+  /// GenSE and the new operands returned by the traversal.
   ///
   ///{
   const SCEV *visitConstant(const SCEVConstant *E) { return E; }
   const SCEV *visitVScale(const SCEVVScale *E) { return E; }
   const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *E) {
-    return SE.getPtrToIntExpr(visit(E->getOperand()), E->getType());
+    return GenSE.getPtrToIntExpr(visit(E->getOperand()), E->getType());
   }
   const SCEV *visitTruncateExpr(const SCEVTruncateExpr *E) {
-    return SE.getTruncateExpr(visit(E->getOperand()), E->getType());
+    return GenSE.getTruncateExpr(visit(E->getOperand()), E->getType());
   }
   const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *E) {
-    return SE.getZeroExtendExpr(visit(E->getOperand()), E->getType());
+    return GenSE.getZeroExtendExpr(visit(E->getOperand()), E->getType());
   }
   const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *E) {
-    return SE.getSignExtendExpr(visit(E->getOperand()), E->getType());
+    return GenSE.getSignExtendExpr(visit(E->getOperand()), E->getType());
   }
   const SCEV *visitUDivExpr(const SCEVUDivExpr *E) {
     auto *RHSScev = visit(E->getRHS());
-    if (!SE.isKnownNonZero(RHSScev))
-      RHSScev = SE.getUMaxExpr(RHSScev, SE.getConstant(E->getType(), 1));
-    return SE.getUDivExpr(visit(E->getLHS()), RHSScev);
+    if (!GenSE.isKnownNonZero(RHSScev))
+      RHSScev = GenSE.getUMaxExpr(RHSScev, GenSE.getConstant(E->getType(), 1));
+    return GenSE.getUDivExpr(visit(E->getLHS()), RHSScev);
   }
   const SCEV *visitAddExpr(const SCEVAddExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getAddExpr(NewOps);
+    return GenSE.getAddExpr(NewOps);
   }
   const SCEV *visitMulExpr(const SCEVMulExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getMulExpr(NewOps);
+    return GenSE.getMulExpr(NewOps);
   }
   const SCEV *visitUMaxExpr(const SCEVUMaxExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getUMaxExpr(NewOps);
+    return GenSE.getUMaxExpr(NewOps);
   }
   const SCEV *visitSMaxExpr(const SCEVSMaxExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getSMaxExpr(NewOps);
+    return GenSE.getSMaxExpr(NewOps);
   }
   const SCEV *visitUMinExpr(const SCEVUMinExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getUMinExpr(NewOps);
+    return GenSE.getUMinExpr(NewOps);
   }
   const SCEV *visitSMinExpr(const SCEVSMinExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getSMinExpr(NewOps);
+    return GenSE.getSMinExpr(NewOps);
   }
   const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getUMinExpr(NewOps, /*Sequential=*/true);
+    return GenSE.getUMinExpr(NewOps, /*Sequential=*/true);
   }
   const SCEV *visitAddRecExpr(const SCEVAddRecExpr *E) {
     SmallVector<const SCEV *, 4> NewOps;
     for (const SCEV *Op : E->operands())
       NewOps.push_back(visit(Op));
-    return SE.getAddRecExpr(NewOps, E->getLoop(), E->getNoWrapFlags());
+
+    const Loop *L = E->getLoop();
+    const SCEV *GenLRepl = LoopMap ? LoopMap->lookup(L) : nullptr;
+    if (!GenLRepl)
+      return GenSE.getAddRecExpr(NewOps, L, E->getNoWrapFlags());
+
+    // evaluateAtIteration replaces the SCEVAddrExpr with a direct calculation.
+    const SCEV *Evaluated =
+        SCEVAddRecExpr::evaluateAtIteration(NewOps, GenLRepl, GenSE);
+
+    // FIXME: This emits a SCEV for GenSE (since GenLRepl will refer to the
+    // induction variable of a generated loop), so we should not use SCEVVisitor
+    // with it. Howver, it still contains references to the SCoP region.
+    return visit(Evaluated);
   }
   ///}
 };
 
-Value *polly::expandCodeFor(Scop &S, ScalarEvolution &SE, const DataLayout &DL,
-                            const char *Name, const SCEV *E, Type *Ty,
-                            Instruction *IP, ValueMapT *VMap,
+Value *polly::expandCodeFor(Scop &S, llvm::ScalarEvolution &SE,
+                            llvm::Function *GenFn, ScalarEvolution &GenSE,
+                            const DataLayout &DL, const char *Name,
+                            const SCEV *E, Type *Ty, Instruction *IP,
+                            ValueMapT *VMap, LoopToScevMapT *LoopMap,
                             BasicBlock *RTCBB) {
-  ScopExpander Expander(S.getRegion(), SE, DL, Name, VMap, RTCBB);
+  ScopExpander Expander(S.getRegion(), SE, GenFn, GenSE, DL, Name, VMap,
+                        LoopMap, RTCBB);
   return Expander.expandCodeFor(E, Ty, IP);
 }
 

>From aeb0a0763e03bd8783f93ac27c1b1e156a4653b2 Mon Sep 17 00:00:00 2001
From: Michael Kruse <llvm-project at meinersbur.de>
Date: Thu, 8 Aug 2024 17:13:12 +0200
Subject: [PATCH 2/2] add assertion for correct codegen Fn

---
 polly/include/polly/CodeGen/BlockGenerators.h | 12 ++++--------
 polly/include/polly/CodeGen/IslExprBuilder.h  |  7 +------
 polly/lib/CodeGen/BlockGenerators.cpp         | 11 +++++++++++
 polly/lib/CodeGen/IslExprBuilder.cpp          | 12 ++++++++++++
 4 files changed, 28 insertions(+), 14 deletions(-)

diff --git a/polly/include/polly/CodeGen/BlockGenerators.h b/polly/include/polly/CodeGen/BlockGenerators.h
index 074426c8ccbda5..4e2645468a7434 100644
--- a/polly/include/polly/CodeGen/BlockGenerators.h
+++ b/polly/include/polly/CodeGen/BlockGenerators.h
@@ -170,14 +170,6 @@ class BlockGenerator {
   /// @}
 
 public:
-  /// Change the function that code is emitted into.
-  void switchGeneratedFunc(Function *GenFn, DominatorTree *GenDT,
-                           LoopInfo *GenLI, ScalarEvolution *GenSE) {
-    this->GenDT = GenDT;
-    this->GenLI = GenLI;
-    this->GenSE = GenSE;
-  }
-
   /// Map to resolve scalar dependences for PHI operands and scalars.
   ///
   /// When translating code that contains scalar dependences as they result from
@@ -311,6 +303,10 @@ class BlockGenerator {
   /// Split @p BB to create a new one we can use to clone @p BB in.
   BasicBlock *splitBB(BasicBlock *BB);
 
+  /// Change the function that code is emitted into.
+  void switchGeneratedFunc(Function *GenFn, DominatorTree *GenDT,
+                           LoopInfo *GenLI, ScalarEvolution *GenSE);
+
   /// Copy the given basic block.
   ///
   /// @param Stmt      The statement to code generate.
diff --git a/polly/include/polly/CodeGen/IslExprBuilder.h b/polly/include/polly/CodeGen/IslExprBuilder.h
index 6a6d644ee2439a..25f61be5787c13 100644
--- a/polly/include/polly/CodeGen/IslExprBuilder.h
+++ b/polly/include/polly/CodeGen/IslExprBuilder.h
@@ -126,12 +126,7 @@ class IslExprBuilder final {
 
   /// Change the function that code is emitted into.
   void switchGeneratedFunc(llvm::Function *GenFn, llvm::DominatorTree *GenDT,
-                           llvm::LoopInfo *GenLI,
-                           llvm::ScalarEvolution *GenSE) {
-    this->GenDT = GenDT;
-    this->GenLI = GenLI;
-    this->GenSE = GenSE;
-  }
+                           llvm::LoopInfo *GenLI, llvm::ScalarEvolution *GenSE);
 
   /// Create LLVM-IR for an isl_ast_expr[ession].
   ///
diff --git a/polly/lib/CodeGen/BlockGenerators.cpp b/polly/lib/CodeGen/BlockGenerators.cpp
index c7e1b21286b443..004fa64c82db2b 100644
--- a/polly/lib/CodeGen/BlockGenerators.cpp
+++ b/polly/lib/CodeGen/BlockGenerators.cpp
@@ -432,6 +432,17 @@ BasicBlock *BlockGenerator::copyBB(ScopStmt &Stmt, BasicBlock *BB,
   return CopyBB;
 }
 
+void BlockGenerator::switchGeneratedFunc(Function *GenFn, DominatorTree *GenDT,
+                                         LoopInfo *GenLI,
+                                         ScalarEvolution *GenSE) {
+  assert(GenFn == GenDT->getRoot()->getParent());
+  assert(GenLI->getTopLevelLoops().empty() ||
+         GenFn == GenLI->getTopLevelLoops().front()->getHeader()->getParent());
+  this->GenDT = GenDT;
+  this->GenLI = GenLI;
+  this->GenSE = GenSE;
+}
+
 void BlockGenerator::copyBB(ScopStmt &Stmt, BasicBlock *BB, BasicBlock *CopyBB,
                             ValueMapT &BBMap, LoopToScevMapT &LTS,
                             isl_id_to_ast_expr *NewAccesses) {
diff --git a/polly/lib/CodeGen/IslExprBuilder.cpp b/polly/lib/CodeGen/IslExprBuilder.cpp
index d573daee87a153..aaafac14bf8065 100644
--- a/polly/lib/CodeGen/IslExprBuilder.cpp
+++ b/polly/lib/CodeGen/IslExprBuilder.cpp
@@ -47,6 +47,18 @@ IslExprBuilder::IslExprBuilder(Scop &S, PollyIRBuilder &Builder,
   OverflowState = (OTMode == OT_ALWAYS) ? Builder.getFalse() : nullptr;
 }
 
+void IslExprBuilder::switchGeneratedFunc(llvm::Function *GenFn,
+                                         llvm::DominatorTree *GenDT,
+                                         llvm::LoopInfo *GenLI,
+                                         llvm::ScalarEvolution *GenSE) {
+  assert(GenFn == GenDT->getRoot()->getParent());
+  assert(GenLI->getTopLevelLoops().empty() ||
+         GenFn == GenLI->getTopLevelLoops().front()->getHeader()->getParent());
+  this->GenDT = GenDT;
+  this->GenLI = GenLI;
+  this->GenSE = GenSE;
+}
+
 void IslExprBuilder::setTrackOverflow(bool Enable) {
   // If potential overflows are tracked always or never we ignore requests
   // to change the behavior.



More information about the llvm-commits mailing list