[llvm] [FuncComp] Compare MDNodes in cmpMetadata using cmpMDNode. (PR #128878)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sat Mar 22 14:32:07 PDT 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/128878

>From 3ad4cccacdd40055eb3a935104b233a139e04b5e Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 22 Mar 2025 12:22:26 +0000
Subject: [PATCH] [FuncComp]

---
 .../Transforms/Utils/FunctionComparator.h     |  5 +-
 llvm/lib/IR/Verifier.cpp                      | 38 ++++++++++++
 .../Transforms/Utils/FunctionComparator.cpp   | 61 +++++++++++++------
 .../MergeFunc/metadata-call-arguments.ll      | 40 ++++++------
 4 files changed, 106 insertions(+), 38 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
index 19c5f7449f23e..6035692c86218 100644
--- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
+++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
@@ -328,8 +328,9 @@ class FunctionComparator {
   int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const;
   int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
   int cmpAttrs(const AttributeList L, const AttributeList R) const;
-  int cmpMDNode(const MDNode *L, const MDNode *R) const;
-  int cmpMetadata(const Metadata *L, const Metadata *R) const;
+  int cmpMDNode(const MDNode *L, const MDNode *R, bool InValueContext) const;
+  int cmpMetadata(const Metadata *L, const Metadata *R,
+                  bool InValueContext) const;
   int cmpInstMetadata(Instruction const *L, Instruction const *R) const;
   int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const;
 
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5934f7adffb93..49d78f184191e 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -3543,6 +3543,31 @@ void Verifier::visitPHINode(PHINode &PN) {
   visitInstruction(PN);
 }
 
+/// Returns true of \p MD is valid for as a metadata argument. It must be on of
+/// the following
+/// * a MDNode without cycles (expect self-reference in the first operand),
+/// * MDString,
+/// * ValueAsMetadata.
+static bool isValidMetadataArgument(const Metadata *MD,
+                                    SmallPtrSetImpl<const Metadata *> &Seen) {
+  // Potential cycles are not allowed.
+  if (!Seen.insert(MD).second)
+    return false;
+
+  if (auto *Node = dyn_cast<MDNode>(MD)) {
+    if (Node->getNumOperands() == 0)
+      return true;
+    ArrayRef<MDOperand> Ops = Node->operands();
+    if (Node->getOperand(0) == Node)
+      Ops = Ops.drop_front();
+    return all_of(Ops, [&](const Metadata *MD) {
+      return MD && isValidMetadataArgument(MD, Seen);
+    });
+  }
+
+  return isa<MDString>(MD) || isa<ValueAsMetadata>(MD);
+}
+
 void Verifier::visitCallBase(CallBase &Call) {
   Check(Call.getCalledOperand()->getType()->isPointerTy(),
         "Called function must be a pointer!", Call);
@@ -3562,6 +3587,19 @@ void Verifier::visitCallBase(CallBase &Call) {
           "Call parameter type does not match function signature!",
           Call.getArgOperand(i), FTy->getParamType(i), Call);
 
+  // Verify metadata arguments.
+  for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
+    auto Arg = Call.getArgOperand(i);
+    if (!Arg->getType()->isMetadataTy() || isa<DbgInfoIntrinsic>(Call))
+      continue;
+    SmallPtrSet<const Metadata *, 4> Seen;
+    Check(isValidMetadataArgument(cast<MetadataAsValue>(Arg)->getMetadata(),
+                                  Seen),
+          "Function arguments must be string metadata, value-as-metadata or an "
+          "MDNode!",
+          Call);
+  }
+
   AttributeList Attrs = Call.getAttributes();
 
   Check(verifyAttributeCount(Attrs, Call.arg_size()),
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index 6d4026e8209de..b9d8762d9d4a6 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -185,21 +185,21 @@ int FunctionComparator::cmpAttrs(const AttributeList L,
   return 0;
 }
 
-int FunctionComparator::cmpMetadata(const Metadata *L,
-                                    const Metadata *R) const {
+int FunctionComparator::cmpMetadata(const Metadata *L, const Metadata *R,
+                                    bool InValueContext) const {
   // TODO: the following routine coerce the metadata contents into constants
   // or MDStrings before comparison.
   // It ignores any other cases, so that the metadata nodes are considered
   // equal even though this is not correct.
   // We should structurally compare the metadata nodes to be perfect here.
 
+  if (L == R)
+    return 0;
+
   auto *MDStringL = dyn_cast<MDString>(L);
   auto *MDStringR = dyn_cast<MDString>(R);
-  if (MDStringL && MDStringR) {
-    if (MDStringL == MDStringR)
-      return 0;
+  if (MDStringL && MDStringR)
     return MDStringL->getString().compare(MDStringR->getString());
-  }
   if (MDStringR)
     return -1;
   if (MDStringL)
@@ -207,16 +207,31 @@ int FunctionComparator::cmpMetadata(const Metadata *L,
 
   auto *CL = dyn_cast<ConstantAsMetadata>(L);
   auto *CR = dyn_cast<ConstantAsMetadata>(R);
-  if (CL == CR)
-    return 0;
-  if (!CL)
+  if (CL && CR)
+    return cmpConstants(CL->getValue(), CR->getValue());
+  if (CR)
     return -1;
-  if (!CR)
+  if (CL)
     return 1;
-  return cmpConstants(CL->getValue(), CR->getValue());
+
+  auto *NodeL = dyn_cast<MDNode>(L);
+  auto *NodeR = dyn_cast<MDNode>(R);
+  if (NodeL && NodeR) {
+    if (InValueContext)
+      return cmpMDNode(NodeL, NodeR, InValueContext);
+  } else {
+    if (NodeR)
+      return -1;
+    if (NodeL)
+      return 1;
+  }
+  assert(!InValueContext &&
+         "all cases must be handled when comparing metadata arguments");
+  return 0;
 }
 
-int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
+int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R,
+                                  bool InValueContext) const {
   if (L == R)
     return 0;
   if (!L)
@@ -231,8 +246,20 @@ int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
   // function semantically.
   if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
     return Res;
-  for (size_t I = 0; I < L->getNumOperands(); ++I)
-    if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I)))
+
+  size_t StartIdx = 0;
+  if (L->getNumOperands() > 0) {
+    if (L->getOperand(0) == L) {
+      if (R->getOperand(0) != R)
+        return -1;
+      StartIdx = 1;
+    } else if (R->getOperand(0) == R)
+      return 1;
+  }
+
+  for (size_t I = StartIdx; I < L->getNumOperands(); ++I)
+    if (int Res =
+            cmpMetadata(L->getOperand(I), R->getOperand(I), InValueContext))
       return Res;
   return 0;
 }
@@ -254,7 +281,7 @@ int FunctionComparator::cmpInstMetadata(Instruction const *L,
     auto const [KeyR, MR] = MDR[I];
     if (int Res = cmpNumbers(KeyL, KeyR))
       return Res;
-    if (int Res = cmpMDNode(ML, MR))
+    if (int Res = cmpMDNode(ML, MR, false))
       return Res;
   }
   return 0;
@@ -722,7 +749,7 @@ int FunctionComparator::cmpOperations(const Instruction *L,
                                cast<CallInst>(R)->getTailCallKind()))
         return Res;
     return cmpMDNode(L->getMetadata(LLVMContext::MD_range),
-                     R->getMetadata(LLVMContext::MD_range));
+                     R->getMetadata(LLVMContext::MD_range), false);
   }
   if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
     ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -899,7 +926,7 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const {
       return 0;
 
     return cmpMetadata(MetadataValueL->getMetadata(),
-                       MetadataValueR->getMetadata());
+                       MetadataValueR->getMetadata(), true);
   }
 
   if (MetadataValueL)
