[llvm] [llvm][CodeExtractor] fix bug in parameter naming (PR #114237)

Tom Eccles via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 30 07:19:20 PDT 2024


https://github.com/tblah created https://github.com/llvm/llvm-project/pull/114237

The code extractor tries to apply the names of source input and output values to function arguments. Not all input and output values get added as arguments: some are instead placed inside of a struct passed to the function. The existing renaming code skipped trying to set these struct-packed arguments names (as there is no corresponding function argument to rename), but it still incremented the iterator over the function arguments. This could result in dereferencing an end iterator if struct-packed inputs/outputs preceded non-struct-packed inputs/outputs.

This patch rewrites this loop to avoid the end iterator dereference.

---

I haven't contributed to llvm-project/llvm before and wasn't sure how best to test this. I came across the bug while working on the OpenMP MLIR dialect's lowering to LLVM IR

>From b02c12970ded39a4bc931906ba6de662ae810dc4 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 28 Oct 2024 19:22:19 +0000
Subject: [PATCH] [llvm][CodeExtractor] fix bug in parameter naming

The code extractor tries to apply the names of source input and output
values to function arguments. Not all input and output values get added
as arguments: some are instead placed inside of a struct passed to the
function. The existing renaming code skipped trying to set these
struct-packed arguments names (as there is no corresponding function
argument to rename), but it still incremented the iterator over the
function arguments. This could result in dereferencing an end iterator
if struct-packed inputs/outputs preceded non-struct-packed inputs/outputs.

This patch rewrites this loop to avoid the end iterator dereference.
---
 llvm/lib/Transforms/Utils/CodeExtractor.cpp | 24 +++++++++++----------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index ed4ad15e5ab695..fa467cc72bd020 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -823,17 +823,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
 
   std::vector<Type *> ParamTy;
   std::vector<Type *> AggParamTy;
+  std::vector<std::tuple<unsigned, Value *>> NumberedInputs;
+  std::vector<std::tuple<unsigned, Value *>> NumberedOutputs;
   ValueSet StructValues;
   const DataLayout &DL = M->getDataLayout();
 
   // Add the types of the input values to the function's argument list
+  unsigned ArgNum = 0;
   for (Value *value : inputs) {
     LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
       AggParamTy.push_back(value->getType());
       StructValues.insert(value);
-    } else
+    } else {
       ParamTy.push_back(value->getType());
+      NumberedInputs.emplace_back(ArgNum++, value);
+    }
   }
 
   // Add the types of the output values to the function's argument list.
@@ -842,9 +847,11 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
       AggParamTy.push_back(output->getType());
       StructValues.insert(output);
-    } else
+    } else {
       ParamTy.push_back(
           PointerType::get(output->getType(), DL.getAllocaAddrSpace()));
+      NumberedOutputs.emplace_back(ArgNum++, output);
+    }
   }
 
   assert(
@@ -1053,15 +1060,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   }
 
   // Set names for input and output arguments.
-  if (NumScalarParams) {
-    ScalarAI = newFunction->arg_begin();
-    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
-      if (!StructValues.contains(inputs[i]))
-        ScalarAI->setName(inputs[i]->getName());
-    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
-      if (!StructValues.contains(outputs[i]))
-        ScalarAI->setName(outputs[i]->getName() + ".out");
-  }
+  for (auto [i, argVal] : NumberedInputs)
+    newFunction->getArg(i)->setName(argVal->getName());
+  for (auto [i, argVal] : NumberedOutputs)
+    newFunction->getArg(i)->setName(argVal->getName() + ".out");
 
   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
   // within the new function. This must be done before we lose track of which



More information about the llvm-commits mailing list