[llvm] [ConstantFold] Remove recursion from ConstantFoldInsertValueInstruction (PR #88541)

Carlos Seo via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 12 10:35:01 PDT 2024


https://github.com/ceseo created https://github.com/llvm/llvm-project/pull/88541

Make the algorithm for ConstantFoldInsertValueInstruction non-recursive to keep memory usage at O(1).

Fixes #77877

>From 9754909b55ff12c70194ae928c3237d4f8a501d9 Mon Sep 17 00:00:00 2001
From: Carlos Eduardo Seo <carlos.seo at linaro.org>
Date: Wed, 10 Apr 2024 21:20:56 +0000
Subject: [PATCH] [ConstantFold] Remove recursion from
 ConstantFoldInsertValueInstruction

Make the algorithm for ConstantFoldInsertValueInstruction non-recursive
to keep memory usage at O(1).

Fixes #77877
---
 llvm/lib/IR/ConstantFold.cpp | 58 ++++++++++++++++++++++++++----------
 1 file changed, 42 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index a766b1fe601823..3cf3ed6ebe010d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -590,30 +590,56 @@ Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg,
 Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
                                                    Constant *Val,
                                                    ArrayRef<unsigned> Idxs) {
+
+  // FIXME: Although this non-recursive version of the algorithm still
+  // executes in O(n), it's very slow when compared to gfortran because
+  // the IR generated at constant folding is huge. See issue #77877.
+  // We should consider reviewing the entire constant folding code to
+  // improve performance.
+
   // Base case: no indices, so replace the entire value.
   if (Idxs.empty())
     return Val;
 
-  unsigned NumElts;
-  if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
-    NumElts = ST->getNumElements();
-  else
-    NumElts = cast<ArrayType>(Agg->getType())->getNumElements();
-
-  SmallVector<Constant*, 32> Result;
-  for (unsigned i = 0; i != NumElts; ++i) {
-    Constant *C = Agg->getAggregateElement(i);
-    if (!C) return nullptr;
+  // At each level, we will keep track of: Agg, idx, NumElts and Elements.
+  // This avoids storing the entire aggregate and keeps memory usage at O(1).
+  std::vector<
+      std::tuple<Constant *, unsigned, unsigned, SmallVector<Constant *, 32>>>
+      vector;
+
+  for (unsigned idx : Idxs) {
+    unsigned NumElts;
+    if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
+      NumElts = ST->getNumElements();
+    else
+      NumElts = cast<ArrayType>(Agg->getType())->getNumElements();
+
+    SmallVector<Constant *, 32> Elements(NumElts);
+    for (unsigned i = 0; i != NumElts; ++i) {
+      Constant *C = Agg->getAggregateElement(i);
+      if (!C)
+        return nullptr;
+      Elements[i] = C;
+    }
+    // Store the data we need.
+    vector.push_back({Agg, idx, NumElts, Elements});
+    Agg = Agg->getAggregateElement(idx);
+    if (!Agg)
+      return nullptr;
+  }
 
-    if (Idxs[0] == i)
-      C = ConstantFoldInsertValueInstruction(C, Val, Idxs.slice(1));
+  // Build the result from the data we stored in the vector.
+  for (auto it = vector.rbegin(); it != vector.rend(); ++it) {
+    std::get<3>(*it)[std::get<1>(*it)] = (it == vector.rbegin()) ? Val : Agg;
 
-    Result.push_back(C);
+    if (StructType *ST = dyn_cast<StructType>(std::get<0>(*it)->getType()))
+      Agg = ConstantStruct::get(ST, std::get<3>(*it));
+    else
+      Agg = ConstantArray::get(cast<ArrayType>(std::get<0>(*it)->getType()),
+                               std::get<3>(*it));
   }
 
-  if (StructType *ST = dyn_cast<StructType>(Agg->getType()))
-    return ConstantStruct::get(ST, Result);
-  return ConstantArray::get(cast<ArrayType>(Agg->getType()), Result);
+  return Agg;
 }
 
 Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {



More information about the llvm-commits mailing list