[llvm] Adding change to IndVarSimplify pass to optimize IVs stuck in trivial vector operations (PR #122248)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 02:46:58 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: None (anilavakundu007)

<details>
<summary>Changes</summary>

This patch helps LLVM to unroll loop patterns like: 

``` 
typedef int ivec2 __attribute__((ext_vector_type(2)));
int data0;
void fn()
{
  int u_xlati59 = 1;
  while (true)
    {
      bool u_xlatb12 = u_xlati59 >= 3;
      if (u_xlatb12)
	{
	  break;
	}
      ivec2 u_xlati12 = (ivec2){u_xlati59, u_xlati59} + (ivec2){-1, 1};
      data0 += u_xlati12.x;
      u_xlati59 = u_xlati12.y;
    }
}
```
In the existing method the Induction variable fails as SCEV fails and the loop cannot be unrolled. With this patch the loop gets unrolled and the final code generated is a lot smaller.

linked issue: https://github.com/llvm/llvm-project/issues/121742

---
Full diff: https://github.com/llvm/llvm-project/pull/122248.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/IndVarSimplify.cpp (+226) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 8a3e0bc3eb9712..37361080a1d766 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -138,6 +138,7 @@ class IndVarSimplify {
   bool RunUnswitching = false;
 
   bool handleFloatingPointIV(Loop *L, PHINode *PH);
+  bool breakVectorOpsOnIVs(Loop *L);
   bool rewriteNonIntegerIVs(Loop *L);
 
   bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI);
@@ -1891,6 +1892,212 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
   return Changed;
 }
 
+// Get the latch condition instruction.
+static ICmpInst *getLatchCmpInst(const Loop &L) {
+  if (BasicBlock *Latch = L.getLoopLatch())
+    if (BranchInst *BI = dyn_cast_or_null<BranchInst>(Latch->getTerminator()))
+      if (BI->isConditional())
+        return dyn_cast<ICmpInst>(BI->getCondition());
+
+  return nullptr;
+}
+
+// get the vector which contains the IV
+// This function will return the vector which only contains the IV
+static Value* getVectorContaingIV(PHINode* IV)
+{
+
+   for (User* use: IV->users())
+   {
+      // check if the IV is a part of any vector or not
+      InsertElementInst *vecInsert = dyn_cast<InsertElementInst>(use);
+      if(!vecInsert)
+        continue;
+      // We need to check if this vector contains only the IV
+      FixedVectorType* vecTy = dyn_cast<FixedVectorType>(vecInsert->getType());
+      if (!vecTy)
+        continue;
+      // if it is vector of a single element with the IV as an element
+      if(vecTy->getNumElements() == 1)
+        return vecInsert;
+      // if we have larger vectors
+      if(vecTy->getNumElements() > 1)
+      {
+         //check the vector we are inserting into an empty vector
+         Value* srcVec = vecInsert->getOperand(0);
+         if (!isa<UndefValue>(srcVec) || !isa<PoisonValue>(srcVec))
+            continue;
+         //check if we are later inserting more elements into the vector or not
+         InsertElementInst *insertOtherVal = nullptr;
+         for (User* vecUse: vecInsert->users())
+         {
+            insertOtherVal = dyn_cast<InsertElementInst>(vecUse);
+            if(insertOtherVal)
+              break;
+         }
+         if(insertOtherVal)
+            continue;
+
+         // vector contains only IV
+         return vecInsert;
+      }
+   }
+   return nullptr;
+}
+
+
+// check if a PHINode is a possbile Induction variable or not
+// The existing functions do not work as SCEV fails
+// This happens when the IV is stuck in a vector operation
+// And the incoming value comes from an extract element instruction
+static PHINode* isPossibleIDVar(Loop *L)
+{
+  BasicBlock *Header = L->getHeader();
+  assert(Header && "Expected a valid loop header");
+  ICmpInst *CmpInst = getLatchCmpInst(*L);
+  if (!CmpInst)
+    return nullptr;
+
+  Value *LatchCmpOp0 = CmpInst->getOperand(0);
+  Value *LatchCmpOp1 = CmpInst->getOperand(1);
+
+  // check if the compare instruction operands are Extract element instructions
+  // We do this as we expect the extract element to be reading from the vector which has the IV
+  // If none of the operands are extract element instructions we do not proceed
+  if (!isa<ExtractElementInst>(LatchCmpOp0) && !isa<ExtractElementInst>(LatchCmpOp1))
+    return nullptr;
+
+  for (PHINode &IndVar : Header->phis())
+  {
+    if (!IndVar.getType()->isIntegerTy() && !IndVar.getType()->isFloatTy())
+      continue;
+
+    BasicBlock *Latch = L->getLoopLatch();
+    Value *StepInst = IndVar.getIncomingValueForBlock(Latch);
+
+    // case 1:
+    // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+    // StepInst = IndVar + step
+    // cmp = StepInst < FinalValue
+    if (StepInst == LatchCmpOp0 || StepInst == LatchCmpOp1)
+      return (getVectorContaingIV(&IndVar)) ? &IndVar : nullptr;
+
+    // case 2:
+    // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+    // StepInst = IndVar + step
+    // cmp = IndVar < FinalValue
+    if (&IndVar == LatchCmpOp0 || &IndVar == LatchCmpOp1)
+      return (getVectorContaingIV(&IndVar)) ? &IndVar : nullptr;
+    }
+
+    return nullptr;
+}
+
+// Function to check if the source vector for the extract element
+// is the same as the vector containing the IV
+// If that can be proved the extract element can be removed
+static bool checkVecOp(Value* BinOperand, Value* VecIV)
+{
+   if (BinOperand == VecIV)
+      return true;
+
+   ShuffleVectorInst *shuffleInst = dyn_cast<ShuffleVectorInst>(BinOperand);
+   if (!shuffleInst)
+      return false;
+
+   // Handle patterns where:
+   //  first operand is the vector containing only the IV
+   //  Mask only selects the first element from the first source vector
+   // TODO: Add more patterns?
+   bool isFirstSrc = (shuffleInst->getOperand(0) == VecIV);
+   auto shuffleMask = shuffleInst->getShuffleMask();
+
+   return (isFirstSrc && shuffleInst->isZeroEltSplatMask(shuffleMask, shuffleMask.size()));
+}
+
+
+static bool ReplaceExtractInst(ConstantDataVector* values, unsigned Opcode,
+                               ExtractElementInst* elemInst, PHINode* IV)
+{
+    unsigned extIdx = cast<ConstantInt>(elemInst->getIndexOperand())->getZExtValue();
+    IRBuilder<> B(elemInst);
+    Value* extractedVal = values->getElementAsConstant(extIdx);
+    Value* newInst = nullptr;
+    bool changed = true;
+    switch(Opcode)
+    {
+       case Instruction::Add:
+          newInst = B.CreateAdd(IV, extractedVal);
+          break;
+       case Instruction::Sub:
+          newInst = B.CreateSub(IV, extractedVal);
+          break;
+       case Instruction::Mul:
+          newInst = B.CreateMul(IV, extractedVal);
+          break;
+       case Instruction::FMul:
+          newInst = B.CreateFMul(IV, extractedVal);
+          break;
+       case Instruction::UDiv:
+          newInst = B.CreateUDiv(IV, extractedVal);
+          break;
+       case Instruction::SDiv:
+          newInst = B.CreateSDiv(IV, extractedVal);
+          break;
+       default:
+          changed = false;
+    };
+
+    if (changed)
+    {
+      LLVM_DEBUG(dbgs() << "INDVARS: Rewriting Extract Element:\n" << *elemInst <<"\n"
+                    << " With :" << *newInst <<"\n");
+      elemInst->replaceAllUsesWith(newInst);
+      elemInst->eraseFromParent();
+    }
+    return changed;
+}
+
+
+bool IndVarSimplify::breakVectorOpsOnIVs(Loop *L) {
+
+  PHINode *IV = isPossibleIDVar(L);
+  if(!IV)
+    return false;
+
+  // Process the vector operation
+  ICmpInst *CmpInst = getLatchCmpInst(*L);
+  unsigned idx = isa<ExtractElementInst>(CmpInst->getOperand(0)) ? 0 : 1;
+  ExtractElementInst *exElem = cast<ExtractElementInst>(CmpInst->getOperand(idx));
+
+  // check if the idx is consant
+  if (!isa<ConstantInt>(exElem->getIndexOperand()))
+    return false;
+
+  // check if the extract element comes from a binary operation
+  BinaryOperator *SrcVec = dyn_cast<BinaryOperator>(exElem->getVectorOperand());
+  if (!SrcVec)
+    return false;
+
+  // if both operands of the binary op is not a constant data vector then let go
+  Value *BinOperand0 = SrcVec->getOperand(0);
+  Value *BinOperand1 = SrcVec->getOperand(1);
+  if(!isa<ConstantDataVector>(BinOperand0) && !isa<ConstantDataVector>(BinOperand1))
+   return false;
+
+  unsigned ConstVecIdx = isa<ConstantDataVector>(BinOperand0) ? 0 : 1;
+  Value* VecWithIV = getVectorContaingIV(IV);
+
+  if(!checkVecOp(SrcVec->getOperand(!ConstVecIdx), VecWithIV))
+    return false;
+
+  ConstantDataVector *DataVec = cast<ConstantDataVector>(SrcVec->getOperand(ConstVecIdx));
+  return ReplaceExtractInst(DataVec, SrcVec->getOpcode(), exElem, IV);
+
+}
+
+
+
 //===----------------------------------------------------------------------===//
 //  IndVarSimplify driver. Manage several subpasses of IV simplification.
 //===----------------------------------------------------------------------===//
