[llvm] [ConstantFold] Remove recursion from ConstantFoldInsertValueInstruction (PR #88541)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 12 10:35:33 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-ir
Author: Carlos Seo (ceseo)
<details>
<summary>Changes</summary>
Make the algorithm for ConstantFoldInsertValueInstruction non-recursive to keep memory usage at O(1).
Fixes #<!-- -->77877
---
Full diff: https://github.com/llvm/llvm-project/pull/88541.diff
1 Files Affected:
- (modified) llvm/lib/IR/ConstantFold.cpp (+42-16)
``````````diff
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index a766b1fe601823..3cf3ed6ebe010d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -590,30 +590,56 @@ Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg,
Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
Constant *Val,
ArrayRef<unsigned> Idxs) {
+
+ // FIXME: Although this non-recursive version of the algorithm still
+ // executes in O(n), it's very slow when compared to gfortran because
+ // the IR generated at constant folding is huge. See issue #77877.
+ // We should consider reviewing the entire constant folding code to
+ // improve performance.
+
// Base case: no indices, so replace the entire value.
if (Idxs.empty())
return Val;
- unsigned NumElts;
- if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
- NumElts = ST->getNumElements();
- else
- NumElts = cast<ArrayType>(Agg->getType())->getNumElements();
-
- SmallVector<Constant*, 32> Result;
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *C = Agg->getAggregateElement(i);
- if (!C) return nullptr;
+ // At each level, we will keep track of: Agg, idx, NumElts and Elements.
+ // This avoids storing the entire aggregate and keeps memory usage at O(1).
+ std::vector<
+ std::tuple<Constant *, unsigned, unsigned, SmallVector<Constant *, 32>>>
+ vector;
+
+ for (unsigned idx : Idxs) {
+ unsigned NumElts;
+ if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
+ NumElts = ST->getNumElements();
+ else
+ NumElts = cast<ArrayType>(Agg->getType())->getNumElements();
+
+ SmallVector<Constant *, 32> Elements(NumElts);
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *C = Agg->getAggregateElement(i);
+ if (!C)
+ return nullptr;
+ Elements[i] = C;
+ }
+ // Store the data we need.
+ vector.push_back({Agg, idx, NumElts, Elements});
+ Agg = Agg->getAggregateElement(idx);
+ if (!Agg)
+ return nullptr;
+ }
- if (Idxs[0] == i)
- C = ConstantFoldInsertValueInstruction(C, Val, Idxs.slice(1));
+ // Build the result from the data we stored in the vector.
+ for (auto it = vector.rbegin(); it != vector.rend(); ++it) {
+ std::get<3>(*it)[std::get<1>(*it)] = (it == vector.rbegin()) ? Val : Agg;
- Result.push_back(C);
+ if (StructType *ST = dyn_cast<StructType>(std::get<0>(*it)->getType()))
+ Agg = ConstantStruct::get(ST, std::get<3>(*it));
+ else
+ Agg = ConstantArray::get(cast<ArrayType>(std::get<0>(*it)->getType()),
+ std::get<3>(*it));
}
- if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
- return ConstantStruct::get(ST, Result);
- return ConstantArray::get(cast<ArrayType>(Agg->getType()), Result);
+ return Agg;
}
Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/88541
More information about the llvm-commits
mailing list