[llvm] Adding change to IndVarSimplify pass to optimize IVs stuck in trivial vector operations (PR #122248)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 9 02:46:58 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: None (anilavakundu007)
<details>
<summary>Changes</summary>
This patch helps LLVM to unroll loop patterns like:
```
typedef int ivec2 __attribute__((ext_vector_type(2)));
int data0;
void fn()
{
int u_xlati59 = 1;
while (true)
{
bool u_xlatb12 = u_xlati59 >= 3;
if (u_xlatb12)
{
break;
}
ivec2 u_xlati12 = (ivec2){u_xlati59, u_xlati59} + (ivec2){-1, 1};
data0 += u_xlati12.x;
u_xlati59 = u_xlati12.y;
}
}
```
In the existing method the Induction variable fails as SCEV fails and the loop cannot be unrolled. With this patch the loop gets unrolled and the final code generated is a lot smaller.
linked issue: https://github.com/llvm/llvm-project/issues/121742
---
Full diff: https://github.com/llvm/llvm-project/pull/122248.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/IndVarSimplify.cpp (+226)
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
index 8a3e0bc3eb9712..37361080a1d766 100644
--- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -138,6 +138,7 @@ class IndVarSimplify {
bool RunUnswitching = false;
bool handleFloatingPointIV(Loop *L, PHINode *PH);
+ bool breakVectorOpsOnIVs(Loop *L);
bool rewriteNonIntegerIVs(Loop *L);
bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI);
@@ -1891,6 +1892,212 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
return Changed;
}
+// Get the latch condition instruction.
+static ICmpInst *getLatchCmpInst(const Loop &L) {
+ if (BasicBlock *Latch = L.getLoopLatch())
+ if (BranchInst *BI = dyn_cast_or_null<BranchInst>(Latch->getTerminator()))
+ if (BI->isConditional())
+ return dyn_cast<ICmpInst>(BI->getCondition());
+
+ return nullptr;
+}
+
+// get the vector which contains the IV
+// This function will return the vector which only contains the IV
+static Value* getVectorContaingIV(PHINode* IV)
+{
+
+ for (User* use: IV->users())
+ {
+ // check if the IV is a part of any vector or not
+ InsertElementInst *vecInsert = dyn_cast<InsertElementInst>(use);
+ if(!vecInsert)
+ continue;
+ // We need to check if this vector contains only the IV
+ FixedVectorType* vecTy = dyn_cast<FixedVectorType>(vecInsert->getType());
+ if (!vecTy)
+ continue;
+ // if it is vector of a single element with the IV as an element
+ if(vecTy->getNumElements() == 1)
+ return vecInsert;
+ // if we have larger vectors
+ if(vecTy->getNumElements() > 1)
+ {
+ //check the vector we are inserting into an empty vector
+ Value* srcVec = vecInsert->getOperand(0);
+ if (!isa<UndefValue>(srcVec) || !isa<PoisonValue>(srcVec))
+ continue;
+ //check if we are later inserting more elements into the vector or not
+ InsertElementInst *insertOtherVal = nullptr;
+ for (User* vecUse: vecInsert->users())
+ {
+ insertOtherVal = dyn_cast<InsertElementInst>(vecUse);
+ if(insertOtherVal)
+ break;
+ }
+ if(insertOtherVal)
+ continue;
+
+ // vector contains only IV
+ return vecInsert;
+ }
+ }
+ return nullptr;
+}
+
+
+// check if a PHINode is a possbile Induction variable or not
+// The existing functions do not work as SCEV fails
+// This happens when the IV is stuck in a vector operation
+// And the incoming value comes from an extract element instruction
+static PHINode* isPossibleIDVar(Loop *L)
+{
+ BasicBlock *Header = L->getHeader();
+ assert(Header && "Expected a valid loop header");
+ ICmpInst *CmpInst = getLatchCmpInst(*L);
+ if (!CmpInst)
+ return nullptr;
+
+ Value *LatchCmpOp0 = CmpInst->getOperand(0);
+ Value *LatchCmpOp1 = CmpInst->getOperand(1);
+
+ // check if the compare instruction operands are Extract element instructions
+ // We do this as we expect the extract element to be reading from the vector which has the IV
+ // If none of the operands are extract element instructions we do not proceed
+ if (!isa<ExtractElementInst>(LatchCmpOp0) && !isa<ExtractElementInst>(LatchCmpOp1))
+ return nullptr;
+
+ for (PHINode &IndVar : Header->phis())
+ {
+ if (!IndVar.getType()->isIntegerTy() && !IndVar.getType()->isFloatTy())
+ continue;
+
+ BasicBlock *Latch = L->getLoopLatch();
+ Value *StepInst = IndVar.getIncomingValueForBlock(Latch);
+
+ // case 1:
+ // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+ // StepInst = IndVar + step
+ // cmp = StepInst < FinalValue
+ if (StepInst == LatchCmpOp0 || StepInst == LatchCmpOp1)
+ return (getVectorContaingIV(&IndVar)) ? &IndVar : nullptr;
+
+ // case 2:
+ // IndVar = phi[{InitialValue, preheader}, {StepInst, latch}]
+ // StepInst = IndVar + step
+ // cmp = IndVar < FinalValue
+ if (&IndVar == LatchCmpOp0 || &IndVar == LatchCmpOp1)
+ return (getVectorContaingIV(&IndVar)) ? &IndVar : nullptr;
+ }
+
+ return nullptr;
+}
+
+// Function to check if the source vector for the extract element
+// is the same as the vector containing the IV
+// If that can be proved the extract element can be removed
+static bool checkVecOp(Value* BinOperand, Value* VecIV)
+{
+ if (BinOperand == VecIV)
+ return true;
+
+ ShuffleVectorInst *shuffleInst = dyn_cast<ShuffleVectorInst>(BinOperand);
+ if (!shuffleInst)
+ return false;
+
+ // Handle patterns where:
+ // first operand is the vector containing only the IV
+ // Mask only selects the first element from the first source vector
+ // TODO: Add more patterns?
+ bool isFirstSrc = (shuffleInst->getOperand(0) == VecIV);
+ auto shuffleMask = shuffleInst->getShuffleMask();
+
+ return (isFirstSrc && shuffleInst->isZeroEltSplatMask(shuffleMask, shuffleMask.size()));
+}
+
+
+static bool ReplaceExtractInst(ConstantDataVector* values, unsigned Opcode,
+ ExtractElementInst* elemInst, PHINode* IV)
+{
+ unsigned extIdx = cast<ConstantInt>(elemInst->getIndexOperand())->getZExtValue();
+ IRBuilder<> B(elemInst);
+ Value* extractedVal = values->getElementAsConstant(extIdx);
+ Value* newInst = nullptr;
+ bool changed = true;
+ switch(Opcode)
+ {
+ case Instruction::Add:
+ newInst = B.CreateAdd(IV, extractedVal);
+ break;
+ case Instruction::Sub:
+ newInst = B.CreateSub(IV, extractedVal);
+ break;
+ case Instruction::Mul:
+ newInst = B.CreateMul(IV, extractedVal);
+ break;
+ case Instruction::FMul:
+ newInst = B.CreateFMul(IV, extractedVal);
+ break;
+ case Instruction::UDiv:
+ newInst = B.CreateUDiv(IV, extractedVal);
+ break;
+ case Instruction::SDiv:
+ newInst = B.CreateSDiv(IV, extractedVal);
+ break;
+ default:
+ changed = false;
+ };
+
+ if (changed)
+ {
+ LLVM_DEBUG(dbgs() << "INDVARS: Rewriting Extract Element:\n" << *elemInst <<"\n"
+ << " With :" << *newInst <<"\n");
+ elemInst->replaceAllUsesWith(newInst);
+ elemInst->eraseFromParent();
+ }
+ return changed;
+}
+
+
+bool IndVarSimplify::breakVectorOpsOnIVs(Loop *L) {
+
+ PHINode *IV = isPossibleIDVar(L);
+ if(!IV)
+ return false;
+
+ // Process the vector operation
+ ICmpInst *CmpInst = getLatchCmpInst(*L);
+ unsigned idx = isa<ExtractElementInst>(CmpInst->getOperand(0)) ? 0 : 1;
+ ExtractElementInst *exElem = cast<ExtractElementInst>(CmpInst->getOperand(idx));
+
+ // check if the idx is consant
+ if (!isa<ConstantInt>(exElem->getIndexOperand()))
+ return false;
+
+ // check if the extract element comes from a binary operation
+ BinaryOperator *SrcVec = dyn_cast<BinaryOperator>(exElem->getVectorOperand());
+ if (!SrcVec)
+ return false;
+
+ // if both operands of the binary op is not a constant data vector then let go
+ Value *BinOperand0 = SrcVec->getOperand(0);
+ Value *BinOperand1 = SrcVec->getOperand(1);
+ if(!isa<ConstantDataVector>(BinOperand0) && !isa<ConstantDataVector>(BinOperand1))
+ return false;
+
+ unsigned ConstVecIdx = isa<ConstantDataVector>(BinOperand0) ? 0 : 1;
+ Value* VecWithIV = getVectorContaingIV(IV);
+
+ if(!checkVecOp(SrcVec->getOperand(!ConstVecIdx), VecWithIV))
+ return false;
+
+ ConstantDataVector *DataVec = cast<ConstantDataVector>(SrcVec->getOperand(ConstVecIdx));
+ return ReplaceExtractInst(DataVec, SrcVec->getOpcode(), exElem, IV);
+
+}
+
+
+
//===----------------------------------------------------------------------===//
// IndVarSimplify driver. Manage several subpasses of IV simplification.
//===----------------------------------------------------------------------===//
@@ -1913,6 +2120,25 @@ bool IndVarSimplify::run(Loop *L) {
return false;
bool Changed = false;
+ // Breaks trivial operation on vector which contain just the Induction variable
+ // This pass looks for the following pattern in the IR
+ // %header:
+ // %phiVar = i32/float [%r1, %pre_header], [%extVal, %latch_block]
+ // more instructions
+ // %initVec = insertelement <k x i32/float> undef/posion, %phiVar
+ // %extendVec = shufflevector <k x i32/float> %initVec, <k x i32/float> undef/poison, <m x i32/float> zeroInitializer
+ // more instructions
+ // %Op = binOp %extendVec, %constDataVec
+ // more instructions
+ // %extVal = extractElement %Op, constIdx
+ // %branchVal = icmp unaryOp %extVal, %loopBound
+ // branch %branchVal B1, B2
+ // Essentially the pass looks for a possible induction variable being extracted from a vector
+ // the vector should have a splat value that is equal to the IV
+ // replaces the extractelement on the IV (%extVal = extractElement %Op, constIdx) with a scalar as:
+ // binOp %extVal = binOp %phiVar, %constDataVec[constIdx]
+ // %branchVal = icmp unaryOp %extVal, %loopBound
+ Changed |= breakVectorOpsOnIVs(L);
// If there are any floating-point recurrences, attempt to
// transform them to use integer recurrences.
Changed |= rewriteNonIntegerIVs(L);
``````````
</details>
https://github.com/llvm/llvm-project/pull/122248
More information about the llvm-commits
mailing list