[clang] a36e9d1 - [CIR] Add musttail thunks and covariant return null-check (#191255)
via cfe-commits
cfe-commits at lists.llvm.org
Tue Apr 14 12:01:05 PDT 2026
Author: adams381
Date: 2026-04-14T14:00:59-05:00
New Revision: a36e9d1d57b12de3674689a617ab7452ed43d9a2
URL: https://github.com/llvm/llvm-project/commit/a36e9d1d57b12de3674689a617ab7452ed43d9a2
DIFF: https://github.com/llvm/llvm-project/commit/a36e9d1d57b12de3674689a617ab7452ed43d9a2.diff
LOG: [CIR] Add musttail thunks and covariant return null-check (#191255)
Implement variadic thunk emission via musttail and null-check
pointer returns in covariant thunk adjustment, matching classic
codegen behavior.
Adds musttail UnitAttr to cir.call/cir.try_call with lowering
to LLVM::MustTail.
Made with [Cursor](https://cursor.com)
Added:
Modified:
clang/include/clang/CIR/Dialect/IR/CIRDialect.td
clang/include/clang/CIR/Dialect/IR/CIROps.td
clang/lib/CIR/CodeGen/CIRGenVTables.cpp
clang/lib/CIR/Dialect/IR/CIRDialect.cpp
clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
clang/test/CIR/CodeGen/thunks.cpp
Removed:
################################################################################
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index b57f874c34393..5b808ea92f470 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -77,6 +77,7 @@ def CIR_Dialect : Dialect {
static llvm::StringRef getArgAttrsAttrName() { return "arg_attrs"; }
static llvm::StringRef getRecordLayoutsAttrName() { return "cir.record_layouts"; }
static llvm::StringRef getCUDABinaryHandleAttrName() { return "cir.cu.binary_handle"; }
+ static llvm::StringRef getMustTailAttrName() { return "musttail"; }
static llvm::StringRef getAMDGPUCodeObjectVersionAttrName() { return "cir.amdhsa_code_object_version"; }
static llvm::StringRef getAMDGPUPrintfKindAttrName() { return "cir.amdgpu_printf_kind"; }
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index e10b9102dae78..6f8db65acccc9 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -3921,6 +3921,7 @@ class CIR_CallOpBase<string mnemonic, list<Trait> extra_traits = []>
dag commonArgs = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$args,
UnitAttr:$nothrow,
+ UnitAttr:$musttail,
DefaultValuedAttr<CIR_SideEffect, "SideEffect::All">:$side_effect,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
diff --git a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
index 56839ca03dbb1..6e1a80926f679 100644
--- a/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenVTables.cpp
@@ -559,26 +559,41 @@ uint64_t CIRGenVTables::getSecondaryVirtualPointerIndex(const CXXRecordDecl *rd,
static RValue performReturnAdjustment(CIRGenFunction &cgf, QualType resultType,
RValue rv, const ThunkInfo &thunk) {
- // Emit the return adjustment.
+ // Emit the return adjustment. For non-reference pointer returns, match
+ // classic codegen: skip the adjustment when the returned pointer is null.
bool nullCheckValue = !resultType->isReferenceType();
-
mlir::Value returnValue = rv.getValue();
- if (nullCheckValue)
- cgf.cgm.errorNYI(
- "return adjustment with null check for non-reference types");
-
const CXXRecordDecl *classDecl =
resultType->getPointeeType()->getAsCXXRecordDecl();
CharUnits classAlign = cgf.cgm.getClassPointerAlignment(classDecl);
mlir::Type pointeeType = cgf.convertTypeForMem(resultType->getPointeeType());
- returnValue = cgf.cgm.getCXXABI().performReturnAdjustment(
- cgf, Address(returnValue, pointeeType, classAlign), classDecl,
- thunk.Return);
+ CIRGenBuilderTy &builder = cgf.getBuilder();
+ mlir::Location loc = returnValue.getLoc();
+
+ if (!nullCheckValue) {
+ returnValue = cgf.cgm.getCXXABI().performReturnAdjustment(
+ cgf, Address(returnValue, pointeeType, classAlign), classDecl,
+ thunk.Return);
+ return RValue::get(returnValue);
+ }
- if (nullCheckValue)
- cgf.cgm.errorNYI(
- "return adjustment with null check for non-reference types");
+ mlir::Value isNotNull = builder.createPtrIsNotNull(returnValue);
+ returnValue =
+ cir::TernaryOp::create(
+ builder, loc, isNotNull,
+ [&](mlir::OpBuilder &, mlir::Location) {
+ mlir::Value adjusted = cgf.cgm.getCXXABI().performReturnAdjustment(
+ cgf, Address(returnValue, pointeeType, classAlign), classDecl,
+ thunk.Return);
+ builder.createYield(loc, adjusted);
+ },
+ [&](mlir::OpBuilder &, mlir::Location) {
+ mlir::Value nullVal =
+ builder.getNullPtr(returnValue.getType(), loc).getResult();
+ builder.createYield(loc, nullVal);
+ })
+ .getResult();
return RValue::get(returnValue);
}
@@ -743,8 +758,33 @@ void CIRGenFunction::emitCallAndReturnForThunk(cir::FuncOp callee,
void CIRGenFunction::emitMustTailThunk(GlobalDecl gd,
mlir::Value adjustedThisPtr,
cir::FuncOp callee) {
- assert(!cir::MissingFeatures::opCallMustTail());
- cgm.errorNYI("musttail thunk");
+ // Forward all function arguments, replacing 'this' with the adjusted pointer.
+ // The call is marked musttail so varargs are forwarded correctly.
+ mlir::Block *entryBlock = getCurFunctionEntryBlock();
+ SmallVector<mlir::Value> args;
+ for (mlir::BlockArgument arg : entryBlock->getArguments())
+ args.push_back(arg);
+
+ // Replace the 'this' argument (first arg) with the adjusted pointer.
+ assert(!args.empty() && "thunk must have at least 'this' argument");
+ if (adjustedThisPtr.getType() != args[0].getType())
+ adjustedThisPtr = builder.createBitcast(adjustedThisPtr, args[0].getType());
+ args[0] = adjustedThisPtr;
+
+ mlir::Location loc = curFn->getLoc();
+ cir::FuncType calleeTy = callee.getFunctionType();
+ mlir::Type retTy = calleeTy.getReturnType();
+
+ cir::CallOp call = builder.createCallOp(loc, callee, args);
+ call->setAttr(cir::CIRDialect::getMustTailAttrName(),
+ mlir::UnitAttr::get(builder.getContext()));
+
+ if (isa<cir::VoidType>(retTy))
+ cir::ReturnOp::create(builder, loc);
+ else
+ cir::ReturnOp::create(builder, loc, call->getResult(0));
+
+ finishThunk();
}
void CIRGenFunction::generateThunk(cir::FuncOp fn,
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index d685a14ef263a..4514b04780746 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -918,6 +918,10 @@ static mlir::ParseResult parseCallCommon(mlir::OpAsmParser &parser,
return ::mlir::failure();
}
+ if (parser.parseOptionalKeyword("musttail").succeeded())
+ result.addAttribute(CIRDialect::getMustTailAttrName(),
+ mlir::UnitAttr::get(parser.getContext()));
+
if (parser.parseOptionalKeyword("nothrow").succeeded())
result.addAttribute(CIRDialect::getNoThrowAttrName(),
mlir::UnitAttr::get(parser.getContext()));
@@ -1020,6 +1024,9 @@ printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
printer << tryCall.getUnwindDest();
}
+ if (op->hasAttr(CIRDialect::getMustTailAttrName()))
+ printer << " musttail";
+
if (isNothrow)
printer << " nothrow";
@@ -1031,6 +1038,7 @@ printCallCommon(mlir::Operation *op, mlir::FlatSymbolRefAttr calleeSym,
llvm::SmallVector<::llvm::StringRef> elidedAttrs = {
CIRDialect::getCalleeAttrName(),
+ CIRDialect::getMustTailAttrName(),
CIRDialect::getNoThrowAttrName(),
CIRDialect::getSideEffectAttrName(),
CIRDialect::getOperandSegmentSizesAttrName(),
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 032971410c64b..b7fd20715287a 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1663,7 +1663,8 @@ static void lowerCallAttributes(cir::CIRCallOpInterface op,
attr.getName() == CIRDialect::getSideEffectAttrName() ||
attr.getName() == CIRDialect::getNoThrowAttrName() ||
attr.getName() == CIRDialect::getNoUnwindAttrName() ||
- attr.getName() == CIRDialect::getNoReturnAttrName())
+ attr.getName() == CIRDialect::getNoReturnAttrName() ||
+ attr.getName() == CIRDialect::getMustTailAttrName())
continue;
assert(!cir::MissingFeatures::opFuncExtraAttrs());
@@ -1764,6 +1765,8 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
newOp.setNoUnwind(noUnwind);
newOp.setWillReturn(willReturn);
newOp.setNoreturn(noReturn);
+ if (op->hasAttr(CIRDialect::getMustTailAttrName()))
+ newOp.setTailCallKind(mlir::LLVM::TailCallKind::MustTail);
}
return mlir::success();
diff --git a/clang/test/CIR/CodeGen/thunks.cpp b/clang/test/CIR/CodeGen/thunks.cpp
index 15c4810738420..b36e8a9805516 100644
--- a/clang/test/CIR/CodeGen/thunks.cpp
+++ b/clang/test/CIR/CodeGen/thunks.cpp
@@ -91,6 +91,36 @@ void C::g(int x) {}
} // namespace Test4
+namespace CovariantReturn {
+// Covariant return with virtual inheritance: return-adjusting thunks use a
+// null check for pointer returns (classic PerformReturnAdjustment).
+struct A {
+ virtual A *f();
+};
+struct B : virtual A {
+ virtual A *f();
+};
+struct C : B {
+ virtual C *f();
+};
+C *C::f() { return 0; }
+} // namespace CovariantReturn
+
+namespace VarargThunk {
+// Variadic this-adjusting thunk. On x86_64, the thunk forwards arguments
+// via musttail (classic codegen) or direct argument forwarding (CIR).
+struct A {
+ virtual void f(int x, ...);
+};
+struct B {
+ virtual void f(int x, ...);
+};
+struct C : A, B {
+ void f(int x, ...) override;
+};
+void C::f(int x, ...) {}
+} // namespace VarargThunk
+
// In CIR, all globals are emitted before functions.
// Test1 vtable: C's vtable references the thunk for B's entry.
@@ -183,6 +213,23 @@ void C::g(int x) {}
// CIR: cir.call @_ZN5Test41C1gEi(%[[T4_RESULT]], %[[T4_ARG]])
// CIR: cir.return
+// --- CovariantReturn: return adjustment with null check on pointer return ---
+
+// CIR-LABEL: cir.func {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// CIR: cir.call @_ZN15CovariantReturn1C1fEv
+// CIR: cir.ternary
+
+// --- VarargThunk: variadic this-adjusting thunk ---
+
+// CIR: cir.func {{.*}} @_ZThn8_N11VarargThunk1C1fEiz(%arg0: !cir.ptr<
+// CIR: %[[VT_THIS:.*]] = cir.load
+// CIR: %[[VT_CAST:.*]] = cir.cast bitcast %[[VT_THIS]] : !cir.ptr<{{.*}}> -> !cir.ptr<!u8i>
+// CIR: %[[VT_OFFSET:.*]] = cir.const #cir.int<-8> : !s64i
+// CIR: %[[VT_ADJUSTED:.*]] = cir.ptr_stride %[[VT_CAST]], %[[VT_OFFSET]] : (!cir.ptr<!u8i>, !s64i) -> !cir.ptr<!u8i>
+// CIR: %[[VT_RESULT:.*]] = cir.cast bitcast %[[VT_ADJUSTED]] : !cir.ptr<!u8i> -> !cir.ptr<
+// CIR: cir.call @_ZN11VarargThunk1C1fEiz(%[[VT_RESULT]], %arg1) musttail
+// CIR: cir.return
+
// --- LLVM checks ---
// LLVM: @_ZTVN5Test11CE = global { [3 x ptr], [3 x ptr] } {
@@ -231,6 +278,14 @@ void C::g(int x) {}
// LLVM: %[[L4_ARG:.*]] = load i32, ptr
// LLVM: call void @_ZN5Test41C1gEi(ptr{{.*}} %[[L4_ADJ]], i32{{.*}} %[[L4_ARG]])
+// LLVM-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// LLVM: call {{.*}} @_ZN15CovariantReturn1C1fEv
+// LLVM: phi ptr
+
+// LLVM-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+// LLVM: getelementptr i8, ptr {{.*}}, i64 -8
+// LLVM: musttail call void (ptr, i32, ...) @_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+
// --- OGCG checks ---
// OGCG: @_ZTVN5Test11CE = unnamed_addr constant { [3 x ptr], [3 x ptr] } {
@@ -278,3 +333,11 @@ void C::g(int x) {}
// OGCG: %[[O4_ADJ:.*]] = getelementptr inbounds i8, ptr %[[O4_THIS]], i64 -8
// OGCG: %[[O4_ARG:.*]] = load i32, ptr
// OGCG: {{.*}}call void @_ZN5Test41C1gEi(ptr{{.*}} %[[O4_ADJ]], i32{{.*}} %[[O4_ARG]])
+
+// OGCG-LABEL: define {{.*}} @_ZTch0_v0_n32_N15CovariantReturn1C1fEv
+// OGCG: {{.*}}call {{.*}} @_ZN15CovariantReturn1C1fEv
+// OGCG: phi ptr
+
+// OGCG-LABEL: define {{.*}} void @_ZThn8_N11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
+// OGCG: getelementptr inbounds i8, ptr {{.*}}, i64 -8
+// OGCG: musttail call void (ptr, i32, ...) @_ZN11VarargThunk1C1fEiz(ptr{{.*}}, i32{{.*}}, ...)
More information about the cfe-commits
mailing list