[clang] [clang][CodeGen] Fix templated constructors in base classes introduce bugs. (PR #87332)

via cfe-commits cfe-commits at lists.llvm.org
Tue Apr 2 03:55:45 PDT 2024


https://github.com/idler66 created https://github.com/llvm/llvm-project/pull/87332

For example, struct base { public : base() {} template <typename T> base(T x) {} }; struct derived : public base { public: derived() {} derived(derived& that): base(that) {} }; int main() { derived d1; derived d2 = d1; return 0;}

The copy constructor of base is not chosen because it is not an exact match for the argument type derived. The templated constructor base(T x) can accept arguments of any type T. In the line derived(const derived& that): base(that), the that object should be copied twice — once during the initialization of the derived class and again when passing it to the base class constructor. The assignment d2 = d1 via base(that) would result in an infinite recursion and eventually lead to a stack overflow.

Multiple executions of copy semantics lead to stack overflow. So, for the templated constructor base(T x),
if T is a subclass of base, pass-by-reference should be used!

>From e36764032dd624f26ab44a7bc3f1126eb98617ae Mon Sep 17 00:00:00 2001
From: wangjufan <wangjufan at gmail.com>
Date: Sun, 31 Mar 2024 23:49:30 +0800
Subject: [PATCH] [clang][CodeGen] Fix templated constructors in base classes
 introduce bugs.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

For example, struct base { public : base() {} template <typename T> base(T x) {} };
struct derived : public base { public: derived() {} derived(derived& that): base(that) {} };
int main() { derived d1; derived d2 = d1; return 0;}

The copy constructor of base is not chosen because it is not an exact match for the argument type derived.
The templated constructor base(T x) can accept arguments of any type T.
In the line derived(const derived& that): base(that), the that object should be copied twice — once during the initialization of the derived class and again when passing it to the base class constructor.
The assignment d2 = d1 via base(that) would result in an infinite recursion and eventually lead to a stack overflow.

Multiple executions of copy semantics lead to stack overflow.
So, for the templated constructor base(T x),
if T is a subclass of base, pass-by-reference should be used!
---
 clang/lib/CodeGen/CGCall.cpp                  |  24 +-
 clang/lib/CodeGen/CodeGenFunction.h           |   2 +-
 clang/unittests/CodeGen/CMakeLists.txt        |   1 +
 .../CodeGen/TemplateInstantiationTest.cpp     | 214 ++++++++++++++++++
 4 files changed, 238 insertions(+), 3 deletions(-)
 create mode 100644 clang/unittests/CodeGen/TemplateInstantiationTest.cpp

diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index 9308528ac93823..76f70bf6b8a8e2 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4573,7 +4573,7 @@ void CodeGenFunction::EmitCallArgs(
             (isa<ObjCMethodDecl>(AC.getDecl()) &&
              isObjCMethodWithTypeParams(cast<ObjCMethodDecl>(AC.getDecl())))) &&
            "Argument and parameter types don't match");
