[clang] [CIR] Add musttail thunks and covariant return null-check (PR #191255)

via cfe-commits cfe-commits at lists.llvm.org
Tue Apr 14 10:29:53 PDT 2026


https://github.com/adams381 updated https://github.com/llvm/llvm-project/pull/191255

>From 9b41803ccf1e4be5f6b0968395e33864c3b8493e Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Tue, 7 Apr 2026 14:03:45 -0700
Subject: [PATCH 1/3] [CIR] Null-check pointer returns in covariant thunk
 adjustment

Match classic PerformReturnAdjustment for non-reference returns: if the
callee returns null, skip return adjustment.

Add CovariantReturn coverage to clang/test/CIR/CodeGen/thunks.cpp
(CIR ternary + LLVM/OGCG phi).
---
 clang/lib/CIR/CodeGen/CIRGenVTables.cpp | 39 +++++++++++++++++--------
 clang/test/CIR/CodeGen/thunks.cpp       | 29 ++++++++++++++++++
 2 files changed, 56 insertions(+), 12 deletions(-)

diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
index 56839ca03dbb1..63b81aba8475d 100644
--- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
@@ -559,26 +559,41 @@ uint64_t CIRGenVTables::getSecondaryVirtualPointerIndex(const CXXRecordDecl *rd,
 
 static RValue performReturnAdjustment(CIRGenFunction &cgf, QualType resultType,
                                       RValue rv, const ThunkInfo &thunk) {
-  // Emit the return adjustment.
+  // Emit the return adjustment.  For non-reference pointer returns, match
+  // classic codegen: skip the adjustment when the returned pointer is null.
   bool nullCheckValue = !resultType->isReferenceType();
-
   mlir::Value returnValue = rv.getValue();
 
-  if (nullCheckValue)
-    cgf.cgm.errorNYI(
-        "return adjustment with null check for non-reference types");
-
   const CXXRecordDecl *classDecl =
       resultType->getPointeeType()->getAsCXXRecordDecl();
   CharUnits classAlign = cgf.cgm.getClassPointerAlignment(classDecl);
   mlir::Type pointeeType = cgf.convertTypeForMem(resultType->getPointeeType());
-  returnValue = cgf.cgm.getCXXABI().performReturnAdjustment(
-      cgf, Address(returnValue, pointeeType, classAlign), classDecl,
-      thunk.Return);
+  CIRGenBuilderTy &builder = cgf.getBuilder();
+  mlir::Location loc = returnValue.getLoc();
+
+  if (!nullCheckValue) {
+    returnValue = cgf.cgm.getCXXABI().performReturnAdjustment(
+        cgf, Address(returnValue, pointeeType, classAlign), classDecl,
+        thunk.Return);
+    return RValue::get(returnValue);
+  }
 
-  if (nullCheckValue)
-    cgf.cgm.errorNYI(
-        "return adjustment with null check for non-reference types");
+  mlir::Value isNotNull = builder.createPtrIsNotNull(returnValue);
+  returnValue =
+      cir::TernaryOp::create(
+          builder, loc, isNotNull,
+          [&](mlir::OpBuilder &, mlir::Location) {
+            mlir::Value adjusted = cgf.cgm.getCXXABI().performReturnAdjustment(
+                cgf, Address(returnValue, pointeeType, classAlign), classDecl,
+                thunk.Return);
+            builder.createYield(loc, adjusted);
+          },
+          [&](mlir::OpBuilder &, mlir::Location) {
+            mlir::Value nullVal =
+                builder.getNullPtr(returnValue.getType(), loc).getResult();
+            builder.createYield(loc, nullVal);
+          })
+          .getResult();
 
   return RValue::get(returnValue);
 }
diff --git a/clang/test/CIR/CodeGen/thunks.cpp b/clang/test/CIR/CodeGen/thunks.cpp
index 15c4810738420..23e4bc8026c15 100644
--- a/clang/test/CIR/CodeGen/thunks.cpp
+++ b/clang/test/CIR/CodeGen/thunks.cpp
@@ -91,6 +91,21 @@ void C::g(int x) {}
 
 } // namespace Test4
 
+namespace CovariantReturn {
+// Covariant return with virtual inheritance: return-adjusting thunks use a
+// null check for pointer returns (classic PerformReturnAdjustment).
+struct A {
+  virtual A *f();
+};
+struct B : virtual A {
+  virtual A *f();
+};
+struct C : B {
+  virtual C *f();
+};
+C *C::f() { return 0; }
+} // namespace CovariantReturn
+
 // In CIR, all globals are emitted before functions.
 
 // Test1 vtable: C's vtable references the thunk for B's entry.
@@ -183,6 +198,12 @@ void C::g(int x) {}
 // CIR:   cir.call @_ZN5Test41C1gEi(%[[T4_RESULT]], %[[T4_ARG]])
 // CIR:   cir.return
 
+// --- CovariantReturn: return adjustment with null check on pointer return ---
+
+// CIR-LABEL: cir.func {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// CIR:       cir.call @_ZN15CovariantReturn1C1fEv
+// CIR:       cir.ternary
+
 // --- LLVM checks ---
 
 // LLVM: @_ZTVN5Test11CE = global { [3 x ptr], [3 x ptr] } {
@@ -231,6 +252,10 @@ void C::g(int x) {}
 // LLVM:   %[[L4_ARG:.*]] = load i32, ptr
 // LLVM:   call void @_ZN5Test41C1gEi(ptr{{.*}} %[[L4_ADJ]], i32{{.*}} %[[L4_ARG]])
 
+// LLVM-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// LLVM:       call {{.*}} @_ZN15CovariantReturn1C1fEv
+// LLVM:       phi ptr
+
 // --- OGCG checks ---
 
 // OGCG: @_ZTVN5Test11CE = unnamed_addr constant { [3 x ptr], [3 x ptr] } {
@@ -278,3 +303,7 @@ void C::g(int x) {}
 // OGCG:   %[[O4_ADJ:.*]] = getelementptr inbounds i8, ptr %[[O4_THIS]], i64 -8
 // OGCG:   %[[O4_ARG:.*]] = load i32, ptr
 // OGCG:   {{.*}}call void @_ZN5Test41C1gEi(ptr{{.*}} %[[O4_ADJ]], i32{{.*}} %[[O4_ARG]])
+
+// OGCG-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// OGCG:       {{.*}}call {{.*}} @_ZN15CovariantReturn1C1fEv
+// OGCG:       phi ptr

>From 170de5ffd779772a9e99ac8bfbd3306ca8fef7ab Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Tue, 7 Apr 2026 15:12:21 -0700
Subject: [PATCH 2/3] [CIR] Add musttail attribute and implement variadic thunk
 emission

Add musttail UnitAttr to cir.call/cir.try_call, parsed and printed
as a keyword.  Lower to LLVM::TailCallKind::MustTail so varargs are
forwarded correctly through '...'.

Implement emitMustTailThunk: forward entry-block arguments with
adjusted 'this' pointer via a musttail call.  Also implement
return-adjustment null-check for covariant return thunks.

Add VarargThunk and CovariantReturn tests to thunks.cpp.
---
 .../clang/CIR/Dialect/IR/CIRDialect.td        |  1 +
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |  1 +
 clang/lib/CIR/CodeGen/CIRGenVTables.cpp       | 30 ++++++++++++++--
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       |  8 +++++
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp |  5 ++-
 clang/test/CIR/CodeGen/thunks.cpp             | 34 +++++++++++++++++++
 6 files changed, 76 insertions(+), 3 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index b57f874c34393..5b808ea92f470 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -77,6 +77,7 @@ def CIR_Dialect : Dialect {
     static llvm::StringRef getArgAttrsAttrName() { return "arg_attrs"; }
     static llvm::StringRef getRecordLayoutsAttrName() { return "cir.record_layouts"; }
     static llvm::StringRef getCUDABinaryHandleAttrName() { return "cir.cu.binary_handle"; }
+    static llvm::StringRef getMustTailAttrName() { return "musttail"; }
 
     static llvm::StringRef getAMDGPUCodeObjectVersionAttrName() { return "cir.amdhsa_code_object_version"; }
     static llvm::StringRef getAMDGPUPrintfKindAttrName() { return "cir.amdgpu_printf_kind"; }
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index e10b9102dae78..6f8db65acccc9 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -3921,6 +3921,7 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
   dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
       Variadic<CIR_AnyType>:$args,
       UnitAttr:$nothrow,
+      UnitAttr:$musttail,
       DefaultValuedAttr<CIR_SideEffect, "SideEffect::All">:$side_effect,
       OptionalAttr<DictArrayAttr>:$arg_attrs,
       OptionalAttr<DictArrayAttr>:$res_attrs
diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
index 63b81aba8475d..936e9b6487514 100644
--- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
@@ -758,8 +758,34 @@ void CIRGenFunction::emitCallAndReturnForThunk(cir::FuncOp callee,
 void CIRGenFunction::emitMustTailThunk(GlobalDecl gd,
                                        mlir::Value adjustedThisPtr,
                                        cir::FuncOp callee) {
-  assert(!cir::MissingFeatures::opCallMustTail());
-  cgm.errorNYI("musttail thunk");
+  // Forward all function arguments, replacing 'this' with the adjusted pointer.
+  // The call is marked musttail so varargs are forwarded correctly.
+  auto thunkFn = cast<cir::FuncOp>(curFn);
+  mlir::Block &entryBlock = thunkFn.getBody().front();
+  SmallVector<mlir::Value> args;
+  for (mlir::BlockArgument arg : entryBlock.getArguments())
+    args.push_back(arg);
+
+  // Replace the 'this' argument (first arg) with the adjusted pointer.
+  assert(!args.empty() && "thunk must have at least 'this' argument");
+  if (adjustedThisPtr.getType() != args[0].getType())
+    adjustedThisPtr = builder.createBitcast(adjustedThisPtr, args[0].getType());
+  args[0] = adjustedThisPtr;
+
+  mlir::Location loc = thunkFn.getLoc();
+  cir::FuncType calleeTy = callee.getFunctionType();
+  mlir::Type retTy = calleeTy.getReturnType();
+
+  cir::CallOp call = builder.createCallOp(loc, callee, args);
+  call->setAttr(cir::CIRDialect::getMustTailAttrName(),
+                mlir::UnitAttr::get(builder.getContext()));
+
+  if (isa<cir::VoidType>(retTy))
+    cir::ReturnOp::create(builder, loc);
+  else
+    cir::ReturnOp::create(builder, loc, call->getResult(0));
+
+  finishThunk();
 }
 
 void CIRGenFunction::generateThunk(cir::FuncOp fn,
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index d685a14ef263a..4514b04780746 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -918,6 +918,10 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
     return ::mlir::failure();
   }
 
+  if (parser.parseOptionalKeyword("musttail").succeeded())
+    result.addAttribute(CIRDialect::getMustTailAttrName(),
+                        mlir::UnitAttr::get(parser.getContext()));
+
   if (parser.parseOptionalKeyword("nothrow").succeeded())
     result.addAttribute(CIRDialect::getNoThrowAttrName(),
                         mlir::UnitAttr::get(parser.getContext()));
@@ -1020,6 +1024,9 @@ printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
     printer << tryCall.getUnwindDest();
   }
 
+  if (op->hasAttr(CIRDialect::getMustTailAttrName()))
+    printer << " musttail";
+
   if (isNothrow)
     printer << " nothrow";
 
@@ -1031,6 +1038,7 @@ printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
 
   llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
       CIRDialect::getCalleeAttrName(),
+      CIRDialect::getMustTailAttrName(),
       CIRDialect::getNoThrowAttrName(),
       CIRDialect::getSideEffectAttrName(),
       CIRDialect::getOperandSegmentSizesAttrName(),
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 032971410c64b..b7fd20715287a 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1663,7 +1663,8 @@ static void lowerCallAttributes(cir::CIRCallOpInterface op,
         attr.getName() == CIRDialect::getSideEffectAttrName() ||
         attr.getName() == CIRDialect::getNoThrowAttrName() ||
         attr.getName() == CIRDialect::getNoUnwindAttrName() ||
-        attr.getName() == CIRDialect::getNoReturnAttrName())
+        attr.getName() == CIRDialect::getNoReturnAttrName() ||
+        attr.getName() == CIRDialect::getMustTailAttrName())
       continue;
 
     assert(!cir::MissingFeatures::opFuncExtraAttrs());
@@ -1764,6 +1765,8 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
     newOp.setNoUnwind(noUnwind);
     newOp.setWillReturn(willReturn);
     newOp.setNoreturn(noReturn);
+    if (op->hasAttr(CIRDialect::getMustTailAttrName()))
+      newOp.setTailCallKind(mlir::LLVM::TailCallKind::MustTail);
   }
 
   return mlir::success();
