[llvm] 0dd8401 - [AggressiveInstCombine] Add `phi` nodes support to `TruncInstCombine`

Anton Afanasyev via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 24 20:57:44 PST 2022


Author: Anton Afanasyev
Date: 2022-02-25T07:57:35+03:00
New Revision: 0dd84013710555599253f9de58996e98206d1383

URL: https://github.com/llvm/llvm-project/commit/0dd84013710555599253f9de58996e98206d1383
DIFF: https://github.com/llvm/llvm-project/commit/0dd84013710555599253f9de58996e98206d1383.diff

LOG: [AggressiveInstCombine] Add `phi` nodes support to `TruncInstCombine`

Expand `TruncInstCombine` to handle loops by adding `phi` nodes
to expression graph.

Reviewed by: RKSimon, lebedev.ri

(recommit of fixed f84d732f, reverted by 8ad6d5e after sanitizer breakage)

Differential Revision: https://reviews.llvm.org/D109817

Added: 
    

Modified: 
    llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
    llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
    llvm/test/Transforms/AggressiveInstCombine/trunc_phi.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
index 6c73645b20f20..9fc103d45d985 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
+++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombineInternal.h
@@ -23,14 +23,14 @@
 using namespace llvm;
 
 //===----------------------------------------------------------------------===//
-// TruncInstCombine - looks for expression dags dominated by trunc instructions
-// and for each eligible dag, it will create a reduced bit-width expression and
-// replace the old expression with this new one and remove the old one.
-// Eligible expression dag is such that:
+// TruncInstCombine - looks for expression graphs dominated by trunc
+// instructions and for each eligible graph, it will create a reduced bit-width
+// expression and replace the old expression with this new one and remove the
+// old one. Eligible expression graph is such that:
 //   1. Contains only supported instructions.
 //   2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
 //   3. Can be evaluated into type with reduced legal bit-width (or Trunc type).
-//   4. All instructions in the dag must not have users outside the dag.
+//   4. All instructions in the graph must not have users outside the graph.
 //      Only exception is for {ZExt, SExt}Inst with operand type equal to the
 //      new reduced type chosen in (3).
 //