-    EmitCallArg(Args, *Arg, ArgTypes[Idx]);
+    EmitCallArg(Args, *Arg, ArgTypes[Idx], AC);
     // In particular, we depend on it being the last arg in Args, and the
     // objectsize bits depend on there only being one arg if !LeftToRight.
     assert(InitialArgSize + 1 == Args.size() &&
@@ -4664,7 +4664,7 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const {
 }
 
 void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
-                                  QualType type) {
+                                  QualType type, const AbstractCallee& AC) {
   DisableDebugLocationUpdates Dis(*this, E);
   if (const ObjCIndirectCopyRestoreExpr *CRE
         = dyn_cast<ObjCIndirectCopyRestoreExpr>(E)) {
@@ -4680,6 +4680,26 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
     return args.add(EmitReferenceBindingToExpr(E), type);
   }
 
+  auto ShouldPassParametersByReferenceToTemplatedConstructors = [&]() {
+    if(1 != AC.getNumParams()) return false;
+    if (const CXXRecordDecl* SubRecordDecl = type->getAsCXXRecordDecl()) {
+      if (const CXXConstructorDecl* ConstructorDecl = dyn_cast<clang::CXXConstructorDecl>(AC.getDecl())) {
+        if(const CXXRecordDecl* BaseRecordDecl = dyn_cast<CXXRecordDecl>(ConstructorDecl->getParent())) {
+          if(SubRecordDecl->isDerivedFrom(BaseRecordDecl)) {
+            return true;
+          }
+        }
+      }
+    }
+    return false;
+  };
+  if(ShouldPassParametersByReferenceToTemplatedConstructors()) {
+    AggValueSlot Slot = args.isUsingInAlloca()
+        ? createPlaceholderSlot(*this, type) : CreateAggTemp(type, "agg.tmp");
+    RValue RV = Slot.asRValue();
+    return args.add(RV, type);
+  }
+
   bool HasAggregateEvalKind = hasAggregateEvaluationKind(type);
 
   // In the Microsoft C++ ABI, aggregate arguments are destructed by the callee.
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index e2a7e28c8211ea..cf713d6fc07dd0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4958,7 +4958,7 @@ class CodeGenFunction : public CodeGenTypeCache {
                            unsigned ParmNum);
 
   /// EmitCallArg - Emit a single call argument.
-  void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType);
+  void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType, const AbstractCallee& AC);
 
   /// EmitDelegateCallArg - We are performing a delegate call; that
   /// is, the current function is delegating to another one.  Produce