diff --git a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
index 28263741f2cde..875160090bed5 100644
--- a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
+++ b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
@@ -55,48 +55,50 @@ declare i64 @llvm.read_volatile_register.i64(metadata)
 
 !5 = !{!"foo", i64 10}
 !6 = !{!"foo", i64 10}
+; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+;
+; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_1() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_2() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_2() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META4:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_1() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META5:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
-; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
-; CHECK-NEXT:    ret i64 [[TMP1]]
-;
-;
-; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
-; CHECK-NEXT:    ret i64 [[TMP1]]
-;
-;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_2() {
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
 ; CHECK-NEXT:    ret i64 [[TMP1]]
 ;
 ;.
-; CHECK: [[META0]] = distinct !{[[META0]], !"foo"}
-; CHECK: [[META1]] = distinct !{[[META1]], !"foo"}
-; CHECK: [[META2]] = distinct !{[[META2]], !"bar"}
-; CHECK: [[META3]] = !{!"foo", i64 10}
+; CHECK: [[META0]] = !{!"foo"}
+; CHECK: [[META1]] = !{!"bar"}
+; CHECK: [[META2]] = distinct !{[[META2]], !"foo"}
+; CHECK: [[META3]] = distinct !{[[META3]], !"foo"}
+; CHECK: [[META4]] = distinct !{[[META4]], !"bar"}
+; CHECK: [[META5]] = !{!"foo", i64 10}
 ;.



More information about the llvm-commits mailing list