[clang] [CIR] Add support for exact dynamic casts (PR #164007)

Andy Kaylor via cfe-commits cfe-commits at lists.llvm.org
Fri Oct 17 12:49:00 PDT 2025


https://github.com/andykaylor created https://github.com/llvm/llvm-project/pull/164007

This adds support for handling exact dynamic casts when optimizations are enabled.

>From a8a601980d45a6e55020527a7a1bd590af2dc574 Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akaylor at nvidia.com>
Date: Thu, 9 Oct 2025 16:18:05 -0700
Subject: [PATCH] [CIR] Add support for exact dynamic casts

This adds support for handling exact dynamic casts when optimizations are
enabled.
---
 clang/lib/CIR/CodeGen/CIRGenCall.cpp          |  16 ++
 clang/lib/CIR/CodeGen/CIRGenFunction.h        |   3 +
 clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp | 162 +++++++++++++++++-
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp |   9 +
 clang/test/CIR/CodeGen/dynamic-cast-exact.cpp | 114 ++++++++++++
 5 files changed, 302 insertions(+), 2 deletions(-)
 create mode 100644 clang/test/CIR/CodeGen/dynamic-cast-exact.cpp

diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index 61072f0883728..88aef89ddd2b9 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -690,6 +690,22 @@ void CallArg::copyInto(CIRGenFunction &cgf, Address addr,
   isUsed = true;
 }
 
+mlir::Value CIRGenFunction::emitRuntimeCall(mlir::Location loc,
+                                            cir::FuncOp callee,
+                                            ArrayRef<mlir::Value> args) {
+  // TODO(cir): set the calling convention to this runtime call.
+  assert(!cir::MissingFeatures::opFuncCallingConv());
+
+  cir::CallOp call = builder.createCallOp(loc, callee, args);
+  assert(call->getNumResults() <= 1 &&
+         "runtime functions have at most 1 result");
+
+  if (call->getNumResults() == 0)
+    return nullptr;
+
+  return call->getResult(0);
+}
+
 void CIRGenFunction::emitCallArg(CallArgList &args, const clang::Expr *e,
                                  clang::QualType argType) {
   assert(argType->isReferenceType() == e->isGLValue() &&
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 3c36f5c697118..84b4ba293b3aa 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -1380,6 +1380,9 @@ class CIRGenFunction : public CIRGenTypeCache {
 
   void emitReturnOfRValue(mlir::Location loc, RValue rv, QualType ty);
 
+  mlir::Value emitRuntimeCall(mlir::Location loc, cir::FuncOp callee,
+                              llvm::ArrayRef<mlir::Value> args = {});
+
   /// Emit the computation of the specified expression of scalar type.
   mlir::Value emitScalarExpr(const clang::Expr *e);
 
diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
index d54d2e9cb29e5..ef91288ab6155 100644
--- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
@@ -1869,6 +1869,15 @@ static cir::FuncOp getBadCastFn(CIRGenFunction &cgf) {
   return cgf.cgm.createRuntimeFunction(fnTy, "__cxa_bad_cast");
 }
 
+static void emitCallToBadCast(CIRGenFunction &cgf, mlir::Location loc) {
+  // TODO(cir): set the calling convention to the runtime function.
+  assert(!cir::MissingFeatures::opFuncCallingConv());
+
+  cgf.emitRuntimeCall(loc, getBadCastFn(cgf));
+  cir::UnreachableOp::create(cgf.getBuilder(), loc);
+  cgf.getBuilder().clearInsertionPoint();
+}
+
 // TODO(cir): This could be shared with classic codegen.
 static CharUnits computeOffsetHint(ASTContext &astContext,
                                    const CXXRecordDecl *src,
@@ -1954,6 +1963,136 @@ static Address emitDynamicCastToVoid(CIRGenFunction &cgf, mlir::Location loc,
   return Address{ptr, src.getAlignment()};
 }
 
+static mlir::Value emitExactDynamicCast(CIRGenItaniumCXXABI &abi,
+                                        CIRGenFunction &cgf, mlir::Location loc,
+                                        QualType srcRecordTy,
+                                        QualType destRecordTy,
+                                        cir::PointerType destCIRTy,
+                                        bool isRefCast, Address src) {
+  // Find all the inheritance paths from SrcRecordTy to DestRecordTy.
+  const CXXRecordDecl *srcDecl = srcRecordTy->getAsCXXRecordDecl();
+  const CXXRecordDecl *destDecl = destRecordTy->getAsCXXRecordDecl();
+  CXXBasePaths paths(/*FindAmbiguities=*/true, /*RecordPaths=*/true,
+                     /*DetectVirtual=*/false);
+  (void)destDecl->isDerivedFrom(srcDecl, paths);
+
+  // Find an offset within `destDecl` where a `srcDecl` instance and its vptr
+  // might appear.
+  std::optional<CharUnits> offset;
+  for (const CXXBasePath &path : paths) {
+    // dynamic_cast only finds public inheritance paths.
+    if (path.Access != AS_public)
+      continue;
+
+    CharUnits pathOffset;
+    for (const CXXBasePathElement &pathElement : path) {
+      // Find the offset along this inheritance step.
+      const CXXRecordDecl *base =
+          pathElement.Base->getType()->getAsCXXRecordDecl();
+      if (pathElement.Base->isVirtual()) {
+        // For a virtual base class, we know that the derived class is exactly
+        // destDecl, so we can use the vbase offset from its layout.
+        const ASTRecordLayout &layout =
+            cgf.getContext().getASTRecordLayout(destDecl);
+        pathOffset = layout.getVBaseClassOffset(base);
+      } else {
+        const ASTRecordLayout &layout =
+            cgf.getContext().getASTRecordLayout(pathElement.Class);
+        pathOffset += layout.getBaseClassOffset(base);
+      }
+    }
+
+    if (!offset) {
+      offset = pathOffset;
+    } else if (offset != pathOffset) {
+      // base appears in at least two different places. Find the most-derived
+      // object and see if it's a DestDecl. Note that the most-derived object
+      // must be at least as aligned as this base class subobject, and must
+      // have a vptr at offset 0.
+      src = emitDynamicCastToVoid(cgf, loc, srcRecordTy, src);
+      srcDecl = destDecl;
+      offset = CharUnits::Zero();
+      break;
+    }
+  }
+
+  CIRGenBuilderTy &builder = cgf.getBuilder();
+
+  if (!offset) {
+    // If there are no public inheritance paths, the cast always fails.
+    mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
+    if (isRefCast) {
+      mlir::Region *currentRegion = builder.getBlock()->getParent();
+      emitCallToBadCast(cgf, loc);
+
+      // The call to bad_cast will terminate the block. Create a new block to
+      // hold any follow up code.
+      builder.createBlock(currentRegion, currentRegion->end());
+    }
+
+    return nullPtrValue;
+  }
+
+  // Compare the vptr against the expected vptr for the destination type at
+  // this offset. Note that we do not know what type src points to in the case
+  // where the derived class multiply inherits from the base class so we can't
+  // use getVTablePtr, so we load the vptr directly instead.
+
+  mlir::Value expectedVPtr =
+      abi.getVTableAddressPoint(BaseSubobject(srcDecl, *offset), destDecl);
+
+  // TODO(cir): handle address space here.
+  assert(!cir::MissingFeatures::addressSpace());
+  mlir::Type vptrTy = expectedVPtr.getType();
+  mlir::Type vptrPtrTy = builder.getPointerTo(vptrTy);
+  Address srcVPtrPtr(builder.createBitcast(src.getPointer(), vptrPtrTy),
+                     src.getAlignment());
+  mlir::Value srcVPtr = builder.createLoad(loc, srcVPtrPtr);
+
+  // TODO(cir): decorate SrcVPtr with TBAA info.
+  assert(!cir::MissingFeatures::opTBAA());
+
+  mlir::Value success =
+      builder.createCompare(loc, cir::CmpOpKind::eq, srcVPtr, expectedVPtr);
+
+  auto emitCastResult = [&] {
+    if (offset->isZero())
+      return builder.createBitcast(src.getPointer(), destCIRTy);
+
+    // TODO(cir): handle address space here.
+    assert(!cir::MissingFeatures::addressSpace());
+    mlir::Type u8PtrTy = builder.getUInt8PtrTy();
+
+    mlir::Value strideToApply =
+        builder.getConstInt(loc, builder.getUInt64Ty(), offset->getQuantity());
+    mlir::Value srcU8Ptr = builder.createBitcast(src.getPointer(), u8PtrTy);
+    mlir::Value resultU8Ptr = cir::PtrStrideOp::create(builder, loc, u8PtrTy,
+                                                       srcU8Ptr, strideToApply);
+    return builder.createBitcast(resultU8Ptr, destCIRTy);
+  };
+
+  if (isRefCast) {
+    mlir::Value failed = builder.createNot(success);
+    cir::IfOp::create(builder, loc, failed, /*withElseRegion=*/false,
+                      [&](mlir::OpBuilder &, mlir::Location) {
+                        emitCallToBadCast(cgf, loc);
+                      });
+    return emitCastResult();
+  }
+
+  return cir::TernaryOp::create(
+             builder, loc, success,
+             [&](mlir::OpBuilder &, mlir::Location) {
+               auto result = emitCastResult();
+               builder.createYield(loc, result);
+             },
+             [&](mlir::OpBuilder &, mlir::Location) {
+               mlir::Value nullPtrValue = builder.getNullPtr(destCIRTy, loc);
+               builder.createYield(loc, nullPtrValue);
+             })
+      .getResult();
+}
+
 static cir::DynamicCastInfoAttr emitDynamicCastInfo(CIRGenFunction &cgf,
                                                     mlir::Location loc,
                                                     QualType srcRecordTy,
@@ -1995,8 +2134,27 @@ mlir::Value CIRGenItaniumCXXABI::emitDynamicCast(CIRGenFunction &cgf,
   // if the dynamic type of the pointer is exactly the destination type.
   if (destRecordTy->getAsCXXRecordDecl()->isEffectivelyFinal() &&
       cgf.cgm.getCodeGenOpts().OptimizationLevel > 0) {
-    cgm.errorNYI(loc, "emitExactDynamicCast");
-    return {};
+    CIRGenBuilderTy &builder = cgf.getBuilder();
+    // If this isn't a reference cast, check the pointer to see if it's null.
+    if (!isRefCast) {
+      mlir::Value srcPtrIsNull = builder.createPtrIsNull(src.getPointer());
+      return cir::TernaryOp::create(
+                 builder, loc, srcPtrIsNull,
+                 [&](mlir::OpBuilder, mlir::Location) {
+                   builder.createYield(
+                       loc, builder.getNullPtr(destCIRTy, loc).getResult());
+                 },
+                 [&](mlir::OpBuilder &, mlir::Location) {
+                   mlir::Value exactCast = emitExactDynamicCast(
+                       *this, cgf, loc, srcRecordTy, destRecordTy, destCIRTy,
+                       isRefCast, src);
+                   builder.createYield(loc, exactCast);
+                 })
+          .getResult();
+    }
+
+    return emitExactDynamicCast(*this, cgf, loc, srcRecordTy, destRecordTy,
+                                destCIRTy, isRefCast, src);
   }
 
   cir::DynamicCastInfoAttr castInfo =
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 0243bf120f396..51dba33338cd6 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -2223,6 +2223,15 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
     return mlir::success();
   }
 
+  if (auto vptrTy = mlir::dyn_cast<cir::VPtrType>(type)) {
+    // !cir.vptr is a special case, but it's just a pointer to LLVM.
+    auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
+                                              /* isSigned=*/false);
+    rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
+        cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
+    return mlir::success();
+  }
+
   if (mlir::isa<cir::FPTypeInterface>(type)) {
     mlir::LLVM::FCmpPredicate kind =
         convertCmpKindToFCmpPredicate(cmpOp.getKind());
diff --git a/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp
new file mode 100644
index 0000000000000..41a70ce53db5e
--- /dev/null
+++ b/clang/test/CIR/CodeGen/dynamic-cast-exact.cpp
@@ -0,0 +1,114 @@
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -clangir-disable-passes -emit-cir -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s --check-prefix=CIR
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -fclangir -emit-llvm -o %t-cir.ll %s
+// RUN: FileCheck --input-file=%t-cir.ll --check-prefix=LLVM %s
+// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -std=c++20 -O1 -emit-llvm -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll --check-prefix=OGCG %s
+
+struct Base1 {
+  virtual ~Base1();
+};
+
+struct Base2 {
+  virtual ~Base2();
+};
+
+struct Derived final : Base1 {};
+
+Derived *ptr_cast(Base1 *ptr) {
+  return dynamic_cast<Derived *>(ptr);
+}
+
+//      CIR: cir.func {{.*}} @_Z8ptr_castP5Base1
+//      CIR:   %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
+// CIR-NEXT:   %[[NULL_PTR:.*]] = cir.const #cir.ptr<null>
+// CIR-NEXT:   %[[SRC_IS_NULL:.*]] = cir.cmp(eq, %[[SRC]], %[[NULL_PTR]])
+// CIR-NEXT:   %[[RESULT:.*]] = cir.ternary(%[[SRC_IS_NULL]], true {
+// CIR-NEXT:     %[[NULL_PTR_DEST:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
+// CIR-NEXT:     cir.yield %[[NULL_PTR_DEST]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT:   }, false {
+// CIR-NEXT:     %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
+// CIR-NEXT:     %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
+// CIR-NEXT:     %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
+// CIR-NEXT:     %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
+// CIR-NEXT:     %[[EXACT_RESULT:.*]] = cir.ternary(%[[SUCCESS]], true {
+// CIR-NEXT:       %[[RES:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
+// CIR-NEXT:       cir.yield %[[RES]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT:     }, false {
+// CIR-NEXT:       %[[NULL:.*]] = cir.const #cir.ptr<null> : !cir.ptr<!rec_Derived>
+// CIR-NEXT:       cir.yield %[[NULL]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT:     }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
+// CIR-NEXT:     cir.yield %[[EXACT_RESULT]] : !cir.ptr<!rec_Derived>
+// CIR-NEXT:     }) : (!cir.bool) -> !cir.ptr<!rec_Derived>
+
+// Note: The LLVM output omits the label for the entry block (which is
+//       implicitly %1), so we use %{{.*}} to match the implicit label in the
+//       phi check.
+
+//      LLVM: define dso_local ptr @_Z8ptr_castP5Base1(ptr{{.*}} %[[SRC:.*]])
+// LLVM-NEXT:   %[[SRC_IS_NULL:.*]] = icmp eq ptr %0, null
+// LLVM-NEXT:   br i1 %[[SRC_IS_NULL]], label %[[LABEL_END:.*]], label %[[LABEL_NOTNULL:.*]]
+//      LLVM: [[LABEL_NOTNULL]]:
+// LLVM-NEXT:   %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
+// LLVM-NEXT:   %[[SUCCESS:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
+// LLVM-NEXT:   %[[EXACT_RESULT:.*]] = select i1 %[[SUCCESS]], ptr %[[SRC]], ptr null
+// LLVM-NEXT:   br label %[[LABEL_END]]
+//      LLVM: [[LABEL_END]]:
+// LLVM-NEXT:   %[[RESULT:.*]] = phi ptr [ %[[EXACT_RESULT]], %[[LABEL_NOTNULL]] ], [ null, %{{.*}} ]
+// LLVM-NEXT:   ret ptr %[[RESULT]]
+// LLVM-NEXT: }
+
+//      OGCG: define{{.*}} ptr @_Z8ptr_castP5Base1(ptr {{.*}} %[[SRC:.*]])
+// OGCG-NEXT: entry:
+// OGCG-NEXT:   %[[NULL_CHECK:.*]] = icmp eq ptr %[[SRC]], null
+// OGCG-NEXT:   br i1 %[[NULL_CHECK]], label %[[LABEL_NULL:.*]], label %[[LABEL_NOTNULL:.*]]
+//      OGCG: [[LABEL_NOTNULL]]:
+// OGCG-NEXT:   %[[VTABLE:.*]] = load ptr, ptr %[[SRC]], align 8
+// OGCG-NEXT:   %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
+// OGCG-NEXT:   br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL]]
+//      OGCG: [[LABEL_NULL]]:
+// OGCG-NEXT:   br label %[[LABEL_END]]
+//      OGCG: [[LABEL_END]]:
+// OGCG-NEXT:   %[[RESULT:.*]] = phi ptr [ %[[SRC]], %[[LABEL_NOTNULL]] ], [ null, %[[LABEL_NULL]] ]
+// OGCG-NEXT:   ret ptr %[[RESULT]]
+// OGCG-NEXT: }
+
+Derived &ref_cast(Base1 &ref) {
+  return dynamic_cast<Derived &>(ref);
+}
+
+//      CIR: cir.func {{.*}} @_Z8ref_castR5Base1
+//      CIR:   %[[SRC:.*]] = cir.load{{.*}} %{{.+}} : !cir.ptr<!cir.ptr<!rec_Base1>>, !cir.ptr<!rec_Base1>
+// CIR-NEXT:   %[[EXPECTED_VPTR:.*]] = cir.vtable.address_point(@_ZTV7Derived, address_point = <index = 0, offset = 2>) : !cir.vptr
+// CIR-NEXT:   %[[SRC_VPTR_PTR:.*]] = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!cir.vptr>
+// CIR-NEXT:   %[[SRC_VPTR:.*]] = cir.load{{.*}} %[[SRC_VPTR_PTR]] : !cir.ptr<!cir.vptr>, !cir.vptr
+// CIR-NEXT:   %[[SUCCESS:.*]] = cir.cmp(eq, %[[SRC_VPTR]], %[[EXPECTED_VPTR]]) : !cir.vptr, !cir.bool
+// CIR-NEXT:   %[[FAILED:.*]] = cir.unary(not, %[[SUCCESS]]) : !cir.bool, !cir.bool
+// CIR-NEXT:   cir.if %[[FAILED]] {
+// CIR-NEXT:     cir.call @__cxa_bad_cast() : () -> ()
+// CIR-NEXT:     cir.unreachable
+// CIR-NEXT:   }
+// CIR-NEXT:   %{{.+}} = cir.cast bitcast %[[SRC]] : !cir.ptr<!rec_Base1> -> !cir.ptr<!rec_Derived>
+
+//      LLVM: define{{.*}} ptr @_Z8ref_castR5Base1(ptr{{.*}} %[[SRC:.*]])
+// LLVM-NEXT:   %[[VPTR:.*]] = load ptr, ptr %[[SRC]], align 8
+// LLVM-NEXT:   %[[OK:.*]] = icmp eq ptr %[[VPTR]], getelementptr inbounds nuw (i8, ptr @_ZTV7Derived, i64 16)
+// LLVM-NEXT:   br i1 %[[OK]], label %[[LABEL_OK:.*]], label %[[LABEL_FAIL:.*]]
+//      LLVM: [[LABEL_FAIL]]:
+// LLVM-NEXT:   tail call void @__cxa_bad_cast()
+// LLVM-NEXT:   unreachable
+//      LLVM: [[LABEL_OK]]:
+// LLVM-NEXT:   ret ptr %[[SRC]]
+// LLVM-NEXT: }
+
+//      OGCG: define{{.*}} ptr @_Z8ref_castR5Base1(ptr {{.*}} %[[REF:.*]])
+// OGCG-NEXT: entry:
+// OGCG-NEXT:   %[[VTABLE:.*]] = load ptr, ptr %[[REF]], align 8
+// OGCG-NEXT:   %[[VTABLE_CHECK:.*]] = icmp eq ptr %[[VTABLE]], getelementptr inbounds {{.*}} (i8, ptr @_ZTV7Derived, i64 16)
+// OGCG-NEXT:   br i1 %[[VTABLE_CHECK]], label %[[LABEL_END:.*]], label %[[LABEL_NULL:.*]]
+//      OGCG: [[LABEL_NULL]]:
+// OGCG-NEXT:   {{.*}}call void @__cxa_bad_cast()
+// OGCG-NEXT:   unreachable
+//      OGCG: [[LABEL_END]]:
+// OGCG-NEXT:   ret ptr %[[REF]]
+// OGCG-NEXT: }



More information about the cfe-commits mailing list