diff --git a/clang/unittests/CodeGen/CMakeLists.txt b/clang/unittests/CodeGen/CMakeLists.txt
index a437f441568f27..8870237c85539d 100644
--- a/clang/unittests/CodeGen/CMakeLists.txt
+++ b/clang/unittests/CodeGen/CMakeLists.txt
@@ -9,6 +9,7 @@ add_clang_unittest(ClangCodeGenTests
   CodeGenExternalTest.cpp
   TBAAMetadataTest.cpp
   CheckTargetFeaturesTest.cpp
+  TemplateInstantiationTest.cpp
   )
 
 clang_target_link_libraries(ClangCodeGenTests
diff --git a/clang/unittests/CodeGen/TemplateInstantiationTest.cpp b/clang/unittests/CodeGen/TemplateInstantiationTest.cpp
new file mode 100644
index 00000000000000..08d8f673bb1d43
--- /dev/null
+++ b/clang/unittests/CodeGen/TemplateInstantiationTest.cpp
@@ -0,0 +1,214 @@
+//===- unittests/CodeGen/TemplateInstantiationTest.cpp - template instantiation test -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "TestCompiler.h"
+
+#include "clang/AST/ASTConsumer.h"
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/GlobalDecl.h"
+#include "clang/AST/RecursiveASTVisitor.h"
+#include "clang/Basic/TargetInfo.h"
+#include "clang/CodeGen/CodeGenABITypes.h"
+#include "clang/CodeGen/ModuleBuilder.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Lex/Preprocessor.h"
+#include "clang/Parse/ParseAST.h"
+#include "clang/Sema/Sema.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/TargetParser/Host.h"
+#include "llvm/TargetParser/Triple.h"
+#include "gtest/gtest.h"
+
+#include "llvm/Analysis/CallGraph.h"
+#include <unordered_set>
+
+using namespace llvm;
+using namespace clang;
+
+namespace {
+
+// forward declarations
+struct TemplateInstantiationASTConsumer;
+static void test_instantiation_fns(TemplateInstantiationASTConsumer *my);
+static bool test_instantiation_fns_ran;
+
+// This forwards the calls to the Clang CodeGenerator
+// so that we can test CodeGen functions while it is open.
+// It accumulates toplevel decls in HandleTopLevelDecl and
+// calls test_instantiation_fns() in HandleTranslationUnit
+// after forwarding that function to the CodeGenerator.
+
+struct TemplateInstantiationASTConsumer : public ASTConsumer {
+  std::unique_ptr<CodeGenerator> Builder;
+  std::vector<Decl*> toplevel_decls;
+
+  TemplateInstantiationASTConsumer(std::unique_ptr<CodeGenerator> Builder_in)
+    : ASTConsumer(), Builder(std::move(Builder_in))
+  {
+  }
+
+  ~TemplateInstantiationASTConsumer() { }
+
+  void Initialize(ASTContext &Context) override;
+  void HandleCXXStaticMemberVarInstantiation(VarDecl *VD) override;
+  bool HandleTopLevelDecl(DeclGroupRef D) override;
+  void HandleInlineFunctionDefinition(FunctionDecl *D) override;
+  void HandleInterestingDecl(DeclGroupRef D) override;
+  void HandleTranslationUnit(ASTContext &Ctx) override;
+  void HandleTagDeclDefinition(TagDecl *D) override;
+  void HandleTagDeclRequiredDefinition(const TagDecl *D) override;
+  void HandleCXXImplicitFunctionInstantiation(FunctionDecl *D) override;
+  void HandleTopLevelDeclInObjCContainer(DeclGroupRef D) override;
+  void HandleImplicitImportDecl(ImportDecl *D) override;
+  void CompleteTentativeDefinition(VarDecl *D) override;
+  void AssignInheritanceModel(CXXRecordDecl *RD) override;
+  void HandleVTable(CXXRecordDecl *RD) override;
+  ASTMutationListener *GetASTMutationListener() override;
+  ASTDeserializationListener *GetASTDeserializationListener() override;
+  void PrintStats() override;
+  bool shouldSkipFunctionBody(Decl *D) override;
+};
+
+void TemplateInstantiationASTConsumer::Initialize(ASTContext &Context) {
+  Builder->Initialize(Context);
+}
+
+bool TemplateInstantiationASTConsumer::HandleTopLevelDecl(DeclGroupRef DG) {
+
+  for (DeclGroupRef::iterator I = DG.begin(), E = DG.end(); I != E; ++I) {
+    toplevel_decls.push_back(*I);
+  }
+
+  return Builder->HandleTopLevelDecl(DG);
+}
+
+void TemplateInstantiationASTConsumer::HandleInlineFunctionDefinition(FunctionDecl *D) {
+  Builder->HandleInlineFunctionDefinition(D);
+}
+
+void TemplateInstantiationASTConsumer::HandleInterestingDecl(DeclGroupRef D) {
+  Builder->HandleInterestingDecl(D);
+}
+
+void TemplateInstantiationASTConsumer::HandleTranslationUnit(ASTContext &Context) {
+  // HandleTranslationUnit can close the module
+  Builder->HandleTranslationUnit(Context);
+  test_instantiation_fns(this);
+}
+
+void TemplateInstantiationASTConsumer::HandleTagDeclDefinition(TagDecl *D) {
+  Builder->HandleTagDeclDefinition(D);
+}
+
+void TemplateInstantiationASTConsumer::HandleTagDeclRequiredDefinition(const TagDecl *D) {
+  Builder->HandleTagDeclRequiredDefinition(D);
+}
+
+void TemplateInstantiationASTConsumer::HandleCXXImplicitFunctionInstantiation(FunctionDecl *D) {
+  Builder->HandleCXXImplicitFunctionInstantiation(D);
+}
+
+void TemplateInstantiationASTConsumer::HandleTopLevelDeclInObjCContainer(DeclGroupRef D) {
+  Builder->HandleTopLevelDeclInObjCContainer(D);
+}
+
+void TemplateInstantiationASTConsumer::HandleImplicitImportDecl(ImportDecl *D) {
+  Builder->HandleImplicitImportDecl(D);
+}
+
+void TemplateInstantiationASTConsumer::CompleteTentativeDefinition(VarDecl *D) {
+  Builder->CompleteTentativeDefinition(D);
+}
+
+void TemplateInstantiationASTConsumer::AssignInheritanceModel(CXXRecordDecl *RD) {
+  Builder->AssignInheritanceModel(RD);
+}
+
+void TemplateInstantiationASTConsumer::HandleCXXStaticMemberVarInstantiation(VarDecl *VD) {
+   Builder->HandleCXXStaticMemberVarInstantiation(VD);
+}
+
+void TemplateInstantiationASTConsumer::HandleVTable(CXXRecordDecl *RD) {
+   Builder->HandleVTable(RD);
+ }
+
+ASTMutationListener *TemplateInstantiationASTConsumer::GetASTMutationListener() {
+  return Builder->GetASTMutationListener();
+}
+
+ASTDeserializationListener *TemplateInstantiationASTConsumer::GetASTDeserializationListener() {
+  return Builder->GetASTDeserializationListener();
+}
+
+void TemplateInstantiationASTConsumer::PrintStats() {
+  Builder->PrintStats();
+}
+
+bool TemplateInstantiationASTConsumer::shouldSkipFunctionBody(Decl *D) {
+  return Builder->shouldSkipFunctionBody(D);
+}
+
+const char TestProgram[] = "struct base { public : base() {} template <typename T> base(T x) {} }; struct derived : public base { public: derived() {} derived(derived& that): base(that) {} }; int main() { derived d1; derived d2 = d1; return 0;}";
+
+bool hasCycles(const Function *CurrentFunction,
+               std::unordered_set<const Function *> &VisitedFunctions,
+               std::unordered_set<const Function *> &RecursionStack,
+               const CallGraphNode* CurrentNode) {
+  VisitedFunctions.insert(CurrentFunction);
+  RecursionStack.insert(CurrentFunction);
+  for (CallGraphNode::const_iterator IT = CurrentNode->begin(), END = CurrentNode->end(); IT != END; ++IT) {
+    if (const Function *CalleeFunction = IT->second->getFunction()) {
+      if (RecursionStack.count(CalleeFunction)) {
+        return true;
+      }
+      if (VisitedFunctions.count(CalleeFunction) == 0 && hasCycles(CalleeFunction, VisitedFunctions, RecursionStack, IT->second)) {
+        return true;
+      }
+    }
+  }
+  RecursionStack.erase(CurrentFunction);
+  return false;
+}
+
+static void test_instantiation_fns(TemplateInstantiationASTConsumer *InstantiationASTConsumer) {
+  test_instantiation_fns_ran = true;
+  llvm::Module* Mdl = InstantiationASTConsumer->Builder->GetModule();
+  CallGraph Graph(*Mdl);
+  std::unordered_set<const Function *> VisitedFunctions;
+  std::unordered_set<const Function *> RecursionStack;
+  for (llvm::CallGraph::const_iterator IT = Graph.begin(), END = Graph.end();
+       IT != END; ++IT) {
+    const Function* Fnc = IT->first;
+    const CallGraphNode* GraphNode = IT->second.get();
+    if (Fnc && VisitedFunctions.count(Fnc) == 0){
+      if(hasCycles(Fnc, VisitedFunctions, RecursionStack, GraphNode)) {
+        test_instantiation_fns_ran = false;
+        break;
+      }
+    }
+  }
+}
+ 
+TEST(TemplatedConstructorTemplateInstantiationTest, TemplatedConstructorTemplateInstantiationTest) {
+  clang::LangOptions LO;
+  LO.CPlusPlus = 1;
+  TestCompiler Compiler(LO);
+  auto CustomASTConsumer
+    = std::make_unique<TemplateInstantiationASTConsumer>(std::move(Compiler.CG));
+
+  Compiler.init(TestProgram, std::move(CustomASTConsumer));
+  ParseAST(Compiler.compiler.getSema(), false, false);
+
+  ASSERT_TRUE(test_instantiation_fns_ran);
+}
+
+} // end anonymous namespace
+



More information about the cfe-commits mailing list