@@ -63,7 +63,7 @@ class TruncInstCombine {
   /// Current processed TruncInst instruction.
   TruncInst *CurrentTruncInst = nullptr;
 
-  /// Information per each instruction in the expression dag.
+  /// Information per each instruction in the expression graph.
   struct Info {
     /// Number of LSBs that are needed to generate a valid expression.
     unsigned ValidBitWidth = 0;
@@ -72,10 +72,10 @@ class TruncInstCombine {
     /// The reduced value generated to replace the old instruction.
     Value *NewValue = nullptr;
   };
-  /// An ordered map representing expression dag post-dominated by current
-  /// processed TruncInst. It maps each instruction in the dag to its Info
+  /// An ordered map representing expression graph post-dominated by current
+  /// processed TruncInst. It maps each instruction in the graph to its Info
   /// structure. The map is ordered such that each instruction appears before
-  /// all other instructions in the dag that uses it.
+  /// all other instructions in the graph that uses it.
   MapVector<Instruction *, Info> InstInfoMap;
 
 public:
@@ -87,11 +87,11 @@ class TruncInstCombine {
   bool run(Function &F);
 
 private:
-  /// Build expression dag dominated by the /p CurrentTruncInst and append it to
-  /// the InstInfoMap container.
+  /// Build expression graph dominated by the /p CurrentTruncInst and append it
+  /// to the InstInfoMap container.
   ///
-  /// \return true only if succeed to generate an eligible sub expression dag.
-  bool buildTruncExpressionDag();
+  /// \return true only if succeed to generate an eligible sub expression graph.
+  bool buildTruncExpressionGraph();
 
   /// Calculate the minimal allowed bit-width of the chain ending with the
   /// currently visited truncate's operand.
@@ -100,12 +100,12 @@ class TruncInstCombine {
   /// truncate's operand can be shrunk to.
   unsigned getMinBitWidth();
 
-  /// Build an expression dag dominated by the current processed TruncInst and
+  /// Build an expression graph dominated by the current processed TruncInst and
   /// Check if it is eligible to be reduced to a smaller type.
   ///
   /// \return the scalar version of the new type to be used for the reduced
-  ///         expression dag, or nullptr if the expression dag is not eligible
-  ///         to be reduced.
+  ///         expression graph, or nullptr if the expression graph is not
+  ///         eligible to be reduced.
   Type *getBestTruncatedType();
 
   KnownBits computeKnownBits(const Value *V) const {
@@ -128,12 +128,12 @@ class TruncInstCombine {
   /// \return the new reduced value.
   Value *getReducedOperand(Value *V, Type *SclTy);
 
-  /// Create a new expression dag using the reduced /p SclTy type and replace
-  /// the old expression dag with it. Also erase all instructions in the old
-  /// dag, except those that are still needed outside the dag.
+  /// Create a new expression graph using the reduced /p SclTy type and replace
+  /// the old expression graph with it. Also erase all instructions in the old
+  /// graph, except those that are still needed outside the graph.
   ///
-  /// \param SclTy scalar version of new type to reduce expression dag into.
-  void ReduceExpressionDag(Type *SclTy);
+  /// \param SclTy scalar version of new type to reduce expression graph into.
+  void ReduceExpressionGraph(Type *SclTy);
 };
 } // end namespace llvm.
 

diff  --git a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
index 4624b735bef8c..71f3d76c0ba78 100644
--- a/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
+++ b/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
@@ -6,14 +6,14 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// TruncInstCombine - looks for expression dags post-dominated by TruncInst and
-// for each eligible dag, it will create a reduced bit-width expression, replace
-// the old expression with this new one and remove the old expression.
-// Eligible expression dag is such that:
+// TruncInstCombine - looks for expression graphs post-dominated by TruncInst
+// and for each eligible graph, it will create a reduced bit-width expression,
+// replace the old expression with this new one and remove the old expression.
+// Eligible expression graph is such that:
 //   1. Contains only supported instructions.
 //   2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
 //   3. Can be evaluated into type with reduced legal bit-width.
-//   4. All instructions in the dag must not have users outside the dag.
+//   4. All instructions in the graph must not have users outside the graph.
 //      The only exception is for {ZExt, SExt}Inst with operand type equal to
 //      the new reduced type evaluated in (3).
 //
@@ -39,14 +39,13 @@ using namespace llvm;
 
 #define DEBUG_TYPE "aggressive-instcombine"
 
-STATISTIC(
-    NumDAGsReduced,
-    "Number of truncations eliminated by reducing bit width of expression DAG");
+STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "
+                           "width of expression graph");
 STATISTIC(NumInstrsReduced,
           "Number of instructions whose bit width was reduced");
 
 /// Given an instruction and a container, it fills all the relevant operands of
-/// that instruction, with respect to the Trunc expression dag optimizaton.
+/// that instruction, with respect to the Trunc expression graph optimizaton.
 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
   unsigned Opc = I->getOpcode();
   switch (Opc) {
@@ -78,15 +77,19 @@ static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
     Ops.push_back(I->getOperand(1));
     Ops.push_back(I->getOperand(2));
     break;
+  case Instruction::PHI:
+    for (Value *V : cast<PHINode>(I)->incoming_values())
+      Ops.push_back(V);
+    break;
   default:
     llvm_unreachable("Unreachable!");
   }
 }
 
-bool TruncInstCombine::buildTruncExpressionDag() {
+bool TruncInstCombine::buildTruncExpressionGraph() {
   SmallVector<Value *, 8> Worklist;
   SmallVector<Instruction *, 8> Stack;
-  // Clear old expression dag.
+  // Clear old instructions info.
   InstInfoMap.clear();
 
   Worklist.push_back(CurrentTruncInst->getOperand(0));
@@ -150,11 +153,19 @@ bool TruncInstCombine::buildTruncExpressionDag() {
       append_range(Worklist, Operands);
       break;
     }
+    case Instruction::PHI: {
+      SmallVector<Value *, 2> Operands;
+      getRelevantOperands(I, Operands);
+      // Add only operands not in Stack to prevent cycle
+      for (auto *Op : Operands)
+        if (all_of(Stack, [Op](Value *V) { return Op != V; }))
+          Worklist.push_back(Op);
+      break;
+    }
     default:
       // TODO: Can handle more cases here:
       // 1. shufflevector
       // 2. sdiv, srem
-      // 3. phi node(and loop handling)
       // ...
       return false;
     }
@@ -254,7 +265,7 @@ unsigned TruncInstCombine::getMinBitWidth() {
 }
 
 Type *TruncInstCombine::getBestTruncatedType() {
-  if (!buildTruncExpressionDag())
+  if (!buildTruncExpressionGraph())
     return nullptr;
 
   // We don't want to duplicate instructions, which isn't profitable. Thus, we
@@ -367,8 +378,10 @@ Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
   return Entry.NewValue;
 }
 
-void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
+void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {
   NumInstrsReduced += InstInfoMap.size();
+  // Pairs of old and new phi-nodes
+  SmallVector<std::pair<PHINode *, PHINode *>, 2> OldNewPHINodes;
   for (auto &Itr : InstInfoMap) { // Forward
     Instruction *I = Itr.first;
     TruncInstCombine::Info &NodeInfo = Itr.second;
@@ -451,6 +464,12 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
       Res = Builder.CreateSelect(Op0, LHS, RHS);
       break;
     }
+    case Instruction::PHI: {
+      Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands());
+      OldNewPHINodes.push_back(
+          std::make_pair(cast<PHINode>(I), cast<PHINode>(Res)));
+      break;
+    }
     default:
       llvm_unreachable("Unhandled instruction");
     }
@@ -460,6 +479,14 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
       ResI->takeName(I);
   }
 
+  for (auto &Node : OldNewPHINodes) {
+    PHINode *OldPN = Node.first;
+    PHINode *NewPN = Node.second;
+    for (auto Incoming : zip(OldPN->incoming_values(), OldPN->blocks()))
+      NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),
+                         std::get<1>(Incoming));
+  }
+
   Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);
   Type *DstTy = CurrentTruncInst->getType();
   if (Res->getType() != DstTy) {
@@ -470,17 +497,31 @@ void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
   }
   CurrentTruncInst->replaceAllUsesWith(Res);
 