diff --git a/clang/test/CIR/CodeGen/thunks.cpp b/clang/test/CIR/CodeGen/thunks.cpp
index 23e4bc8026c15..b36e8a9805516 100644
--- a/clang/test/CIR/CodeGen/thunks.cpp
+++ b/clang/test/CIR/CodeGen/thunks.cpp
@@ -106,6 +106,21 @@ struct C : B {
 C *C::f() { return 0; }
 } // namespace CovariantReturn
 
+namespace VarargThunk {
+// Variadic this-adjusting thunk.  On x86_64, the thunk forwards arguments
+// via musttail (classic codegen) or direct argument forwarding (CIR).
+struct A {
+  virtual void f(int x, ...);
+};
+struct B {
+  virtual void f(int x, ...);
+};
+struct C : A, B {
+  void f(int x, ...) override;
+};
+void C::f(int x, ...) {}
+} // namespace VarargThunk
+
 // In CIR, all globals are emitted before functions.
 
 // Test1 vtable: C's vtable references the thunk for B's entry.
@@ -204,6 +219,17 @@ C *C::f() { return 0; }
 // CIR:       cir.call @_ZN15CovariantReturn1C1fEv
 // CIR:       cir.ternary
 
+// --- VarargThunk: variadic this-adjusting thunk ---
+
+// CIR: cir.func {{.*}} @_ZThn8_N11VarargThunk1C1fEiz(%arg0: !cir.ptr<
+// CIR:   %[[VT_THIS:.*]] = cir.load
+// CIR:   %[[VT_CAST:.*]] = cir.cast bitcast %[[VT_THIS]] : !cir.ptr<{{.*}}> -> !cir.ptr<!u8i>
+// CIR:   %[[VT_OFFSET:.*]] = cir.const #cir.int<-8> : !s64i
+// CIR:   %[[VT_ADJUSTED:.*]] = cir.ptr_stride %[[VT_CAST]], %[[VT_OFFSET]] : (!cir.ptr<!u8i>, !s64i) -> !cir.ptr<!u8i>
+// CIR:   %[[VT_RESULT:.*]] = cir.cast bitcast %[[VT_ADJUSTED]] : !cir.ptr<!u8i> -> !cir.ptr<
+// CIR:   cir.call @_ZN11VarargThunk1C1fEiz(%[[VT_RESULT]], %arg1) musttail
+// CIR:   cir.return
+
 // --- LLVM checks ---
 
 // LLVM: @_ZTVN5Test11CE = global { [3 x ptr], [3 x ptr] } {
