[llvm] [FuncComp] Compare MDNodes in cmpMetadata using cmpMDNode. (PR #128878)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 26 06:10:34 PST 2025


https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/128878

Use cmpMDNode in cmpMetadata to structurally compare MDNodes for metadata arguments. This fixes a mis-compile caused by cmpMetadata incorrectly returning 0 for different nodes.

Note that metadata can contain cycles, so we need to make sure we don't get stuck in an infinite cycle.

>From 87baa8f733f003f19f1915ef7f9738a4a1295120 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 26 Feb 2025 13:53:22 +0000
Subject: [PATCH] [FuncComp] Compare MDNodes in cmpMetadata using cmpMDNode.

Use cmpMDNode in cmpMetadata to structurally compare MDNodes for
metadata arguments. This fixes a mis-compile caused by cmpMetadata
incorrectly returning 0 for different nodes.

Note that metadata can contain cycles, so we need to make sure we don't
get stuck in an infinite cycle.
---
 .../Transforms/Utils/FunctionComparator.h     |  8 ++-
 .../Transforms/Utils/FunctionComparator.cpp   | 69 +++++++++++++++----
 .../MergeFunc/metadata-call-arguments.ll      | 42 +++++------
 3 files changed, 81 insertions(+), 38 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
index 19c5f7449f23e..47871719b497a 100644
--- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
+++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
@@ -328,8 +328,12 @@ class FunctionComparator {
   int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const;
   int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
   int cmpAttrs(const AttributeList L, const AttributeList R) const;
-  int cmpMDNode(const MDNode *L, const MDNode *R) const;
-  int cmpMetadata(const Metadata *L, const Metadata *R) const;
+  int cmpMDNode(const MDNode *L, const MDNode *R,
+                SmallPtrSetImpl<const MDNode *> &SeenL,
+                SmallPtrSetImpl<const MDNode *> &SeenR) const;
+  int cmpMetadata(const Metadata *L, const Metadata *R,
+                  SmallPtrSetImpl<const MDNode *> &SeenL,
+                  SmallPtrSetImpl<const MDNode *> &SeenR) const;
   int cmpInstMetadata(Instruction const *L, Instruction const *R) const;
   int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const;
 
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index 6d4026e8209de..8176518d82011 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -185,8 +185,13 @@ int FunctionComparator::cmpAttrs(const AttributeList L,
   return 0;
 }
 
-int FunctionComparator::cmpMetadata(const Metadata *L,
-                                    const Metadata *R) const {
+int FunctionComparator::cmpMetadata(
+    const Metadata *L, const Metadata *R,
+    SmallPtrSetImpl<const MDNode *> &SeenL,
+    SmallPtrSetImpl<const MDNode *> &SeenR) const {
+  if (L == R)
+    return 0;
+
   // TODO: the following routine coerce the metadata contents into constants
   // or MDStrings before comparison.
   // It ignores any other cases, so that the metadata nodes are considered
@@ -207,22 +212,51 @@ int FunctionComparator::cmpMetadata(const Metadata *L,
 
   auto *CL = dyn_cast<ConstantAsMetadata>(L);
   auto *CR = dyn_cast<ConstantAsMetadata>(R);
-  if (CL == CR)
-    return 0;
-  if (!CL)
+  if (CL && CR) {
+    if (!CL)
+      return -1;
+    if (!CR)
+      return 1;
+    return cmpConstants(CL->getValue(), CR->getValue());
+  }
+
+  auto *NodeL = dyn_cast<const MDNode>(L);
+  auto *NodeR = dyn_cast<const MDNode>(R);
+  if (NodeL && NodeR)
+    return cmpMDNode(NodeL, NodeR, SeenL, SeenR);
+
+  if (NodeR)
     return -1;
-  if (!CR)
+
+  if (NodeL)
     return 1;
-  return cmpConstants(CL->getValue(), CR->getValue());
+
+  assert(false);
+
+  return 0;
 }
 
-int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
+int FunctionComparator::cmpMDNode(
+    const MDNode *L, const MDNode *R, SmallPtrSetImpl<const MDNode *> &SeenL,
+    SmallPtrSetImpl<const MDNode *> &SeenR) const {
   if (L == R)
     return 0;
   if (!L)
     return -1;
   if (!R)
     return 1;
+
+  // Check if we already checked either L or R previosuly. This can be the case
+  // for metadata nodes with cycles.
+  bool AlreadySeenL = !SeenL.insert(L).second;
+  bool AlreadySeenR = !SeenR.insert(R).second;
+  if (AlreadySeenL && AlreadySeenR)
+    return 0;
+  if (AlreadySeenR)
+    return -1;
+  if (AlreadySeenL)
+    return 1;
+
   // TODO: Note that as this is metadata, it is possible to drop and/or merge
   // this data when considering functions to merge. Thus this comparison would
   // return 0 (i.e. equivalent), but merging would become more complicated
@@ -232,7 +266,7 @@ int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
   if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
     return Res;
   for (size_t I = 0; I < L->getNumOperands(); ++I)
-    if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I)))
+    if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I), SeenL, SeenR))
       return Res;
   return 0;
 }