-  // Erase old expression dag, which was replaced by the reduced expression dag.
-  // We iterate backward, which means we visit the instruction before we visit
-  // any of its operands, this way, when we get to the operand, we already
-  // removed the instructions (from the expression dag) that uses it.
+  // Erase old expression graph, which was replaced by the reduced expression
+  // graph.
   CurrentTruncInst->eraseFromParent();
+  // First, erase old phi-nodes and its uses
+  for (auto &Node : OldNewPHINodes) {
+    PHINode *OldPN = Node.first;
+    OldPN->replaceAllUsesWith(PoisonValue::get(OldPN->getType()));
+    OldPN->eraseFromParent();
+  }
+  // Now we have expression graph turned into dag.
+  // We iterate backward, which means we visit the instruction before we
+  // visit any of its operands, this way, when we get to the operand, we already
+  // removed the instructions (from the expression dag) that uses it.
   for (auto &I : llvm::reverse(InstInfoMap)) {
+    // Skip phi-nodes since they were erased before
+    if (isa<PHINode>(I.first))
+      continue;
     // We still need to check that the instruction has no users before we erase
     // it, because {SExt, ZExt}Inst Instruction might have other users that was
     // not reduced, in such case, we need to keep that instruction.
     if (I.first->use_empty())
       I.first->eraseFromParent();
+    else
+      assert((isa<SExtInst>(I.first) || isa<ZExtInst>(I.first)) &&
+             "Only {SExt, ZExt}Inst might have unreduced users");
   }
 }
 
@@ -498,18 +539,18 @@ bool TruncInstCombine::run(Function &F) {
   }
 
   // Process all TruncInst in the Worklist, for each instruction:
-  //   1. Check if it dominates an eligible expression dag to be reduced.
-  //   2. Create a reduced expression dag and replace the old one with it.
+  //   1. Check if it dominates an eligible expression graph to be reduced.
+  //   2. Create a reduced expression graph and replace the old one with it.
   while (!Worklist.empty()) {
     CurrentTruncInst = Worklist.pop_back_val();
 
     if (Type *NewDstSclTy = getBestTruncatedType()) {
       LLVM_DEBUG(
-          dbgs() << "ICE: TruncInstCombine reducing type of expression dag "
+          dbgs() << "ICE: TruncInstCombine reducing type of expression graph "
                     "dominated by: "
                  << CurrentTruncInst << '\n');
-      ReduceExpressionDag(NewDstSclTy);
-      ++NumDAGsReduced;
+      ReduceExpressionGraph(NewDstSclTy);
+      ++NumExprsReduced;
       MadeIRChange = true;
     }
   }