@@ -256,6 +282,10 @@ C *C::f() { return 0; }
 // LLVM:       call {{.*}} @_ZN15CovariantReturn1C1fEv
 // LLVM:       phi ptr
 
+// LLVM-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+// LLVM:   getelementptr i8, ptr {{.*}}, i64 -8
+// LLVM:   musttail call void (ptr, i32, ...) @_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+
 // --- OGCG checks ---
 
 // OGCG: @_ZTVN5Test11CE = unnamed_addr constant { [3 x ptr], [3 x ptr] } {
@@ -307,3 +337,7 @@ C *C::f() { return 0; }
 // OGCG-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
 // OGCG:       {{.*}}call {{.*}} @_ZN15CovariantReturn1C1fEv
 // OGCG:       phi ptr
+
+// OGCG-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+// OGCG:   getelementptr inbounds i8, ptr {{.*}}, i64 -8
+// OGCG:   musttail call void (ptr, i32, ...) @_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)

>From b9966d00597001c3d5005bda4ee8e00f23966ed8 Mon Sep 17 00:00:00 2001
From: Adam Smith <adams at nvidia.com>
Date: Mon, 13 Apr 2026 13:08:58 -0700
Subject: [PATCH 3/3] [CIR] Address review nits in emitMustTailThunk

Use getCurFunctionEntryBlock() and curFn->getLoc() instead of
casting curFn to FuncOp and accessing body/loc directly.

