[llvm] [ConstantFold] Remove recursion from ConstantFoldInsertValueInstruction (PR #88541)
Carlos Seo via llvm-commits
llvm-commits at lists.llvm.org
Fri Apr 12 10:35:01 PDT 2024
https://github.com/ceseo created https://github.com/llvm/llvm-project/pull/88541
Make the algorithm for ConstantFoldInsertValueInstruction non-recursive to keep memory usage at O(1).
Fixes #77877
>From 9754909b55ff12c70194ae928c3237d4f8a501d9 Mon Sep 17 00:00:00 2001
From: Carlos Eduardo Seo <carlos.seo at linaro.org>
Date: Wed, 10 Apr 2024 21:20:56 +0000
Subject: [PATCH] [ConstantFold] Remove recursion from
ConstantFoldInsertValueInstruction
Make the algorithm for ConstantFoldInsertValueInstruction non-recursive
to keep memory usage at O(1).
Fixes #77877
---
llvm/lib/IR/ConstantFold.cpp | 58 ++++++++++++++++++++++++++----------
1 file changed, 42 insertions(+), 16 deletions(-)
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index a766b1fe601823..3cf3ed6ebe010d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -590,30 +590,56 @@ Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg,
Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
Constant *Val,
ArrayRef<unsigned> Idxs) {
+
+ // FIXME: Although this non-recursive version of the algorithm still
+ // executes in O(n), it's very slow when compared to gfortran because
+ // the IR generated at constant folding is huge. See issue #77877.
+ // We should consider reviewing the entire constant folding code to
+ // improve performance.
+
// Base case: no indices, so replace the entire value.
if (Idxs.empty())
return Val;
- unsigned NumElts;
- if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
- NumElts = ST->getNumElements();
- else
- NumElts = cast<ArrayType>(Agg->getType())->getNumElements();
-
- SmallVector<Constant*, 32> Result;
- for (unsigned i = 0; i != NumElts; ++i) {
- Constant *C = Agg->getAggregateElement(i);
- if (!C) return nullptr;
+ // At each level, we will keep track of: Agg, idx, NumElts and Elements.
+ // This avoids storing the entire aggregate and keeps memory usage at O(1).
+ std::vector<
+ std::tuple<Constant *, unsigned, unsigned, SmallVector<Constant *, 32>>>
+ vector;
+
+ for (unsigned idx : Idxs) {
+ unsigned NumElts;
+ if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
+ NumElts = ST->getNumElements();
+ else
+ NumElts = cast<ArrayType>(Agg->getType())->getNumElements();
+
+ SmallVector<Constant *, 32> Elements(NumElts);
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *C = Agg->getAggregateElement(i);
+ if (!C)
+ return nullptr;
+ Elements[i] = C;
+ }
+ // Store the data we need.
+ vector.push_back({Agg, idx, NumElts, Elements});
+ Agg = Agg->getAggregateElement(idx);
+ if (!Agg)
+ return nullptr;
+ }
- if (Idxs[0] == i)
- C = ConstantFoldInsertValueInstruction(C, Val, Idxs.slice(1));
+ // Build the result from the data we stored in the vector.
+ for (auto it = vector.rbegin(); it != vector.rend(); ++it) {
+ std::get<3>(*it)[std::get<1>(*it)] = (it == vector.rbegin()) ? Val : Agg;
- Result.push_back(C);
+ if (StructType *ST = dyn_cast<StructType>(std::get<0>(*it)->getType()))
+ Agg = ConstantStruct::get(ST, std::get<3>(*it));
+ else
+ Agg = ConstantArray::get(cast<ArrayType>(std::get<0>(*it)->getType()),
+ std::get<3>(*it));
}
- if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
- return ConstantStruct::get(ST, Result);
- return ConstantArray::get(cast<ArrayType>(Agg->getType()), Result);
+ return Agg;
}
Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
More information about the llvm-commits
mailing list