[clang] [CIR] Add support for exact dynamic casts (PR #164007)
via cfe-commits
cfe-commits at lists.llvm.org
Fri Oct 17 12:49:31 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-clang
Author: Andy Kaylor (andykaylor)
<details>
<summary>Changes</summary>
This adds support for handling exact dynamic casts when optimizations are enabled.
---
Full diff: https://github.com/llvm/llvm-project/pull/164007.diff
5 Files Affected:
- (modified) clang/lib/CIR/CodeGen/CIRGenCall.cpp (+16)
- (modified) clang/lib/CIR/CodeGen/CIRGenFunction.h (+3)
- (modified) clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp (+160-2)
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+9)
- (added) clang/test/CIR/CodeGen/dynamic-cast-exact.cpp (+114)
``````````diff
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 61072f0883728..88aef89ddd2b9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr,
isUsed = true;
}
+mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc,
+ cir::FuncOp callee,
+ ArrayRef<mlir::Value> args) {
+ // TODO(cir): set the calling convention to this runtime call.
+ assert(!cir::MissingFeatures::opFuncCallingConv());
+
+ cir::CallOp call = builder.createCallOp(loc, callee, args);
+ assert(call->getNumResults() <= 1 &&
+ "runtime functions have at most 1 result");
+
+ if (call->getNumResults() == 0)
+ return nullptr;
+
+ return call->getResult(0);
+}
+
void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
clang::QualType argType) {
assert(argType->isReferenceType() == e->isGLValue() &&
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 3c36f5c697118..84b4ba293b3aa 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache {
void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty);
+ mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee,
+ llvm::ArrayRef<mlir::Value> args = {});
+
/// Emit the computation of the specified expression of scalar type.
mlir::Value emitScalarExpr(const clang::Expr *e);
diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
index d54d2e9cb29e5..ef91288ab6155 100644
--- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
@@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) {
return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast");
}
+static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) {
+ // TODO(cir): set the calling convention to the runtime function.
+ assert(!cir::MissingFeatures::opFuncCallingConv());
+
+ cgf.emitRuntimeCall(loc, getBadCastFn(cgf));
+ cir::UnreachableOp::create(cgf.getBuilder(), loc);
+ cgf.getBuilder().clearInsertionPoint();
+}
+
// TODO(cir): This could be shared with classic codegen.
static CharUnits computeOffsetHint(ASTContext &astContext,
const CXXRecordDecl *src,
@@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc,
return Address{ptr, src.getAlignment()};
}
+static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi,
+ CIRGenFunction &cgf, mlir::Location loc,
+ QualType srcRecordTy,
+ QualType destRecordTy,
+ cir::PointerType destCIRTy,
+ bool isRefCast, Address src) {
+ // Find all the inheritance paths from SrcRecordTy to DestRecordTy.
+ const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl();
+ const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl();
+ CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
+ /*DetectVirtual=*/false);
+ (void)destDecl->isDerivedFrom(srcDecl, paths);
+
+ // Find an offset within `destDecl` where a `srcDecl` instance and its vptr
+ // might appear.
+ std::optional<CharUnits> offset;
+ for (const CXXBasePath &path : paths) {
+ // dynamic_cast only finds public inheritance paths.
+ if (path.Access != AS_public)
+ continue;
+
+ CharUnits pathOffset;
+ for (const CXXBasePathElement &pathElement : path) {
+ // Find the offset along this inheritance step.
+ const CXXRecordDecl *base =
+ pathElement.Base->getType()->getAsCXXRecordDecl();
+ if (pathElement.Base->isVirtual()) {
+ // For a virtual base class, we know that the derived class is exactly
+ // destDecl, so we can use the vbase offset from its layout.
+ const ASTRecordLayout &layout =
+ cgf.getContext().getASTRecordLayout(destDecl);
+ pathOffset = layout.getVBaseClassOffset(base);
+ } else {
+ const ASTRecordLayout &layout =
+ cgf.getContext().getASTRecordLayout(pathElement.Class);
+ pathOffset += layout.getBaseClassOffset(base);
+ }
+ }
+
+ if (!offset) {
+ offset = pathOffset;
+ } else if (offset != pathOffset) {
+ // base appears in at least two different places. Find the most-derived
+ // object and see if it's a DestDecl. Note that the most-derived object
+ // must be at least as aligned as this base class subobject, and must
+ // have a vptr at offset 0.
+ src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src);
+ srcDecl = destDecl;
+ offset = CharUnits::Zero();
+ break;
+ }
+ }
+
+ CIRGenBuilderTy &builder = cgf.getBuilder();
+
+ if (!offset) {
+ // If there are no public inheritance paths, the cast always fails.
+ mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
+ if (isRefCast) {
+ mlir::Region *currentRegion = builder.getBlock()->getParent();
+ emitCallToBadCast(cgf, loc);
+
+ // The call to bad_cast will terminate the block. Create a new block to
+ // hold any follow up code.
+ builder.createBlock(currentRegion, currentRegion->end());
+ }
+
+ return nullPtrValue;
+ }
+
+ // Compare the vptr against the expected vptr for the destination type at
+ // this offset. Note that we do not know what type src points to in the case
+ // where the derived class multiply inherits from the base class so we can't
+ // use getVTablePtr, so we load the vptr directly instead.
+
+ mlir::Value expectedVPtr =
+ abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl);
+
+ // TODO(cir): handle address space here.
+ assert(!cir::MissingFeatures::addressSpace());
+ mlir::Type vptrTy = expectedVPtr.getType();
+ mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy);
+ Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy),
+ src.getAlignment());
+ mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr);
+
+ // TODO(cir): decorate SrcVPtr with TBAA info.
+ assert(!cir::MissingFeatures::opTBAA());
+
+ mlir::Value success =
+ builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr);
+
+ auto emitCastResult = [&] {
+ if (offset->isZero())
+ return builder.createBitcast(src.getPointer(), destCIRTy);
+
+ // TODO(cir): handle address space here.
+ assert(!cir::MissingFeatures::addressSpace());
+ mlir::Type u8PtrTy = builder.getUInt8PtrTy();
+
+ mlir::Value strideToApply =
+ builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity());
+ mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy);
+ mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy,
+ srcU8Ptr, strideToApply);
+ return builder.createBitcast(resultU8Ptr, destCIRTy);
+ };
+
+ if (isRefCast) {
+ mlir::Value failed = builder.createNot(success);
+ cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false,
+ [&](mlir::OpBuilder &, mlir::Location) {
+ emitCallToBadCast(cgf, loc);
+ });
+ return emitCastResult();
+ }
+
+ return cir::TernaryOp::create(
+ builder, loc, success,
+ [&](mlir::OpBuilder &, mlir::Location) {
+ auto result = emitCastResult();
+ builder.createYield(loc, result);
+ },
+ [&](mlir::OpBuilder &, mlir::Location) {
+ mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
+ builder.createYield(loc, nullPtrValue);
+ })
+ .getResult();
+}
+
static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf,
mlir::Location loc,
QualType srcRecordTy,
@@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf,
// if the dynamic type of the pointer is exactly the destination type.
if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) {
- cgm.errorNYI(loc, "emitExactDynamicCast");
- return {};
+ CIRGenBuilderTy &builder = cgf.getBuilder();
+ // If this isn't a reference cast, check the pointer to see if it's null.
+ if (!isRefCast) {
+ mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer());
+ return cir::TernaryOp::create(
+ builder, loc, srcPtrIsNull,
+ [&](mlir::OpBuilder, mlir::Location) {
+ builder.createYield(
+ loc, builder.getNullPtr(destCIRTy, loc).getResult());
+ },
+ [&](mlir::OpBuilder &, mlir::Location) {
+ mlir::Value exactCast = emitExactDynamicCast(
+ *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy,
+ isRefCast, src);
+ builder.createYield(loc, exactCast);
+ })
+ .getResult();
+ }
+
+ return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy,
+ destCIRTy, isRefCast, src);
}
cir::DynamicCastInfoAttr castInfo =
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 0243bf120f396..51dba33338cd6 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
return mlir::success();
}
+ if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) {
+ // !cir.vptr is a special case, but it's just a pointer to LLVM.
+ auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
+ /* isSigned=*/false);
+ rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
+ cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
+ return mlir::success();
+ }
+
if (mlir::isa<cir::FPTypeInterface>(type)) {
mlir::LLVM::FCmpPredicate kind =
convertCmpKindToFCmpPredicate(cmpOp.getKind());
diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp
new file mode 100644
index 0000000000000..41a70ce53db5e
--- /dev/null
+++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp
@@ -0,0 +1,114 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s
+// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s
+
+struct Base1 {
+ virtual ~Base1();
+};
+
+struct Base2 {
+ virtual ~Base2();
+};
+
+struct Derived final : Base1 {};
+
+Derived *ptr_cast(Base1 *ptr) {
+ return dynamic_cast<Derived *>(ptr);
+}
+
+// CIR: cir.func {{.*}} @_Z8ptr_castP5Base1
+// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
+// CIR-NEXT: %[[NULL_PTR:.*]] = cir.const #cir.ptr<null>
+// CIR-NEXT: %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]])
+// CIR-NEXT: %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true {
+// CIR-NEXT: %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
+// CIR-NEXT: cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT: }, false {
+// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
+// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
+// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
+// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
+// CIR-NEXT: %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true {
+// CIR-NEXT: %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
+// CIR-NEXT: cir.yield %[[RES]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT: }, false {
+// CIR-NEXT: %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
+// CIR-NEXT: cir.yield %[[NULL]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
+// CIR-NEXT: cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT: }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
+
+// Note: The LLVM output omits the label for the entry block (which is
+// implicitly %1), so we use %{{.*}} to match the implicit label in the
+// phi check.
+
+// LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]])
+// LLVM-NEXT: %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null
+// LLVM-NEXT: br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]]
+// LLVM: [[LABEL_NOTNULL]]:
+// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
+// LLVM-NEXT: %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
+// LLVM-NEXT: %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null
+// LLVM-NEXT: br label %[[LABEL_END]]
+// LLVM: [[LABEL_END]]:
+// LLVM-NEXT: %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ]
+// LLVM-NEXT: ret ptr %[[RESULT]]
+// LLVM-NEXT: }
+
+// OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]])
+// OGCG-NEXT: entry:
+// OGCG-NEXT: %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null
+// OGCG-NEXT: br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
+// OGCG: [[LABEL_NOTNULL]]:
+// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8
+// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
+// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]]
+// OGCG: [[LABEL_NULL]]:
+// OGCG-NEXT: br label %[[LABEL_END]]
+// OGCG: [[LABEL_END]]:
+// OGCG-NEXT: %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
+// OGCG-NEXT: ret ptr %[[RESULT]]
+// OGCG-NEXT: }
+
+Derived &ref_cast(Base1 &ref) {
+ return dynamic_cast<Derived &>(ref);
+}
+
+// CIR: cir.func {{.*}} @_Z8ref_castR5Base1
+// CIR: %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
+// CIR-NEXT: %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
+// CIR-NEXT: %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
+// CIR-NEXT: %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
+// CIR-NEXT: %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
+// CIR-NEXT: %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool
+// CIR-NEXT: cir.if %[[FAILED]] {
+// CIR-NEXT: cir.call @__cxa_bad_cast() : () -> ()
+// CIR-NEXT: cir.unreachable
+// CIR-NEXT: }
+// CIR-NEXT: %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
+
+// LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]])
+// LLVM-NEXT: %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
+// LLVM-NEXT: %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
+// LLVM-NEXT: br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]]
+// LLVM: [[LABEL_FAIL]]:
+// LLVM-NEXT: tail call void @__cxa_bad_cast()
+// LLVM-NEXT: unreachable
+// LLVM: [[LABEL_OK]]:
+// LLVM-NEXT: ret ptr %[[SRC]]
+// LLVM-NEXT: }
+
+// OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]])
+// OGCG-NEXT: entry:
+// OGCG-NEXT: %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8
+// OGCG-NEXT: %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
+// OGCG-NEXT: br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]]
+// OGCG: [[LABEL_NULL]]:
+// OGCG-NEXT: {{.*}}call void @__cxa_bad_cast()
+// OGCG-NEXT: unreachable
+// OGCG: [[LABEL_END]]:
+// OGCG-NEXT: ret ptr %[[REF]]
+// OGCG-NEXT: }
``````````
</details>
https://github.com/llvm/llvm-project/pull/164007
More information about the cfe-commits
mailing list