@@ -254,7 +288,10 @@ int FunctionComparator::cmpInstMetadata(Instruction const *L,
     auto const [KeyR, MR] = MDR[I];
     if (int Res = cmpNumbers(KeyL, KeyR))
       return Res;
-    if (int Res = cmpMDNode(ML, MR))
+
+    SmallPtrSet<const MDNode *, 4> SeenL;
+    SmallPtrSet<const MDNode *, 4> SeenR;
+    if (int Res = cmpMDNode(ML, MR, SeenL, SeenR))
       return Res;
   }
   return 0;
@@ -721,8 +758,11 @@ int FunctionComparator::cmpOperations(const Instruction *L,
       if (int Res = cmpNumbers(CI->getTailCallKind(),
                                cast<CallInst>(R)->getTailCallKind()))
         return Res;
+
+    SmallPtrSet<const MDNode *, 4> SeenL;
+    SmallPtrSet<const MDNode *, 4> SeenR;
     return cmpMDNode(L->getMetadata(LLVMContext::MD_range),
-                     R->getMetadata(LLVMContext::MD_range));
+                     R->getMetadata(LLVMContext::MD_range), SeenL, SeenR);
   }
   if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
     ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -895,11 +935,10 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const {
   const MetadataAsValue *MetadataValueL = dyn_cast<MetadataAsValue>(L);
   const MetadataAsValue *MetadataValueR = dyn_cast<MetadataAsValue>(R);
   if (MetadataValueL && MetadataValueR) {
-    if (MetadataValueL == MetadataValueR)
-      return 0;
-
+    SmallPtrSet<const MDNode *, 4> SeenL;
+    SmallPtrSet<const MDNode *, 4> SeenR;
     return cmpMetadata(MetadataValueL->getMetadata(),
-                       MetadataValueR->getMetadata());
+                       MetadataValueR->getMetadata(), SeenL, SeenR);
   }
 
   if (MetadataValueL)
diff --git a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
index 28263741f2cde..10b4c691d314a 100644
--- a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
+++ b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
@@ -1,7 +1,6 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 5
 ; RUN: opt -p mergefunc -S %s | FileCheck %s
 
-; FIXME: Should not be merged with @call_mdtuple_arg_not_equal_2.
 define i64 @call_mdtuple_arg_not_equal_1() {
   %r = call i64 @llvm.read_volatile_register.i64(metadata !0)
   ret i64 %r
@@ -22,7 +21,6 @@ define i64 @call_mdtuple_arg_with_cycle_equal_2() {
   ret i64 %r
 }
 
-; FIXME: Should not be merged with @call_mdtuple_arg_with_cycle_not_equal_2.
 define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
   %r = call i64 @llvm.read_volatile_register.i64(metadata !3)
   ret i64 %r
@@ -55,48 +53,50 @@ declare i64 @llvm.read_volatile_register.i64(metadata)
 
 !5 = !{!"foo", i64 10}
 !6 = !{!"foo", i64 10}
+; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+;
+; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
+; CHECK-NEXT:    ret i64 [[TMP1]]
+;
+;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_1() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_2() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_2() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META4:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_1() {
-; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
+; CHECK-NEXT:    [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META5:![0-9]+]])
 ; CHECK-NEXT:    ret i64 [[R]]
 ;
 ;
-; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
-; CHECK-NEXT:    ret i64 [[TMP1]]
-;
-;
-; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
-; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
-; CHECK-NEXT:    ret i64 [[TMP1]]
-;
-;
 ; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_2() {
 ; CHECK-NEXT:    [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
 ; CHECK-NEXT:    ret i64 [[TMP1]]
 ;
 ;.
-; CHECK: [[META0]] = distinct !{[[META0]], !"foo"}
-; CHECK: [[META1]] = distinct !{[[META1]], !"foo"}
-; CHECK: [[META2]] = distinct !{[[META2]], !"bar"}
-; CHECK: [[META3]] = !{!"foo", i64 10}
+; CHECK: [[META0]] = !{!"foo"}
+; CHECK: [[META1]] = !{!"bar"}
+; CHECK: [[META2]] = distinct !{[[META2]], !"foo"}
+; CHECK: [[META3]] = distinct !{[[META3]], !"foo"}
+; CHECK: [[META4]] = distinct !{[[META4]], !"bar"}
+; CHECK: [[META5]] = !{!"foo", i64 10}
 ;.



More information about the llvm-commits mailing list