Made-with: Cursor
---
 clang/lib/CIR/CodeGen/CIRGenVTables.cpp | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
index 936e9b6487514..6e1a80926f679 100644
--- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
@@ -760,10 +760,9 @@ void CIRGenFunction::emitMustTailThunk(GlobalDecl gd,
                                        cir::FuncOp callee) {
   // Forward all function arguments, replacing 'this' with the adjusted pointer.
   // The call is marked musttail so varargs are forwarded correctly.
-  auto thunkFn = cast<cir::FuncOp>(curFn);
-  mlir::Block &entryBlock = thunkFn.getBody().front();
+  mlir::Block *entryBlock = getCurFunctionEntryBlock();
   SmallVector<mlir::Value> args;
-  for (mlir::BlockArgument arg : entryBlock.getArguments())
+  for (mlir::BlockArgument arg : entryBlock->getArguments())
     args.push_back(arg);
 
   // Replace the 'this' argument (first arg) with the adjusted pointer.
@@ -772,7 +771,7 @@ void CIRGenFunction::emitMustTailThunk(GlobalDecl gd,
     adjustedThisPtr = builder.createBitcast(adjustedThisPtr, args[0].getType());
   args[0] = adjustedThisPtr;
 
-  mlir::Location loc = thunkFn.getLoc();
+  mlir::Location loc = curFn->getLoc();
   cir::FuncType calleeTy = callee.getFunctionType();
   mlir::Type retTy = calleeTy.getReturnType();
 



More information about the cfe-commits mailing list