[llvm] Fix CodeExtractor when using aggregated arguments. (PR #94294)

Rodrigo Rocha via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 4 17:32:30 PDT 2024


https://github.com/rcorcs updated https://github.com/llvm/llvm-project/pull/94294

>From 94e7095320c91233f3873cb0998fb8ca03866aa3 Mon Sep 17 00:00:00 2001
From: Rodrigo Rocha <rcor.cs at gmail.com>
Date: Tue, 4 Jun 2024 01:04:29 +0100
Subject: [PATCH 1/2] Fix CodeExtractor when using aggregated arguments.

---
 llvm/lib/Transforms/Utils/CodeExtractor.cpp   |  2 +-
 .../Transforms/Utils/CodeExtractorTest.cpp    | 68 +++++++++++++++++++
 2 files changed, 69 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index f2672b8e9118f..e4965a1788dbd 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -1173,8 +1173,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
       params.push_back(input);
       if (input->isSwiftError())
         SwiftErrorArgs.push_back(ScalarInputArgNo);
+      ++ScalarInputArgNo;
     }
-    ++ScalarInputArgNo;
   }
 
   // Create allocas for the outputs
diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index 046010716862f..ce72d345b7bb5 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -136,6 +136,74 @@ TEST(CodeExtractor, InputOutputMonitoring) {
   EXPECT_FALSE(verifyFunction(*Func));
 }
 
+TEST(CodeExtractor, AggInputOutputMonitoring) {
+  LLVMContext Ctx;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
+    define i32 @foo(i32 %x, i32 %y, i32 %z) {
+    header:
+      %0 = icmp ugt i32 %x, %y
+      br i1 %0, label %body1, label %body2
+
+    body1:
+      %1 = add i32 %z, 2
+      br label %notExtracted
+
+    body2:
+      %2 = mul i32 %z, 7
+      br label %notExtracted
+
+    notExtracted:
+      %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
+      %4 = add i32 %3, %x
+      ret i32 %4
+    }
+  )invalid",
+                                                Err, Ctx));
+
+  Function *Func = M->getFunction("foo");
+  SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
+                                          getBlockByName(Func, "body1"),
+                                          getBlockByName(Func, "body2")};
+
+  CodeExtractor CE(Candidates, nullptr, true);
+  EXPECT_TRUE(CE.isEligible());
+
+  CodeExtractorAnalysisCache CEAC(*Func);
+  SetVector<Value *> Inputs, Outputs;
+  Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
+  EXPECT_TRUE(Outlined);
+  //Ensure that the outlined function has a single argument with
+  //the input and output values in an aggregated structure.
+  EXPECT_EQ(Outlined->arg_size(), 1u);
+
+  EXPECT_EQ(Inputs.size(), 3u);
+  EXPECT_EQ(Inputs[0], Func->getArg(2));
+  EXPECT_EQ(Inputs[1], Func->getArg(0));
+  EXPECT_EQ(Inputs[2], Func->getArg(1));
+
+  EXPECT_EQ(Outputs.size(), 1u);
+  //The output value must be stored in the appropriate element inside the
+  //aggregated structure.
+  GetElementPtrInst *GEP = cast<GetElementPtrInst>(Outlined->getArg(0)->user_back());
+  APInt Offset(M->getDataLayout().getMaxIndexSizeInBits(), 0);
+  EXPECT_TRUE(GEP->accumulateConstantOffset(M->getDataLayout(), Offset));
+  EXPECT_EQ(Offset, 3u*4u); //Fourth i32 element, with 4-bytes each.
+  StoreInst *SI = cast<StoreInst>(GEP->user_back());
+  Value *OutputVal = SI->getValueOperand();
+  EXPECT_EQ(Outputs[0], OutputVal);
+  BasicBlock *Exit = getBlockByName(Func, "notExtracted");
+  BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
+  // Ensure that PHI in exit block has only one incoming value (from code
+  // replacer block).
+  EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
+  // Ensure that there is a PHI in outlined function with 2 incoming values.
+  EXPECT_TRUE(ExitSplit &&
+              cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
+  EXPECT_FALSE(verifyFunction(*Outlined));
+  EXPECT_FALSE(verifyFunction(*Func));
+}
+
 TEST(CodeExtractor, ExitBlockOrderingPhis) {
   LLVMContext Ctx;
   SMDiagnostic Err;

>From ebda416577c59482284e1748ae49d9994adf390d Mon Sep 17 00:00:00 2001
From: Rodrigo Rocha <rcor.cs at gmail.com>
Date: Wed, 5 Jun 2024 01:30:26 +0100
Subject: [PATCH 2/2] Formatted with clang-format.

---
 .../Transforms/Utils/CodeExtractorTest.cpp    | 62 +++++++++----------
 1 file changed, 30 insertions(+), 32 deletions(-)

diff --git a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
index ce72d345b7bb5..37ea65ff9a349 100644
--- a/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
+++ b/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp
@@ -7,8 +7,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Utils/CodeExtractor.h"
-#include "llvm/AsmParser/Parser.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Dominators.h"
@@ -56,9 +56,9 @@ TEST(CodeExtractor, ExitStub) {
                                                 Err, Ctx));
 
   Function *Func = M->getFunction("foo");
-  SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
-                                           getBlockByName(Func, "body1"),
-                                           getBlockByName(Func, "body2") };
+  SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
+                                          getBlockByName(Func, "body1"),
+                                          getBlockByName(Func, "body2")};
 
   CodeExtractor CE(Candidates);
   EXPECT_TRUE(CE.isEligible());