diff  --git a/llvm/test/Transforms/AggressiveInstCombine/trunc_phi.ll b/llvm/test/Transforms/AggressiveInstCombine/trunc_phi.ll
index 46bdb60fada6c..01103a1a5afbf 100644
--- a/llvm/test/Transforms/AggressiveInstCombine/trunc_phi.ll
+++ b/llvm/test/Transforms/AggressiveInstCombine/trunc_phi.ll
@@ -4,18 +4,17 @@
 define i16 @trunc_phi(i8 %x) {
 ; CHECK-LABEL: @trunc_phi(
 ; CHECK-NEXT:  LoopHeader:
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       Loop:
-; CHECK-NEXT:    [[ZEXT2:%.*]] = phi i32 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[SHL:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[ZEXT2:%.*]] = phi i16 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[SHL:%.*]], [[LOOP]] ]
 ; CHECK-NEXT:    [[J:%.*]] = phi i32 [ 0, [[LOOPHEADER]] ], [ [[I:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[SHL]] = shl i32 [[ZEXT2]], 1
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i32 [[SHL]] to i16
+; CHECK-NEXT:    [[SHL]] = shl i16 [[ZEXT2]], 1
 ; CHECK-NEXT:    [[I]] = add i32 [[J]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[I]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOPEND:%.*]], label [[LOOP]]
 ; CHECK:       LoopEnd:
-; CHECK-NEXT:    ret i16 [[TRUNC]]
+; CHECK-NEXT:    ret i16 [[SHL]]
 ;
 LoopHeader:
   %zext = zext i8 %x to i32
@@ -37,22 +36,21 @@ LoopEnd:
 define i16 @trunc_phi2(i8 %x, i32 %sw) {
 ; CHECK-LABEL: @trunc_phi2(
 ; CHECK-NEXT:  LoopHeader:
-; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i32
+; CHECK-NEXT:    [[ZEXT:%.*]] = zext i8 [[X:%.*]] to i16
 ; CHECK-NEXT:    switch i32 [[SW:%.*]], label [[LOOPEND:%.*]] [
 ; CHECK-NEXT:    i32 0, label [[LOOP:%.*]]
 ; CHECK-NEXT:    i32 1, label [[LOOP]]
 ; CHECK-NEXT:    ]
 ; CHECK:       Loop:
-; CHECK-NEXT:    [[ZEXT2:%.*]] = phi i32 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[ZEXT]], [[LOOPHEADER]] ], [ [[SHL:%.*]], [[LOOP]] ]
+; CHECK-NEXT:    [[ZEXT2:%.*]] = phi i16 [ [[ZEXT]], [[LOOPHEADER:%.*]] ], [ [[ZEXT]], [[LOOPHEADER]] ], [ [[SHL:%.*]], [[LOOP]] ]
 ; CHECK-NEXT:    [[J:%.*]] = phi i32 [ 0, [[LOOPHEADER]] ], [ 0, [[LOOPHEADER]] ], [ [[I:%.*]], [[LOOP]] ]
-; CHECK-NEXT:    [[SHL]] = shl i32 [[ZEXT2]], 1
+; CHECK-NEXT:    [[SHL]] = shl i16 [[ZEXT2]], 1
 ; CHECK-NEXT:    [[I]] = add i32 [[J]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i32 [[I]], 10
 ; CHECK-NEXT:    br i1 [[CMP]], label [[LOOPEND]], label [[LOOP]]
 ; CHECK:       LoopEnd:
-; CHECK-NEXT:    [[ZEXT3:%.*]] = phi i32 [ [[ZEXT]], [[LOOPHEADER]] ], [ [[ZEXT2]], [[LOOP]] ]
-; CHECK-NEXT:    [[TRUNC:%.*]] = trunc i32 [[ZEXT3]] to i16
-; CHECK-NEXT:    ret i16 [[TRUNC]]
+; CHECK-NEXT:    [[ZEXT3:%.*]] = phi i16 [ [[ZEXT]], [[LOOPHEADER]] ], [ [[ZEXT2]], [[LOOP]] ]
+; CHECK-NEXT:    ret i16 [[ZEXT3]]
 ;
 LoopHeader:
   %zext = zext i8 %x to i32


        


More information about the llvm-commits mailing list