[llvm] Fix CodeExtractor when using aggregated arguments. (PR #94294)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 3 17:09:35 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Rodrigo Rocha (rcorcs)

<details>
<summary>Changes</summary>

The LLVM CodeExtractor utility class is used for extracting a list of basic blocks into a separate function (outlining). The input and output arguments can be passed to the extracted function either as a list of separate arguments or aggregated together into a single argument. It seems that the aggregated arguments is currently not widely used by the LLVM passes, neither are they thoroughly tested. I have encountered a bug in the CodeExtractor for when using aggregated arguments to extract a function that requires to update output values. The code was counting aggregated input arguments as both aggregated and scalar (individual) arguments, causing an out of bounds issue when trying to use the number of scalar input arguments as an offset to access the output arguments. It always needs to count both scalar and aggregated arguments, as an extracted functions can also have mixed arguments. However, when forcing all arguments to be passed as an aggregated structure, both the ScalarInputArgNo and the NumAggregatedInputs had the same numbers in the CodeExtractor::emitCallAndSwitchStatement, even though ScalarInputArgNo should be zero.

The new unit test AggInputOutputMonitoring exercises this issue. This unit test uses the same example as the InputOutputMonitoring one, but it forces the CodeExtractor to use aggregated arguments. This new unit test, AggInputOutputMonitoring, then ensures that the outlined function has a single aggregated argument and that the update of the output value is performed correctly.

This bug fix ensures that the new unit test works correctly while preserving the correctness of all other tests.

---
Full diff: https://github.com/llvm/llvm-project/pull/94294.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+1-1) 
- (modified) llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (+68) 


``````````diff
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index f2672b8e9118f..e4965a1788dbd 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1173,8 +1173,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
       params.push_back(input);
       if (input->isSwiftError())
         SwiftErrorArgs.push_back(ScalarInputArgNo);
+      ++ScalarInputArgNo;
     }
-    ++ScalarInputArgNo;
   }
 
   // Create allocas for the outputs
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 046010716862f..ce72d345b7bb5 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -136,6 +136,74 @@ TEST(CodeExtractor, InputOutputMonitoring) {
   EXPECT_FALSE(verifyFunction(*Func));
 }
 
+TEST(CodeExtractor, AggInputOutputMonitoring) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    define i32 @foo(i32 %x, i32 %y, i32 %z) {
+    header:
+      %0 = icmp ugt i32 %x, %y
+      br i1 %0, label %body1, label %body2
+
+    body1:
+      %1 = add i32 %z, 2
+      br label %notExtracted
+
+    body2:
+      %2 = mul i32 %z, 7
+      br label %notExtracted
+
+    notExtracted:
+      %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
+      %4 = add i32 %3, %x
+      ret i32 %4
+    }
+  )invalid",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
+                                          getBlockByName(Func, "body1"),
+                                          getBlockByName(Func, "body2")};
+
+  CodeExtractor CE(Candidates, nullptr, true);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs;
+  Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+  EXPECT_TRUE(Outlined);
+  //Ensure that the outlined function has a single argument with
+  //the input and output values in an aggregated structure.
+  EXPECT_EQ(Outlined->arg_size(), 1u);
+
+  EXPECT_EQ(Inputs.size(), 3u);
+  EXPECT_EQ(Inputs[0], Func->getArg(2));
+  EXPECT_EQ(Inputs[1], Func->getArg(0));
+  EXPECT_EQ(Inputs[2], Func->getArg(1));
+
+  EXPECT_EQ(Outputs.size(), 1u);
+  //The output value must be stored in the appropriate element inside the
+  //aggregated structure.
+  GetElementPtrInst *GEP = cast<GetElementPtrInst>(Outlined->getArg(0)->user_back());
+  APInt Offset(M->getDataLayout().getMaxIndexSizeInBits(), 0);
+  EXPECT_TRUE(GEP->accumulateConstantOffset(M->getDataLayout(), Offset));
+  EXPECT_EQ(Offset, 3u*4u); //Fourth i32 element, with 4-bytes each.
+  StoreInst *SI = cast<StoreInst>(GEP->user_back());
+  Value *OutputVal = SI->getValueOperand();
+  EXPECT_EQ(Outputs[0], OutputVal);
+  BasicBlock *Exit = getBlockByName(Func, "notExtracted");
+  BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
+  // Ensure that PHI in exit block has only one incoming value (from code
+  // replacer block).
+  EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
+  // Ensure that there is a PHI in outlined function with 2 incoming values.
+  EXPECT_TRUE(ExitSplit &&
+              cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
 TEST(CodeExtractor, ExitBlockOrderingPhis) {
   LLVMContext Ctx;
   SMDiagnostic Err;

``````````

</details>


https://github.com/llvm/llvm-project/pull/94294


More information about the llvm-commits mailing list