@@ -173,8 +173,8 @@ TEST(CodeExtractor, AggInputOutputMonitoring) {
   SetVector<Value *> Inputs, Outputs;
   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
   EXPECT_TRUE(Outlined);
-  //Ensure that the outlined function has a single argument with
-  //the input and output values in an aggregated structure.
+  // Ensure that the outlined function has a single argument with
+  // the input and output values in an aggregated structure.
   EXPECT_EQ(Outlined->arg_size(), 1u);
 
   EXPECT_EQ(Inputs.size(), 3u);
@@ -183,12 +183,13 @@ TEST(CodeExtractor, AggInputOutputMonitoring) {
   EXPECT_EQ(Inputs[2], Func->getArg(1));
 
   EXPECT_EQ(Outputs.size(), 1u);
-  //The output value must be stored in the appropriate element inside the
-  //aggregated structure.
-  GetElementPtrInst *GEP = cast<GetElementPtrInst>(Outlined->getArg(0)->user_back());
+  // The output value must be stored in the appropriate element inside the
+  // aggregated structure.
+  GetElementPtrInst *GEP =
+      cast<GetElementPtrInst>(Outlined->getArg(0)->user_back());
   APInt Offset(M->getDataLayout().getMaxIndexSizeInBits(), 0);
   EXPECT_TRUE(GEP->accumulateConstantOffset(M->getDataLayout(), Offset));
-  EXPECT_EQ(Offset, 3u*4u); //Fourth i32 element, with 4-bytes each.
+  EXPECT_EQ(Offset, 3u * 4u); // Fourth i32 element, with 4-bytes each.
   StoreInst *SI = cast<StoreInst>(GEP->user_back());
   Value *OutputVal = SI->getValueOperand();
   EXPECT_EQ(Outputs[0], OutputVal);
@@ -232,9 +233,9 @@ TEST(CodeExtractor, ExitBlockOrderingPhis) {
   )invalid",
                                                 Err, Ctx));
   Function *Func = M->getFunction("foo");
-  SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
-                                           getBlockByName(Func, "test1"),
-                                           getBlockByName(Func, "test") };
+  SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "test0"),
+                                          getBlockByName(Func, "test1"),
+                                          getBlockByName(Func, "test")};
 
   CodeExtractor CE(Candidates);
   EXPECT_TRUE(CE.isEligible());
@@ -289,9 +290,9 @@ TEST(CodeExtractor, ExitBlockOrdering) {
   )invalid",
                                                 Err, Ctx));
   Function *Func = M->getFunction("foo");
-  SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
-                                           getBlockByName(Func, "test1"),
-                                           getBlockByName(Func, "test") };
+  SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "test0"),
+                                          getBlockByName(Func, "test1"),
+                                          getBlockByName(Func, "test")};
 
   CodeExtractor CE(Candidates);
   EXPECT_TRUE(CE.isEligible());
@@ -344,13 +345,12 @@ TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
       ret i32 %1
     }
-  )invalid", Err, Ctx));
+  )invalid",
+                                                Err, Ctx));
 
   Function *Func = M->getFunction("foo");
   SmallVector<BasicBlock *, 2> ExtractedBlocks{
-    getBlockByName(Func, "extracted1"),
-    getBlockByName(Func, "extracted2")
-  };
+      getBlockByName(Func, "extracted1"), getBlockByName(Func, "extracted2")};
 
   CodeExtractor CE(ExtractedBlocks);
   EXPECT_TRUE(CE.isEligible());
@@ -363,9 +363,9 @@ TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
   // Ensure that PHIs in exits are not splitted (since that they have only one
   // incoming value from extracted region).
   EXPECT_TRUE(Exit1 &&
-          cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
+              cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
   EXPECT_TRUE(Exit2 &&
-          cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
+              cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
   EXPECT_FALSE(verifyFunction(*Outlined));
   EXPECT_FALSE(verifyFunction(*Func));
 }
@@ -410,9 +410,10 @@ TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
         unreachable
     }
-  )invalid", Err, Ctx));
+  )invalid",
+                                                Err, Ctx));
 
-	if (!M) {
+  if (!M) {
     Err.print("unit", errs());
     exit(1);
   }
@@ -421,11 +422,8 @@ TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
   EXPECT_FALSE(verifyFunction(*Func, &errs()));
 
   SmallVector<BasicBlock *, 2> ExtractedBlocks{
-    getBlockByName(Func, "catch"),
-    getBlockByName(Func, "invoke.cont2"),
-    getBlockByName(Func, "invoke.cont3"),
-    getBlockByName(Func, "lpad2")
-  };
+      getBlockByName(Func, "catch"), getBlockByName(Func, "invoke.cont2"),
+      getBlockByName(Func, "invoke.cont3"), getBlockByName(Func, "lpad2")};
 
   CodeExtractor CE(ExtractedBlocks);
   EXPECT_TRUE(CE.isEligible());
@@ -459,8 +457,8 @@ TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
                                                 Err, Ctx));
 
   Function *Func = M->getFunction("foo");
-  SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
-                                       getBlockByName(Func, "lpad") };
+  SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "entry"),
+                                      getBlockByName(Func, "lpad")};
 
   CodeExtractor CE(Blocks);
   EXPECT_TRUE(CE.isEligible());
@@ -512,7 +510,7 @@ TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
 
   assert(M && "Could not parse module?");
   Function *Func = M->getFunction("test");
-  SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
+  SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "if.else")};
   AssumptionCache AC(*Func);
   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
   EXPECT_TRUE(CE.isEligible());



More information about the llvm-commits mailing list