@@ -1913,6 +2120,25 @@ bool IndVarSimplify::run(Loop *L) {
     return false;
 
   bool Changed = false;
+  // Breaks trivial operation on vector which contain just the Induction variable
+  // This pass looks for the following pattern in the IR
+      // %header:
+        // %phiVar = i32/float [%r1, %pre_header], [%extVal, %latch_block]
+        // more instructions
+        // %initVec = insertelement <k x i32/float> undef/posion, %phiVar
+        // %extendVec = shufflevector <k x i32/float> %initVec, <k x i32/float> undef/poison, <m x i32/float> zeroInitializer 
+        // more instructions
+        // %Op = binOp %extendVec, %constDataVec
+        // more instructions
+        // %extVal = extractElement %Op, constIdx
+        // %branchVal = icmp unaryOp %extVal, %loopBound
+        // branch %branchVal B1, B2
+  // Essentially the pass looks for a possible induction variable being extracted from a vector
+  // the vector should have a splat value that is equal to the IV
+  // replaces the extractelement on the IV (%extVal = extractElement %Op, constIdx) with a scalar as:
+        // binOp %extVal = binOp %phiVar, %constDataVec[constIdx]
+        // %branchVal = icmp unaryOp %extVal, %loopBound
+  Changed |= breakVectorOpsOnIVs(L);
   // If there are any floating-point recurrences, attempt to
   // transform them to use integer recurrences.
   Changed |= rewriteNonIntegerIVs(L);

``````````

</details>


https://github.com/llvm/llvm-project/pull/122248


More information about the llvm-commits mailing list