[llvm] [llvm][CodeExtractor][NFC] fix bug in parameter naming (PR #114237)

Tom Eccles via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 31 09:56:21 PDT 2024


https://github.com/tblah updated https://github.com/llvm/llvm-project/pull/114237

>From b02c12970ded39a4bc931906ba6de662ae810dc4 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Mon, 28 Oct 2024 19:22:19 +0000
Subject: [PATCH 1/2] [llvm][CodeExtractor] fix bug in parameter naming

The code extractor tries to apply the names of source input and output
values to function arguments. Not all input and output values get added
as arguments: some are instead placed inside of a struct passed to the
function. The existing renaming code skipped trying to set these
struct-packed arguments names (as there is no corresponding function
argument to rename), but it still incremented the iterator over the
function arguments. This could result in dereferencing an end iterator
if struct-packed inputs/outputs preceded non-struct-packed inputs/outputs.

This patch rewrites this loop to avoid the end iterator dereference.
---
 llvm/lib/Transforms/Utils/CodeExtractor.cpp | 24 +++++++++++----------
 1 file changed, 13 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index ed4ad15e5ab695..fa467cc72bd020 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -823,17 +823,22 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
 
   std::vector<Type *> ParamTy;
   std::vector<Type *> AggParamTy;
+  std::vector<std::tuple<unsigned, Value *>> NumberedInputs;
+  std::vector<std::tuple<unsigned, Value *>> NumberedOutputs;
   ValueSet StructValues;
   const DataLayout &DL = M->getDataLayout();
 
   // Add the types of the input values to the function's argument list
+  unsigned ArgNum = 0;
   for (Value *value : inputs) {
     LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
       AggParamTy.push_back(value->getType());
       StructValues.insert(value);
-    } else
+    } else {
       ParamTy.push_back(value->getType());
+      NumberedInputs.emplace_back(ArgNum++, value);
+    }
   }
 
   // Add the types of the output values to the function's argument list.
@@ -842,9 +847,11 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
     if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
       AggParamTy.push_back(output->getType());
       StructValues.insert(output);
-    } else
+    } else {
       ParamTy.push_back(
           PointerType::get(output->getType(), DL.getAllocaAddrSpace()));
+      NumberedOutputs.emplace_back(ArgNum++, output);
+    }
   }
 
   assert(
@@ -1053,15 +1060,10 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
   }
 
   // Set names for input and output arguments.
-  if (NumScalarParams) {
-    ScalarAI = newFunction->arg_begin();
-    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
-      if (!StructValues.contains(inputs[i]))
-        ScalarAI->setName(inputs[i]->getName());
-    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
-      if (!StructValues.contains(outputs[i]))
-        ScalarAI->setName(outputs[i]->getName() + ".out");
-  }
+  for (auto [i, argVal] : NumberedInputs)
+    newFunction->getArg(i)->setName(argVal->getName());
+  for (auto [i, argVal] : NumberedOutputs)
+    newFunction->getArg(i)->setName(argVal->getName() + ".out");
 
   // Rewrite branches to basic blocks outside of the loop to new dummy blocks
   // within the new function. This must be done before we lose track of which

>From 9073212ec68c89e9932acb3864bbd4990fd8cfb5 Mon Sep 17 00:00:00 2001
From: Tom Eccles <tom.eccles at arm.com>
Date: Thu, 31 Oct 2024 15:14:15 +0000
Subject: [PATCH 2/2] Add test

---
 .../Transforms/Utils/CodeExtractorTest.cpp    | 47 +++++++++++++++++++
 1 file changed, 47 insertions(+)

diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 046010716862f6..80c2a23a957963 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -556,6 +556,53 @@ TEST(CodeExtractor, PartialAggregateArgs) {
   EXPECT_FALSE(verifyFunction(*Func));
 }
 
+/// Regression test to ensure we don't crash trying to set the name of the ptr
+/// argument
+TEST(CodeExtractor, PartialAggregateArgs2) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"ir(
+    declare void @usei(i32)
+    declare void @usep(ptr)
+
+    define void @foo(i32 %a, i32 %b, ptr %p) {
+    entry:
+      br label %extract
+
+    extract:
+      call void @usei(i32 %a)
+      call void @usei(i32 %b)
+      call void @usep(ptr %p)
+      br label %exit
+
+    exit:
+      ret void
+    }
+  )ir",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
+
+  // Create the CodeExtractor with arguments aggregation enabled.
+  CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
+                   /* AggregateArgs */ true);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
+  CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+  CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
+  // Exclude the last input from the argument aggregate.
+  CE.excludeArgFromAggregate(Inputs[2]);
+
+  Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+  EXPECT_TRUE(Outlined);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
 TEST(CodeExtractor, OpenMPAggregateArgs) {
   LLVMContext Ctx;
   SMDiagnostic Err;



More information about the llvm-commits mailing list