[llvm] [FuncComp] Compare MDNodes in cmpMetadata using cmpMDNode. (PR #128878)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Wed Feb 26 06:10:34 PST 2025
https://github.com/fhahn created https://github.com/llvm/llvm-project/pull/128878
Use cmpMDNode in cmpMetadata to structurally compare MDNodes for metadata arguments. This fixes a mis-compile caused by cmpMetadata incorrectly returning 0 for different nodes.
Note that metadata can contain cycles, so we need to make sure we don't get stuck in an infinite cycle.
>From 87baa8f733f003f19f1915ef7f9738a4a1295120 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Wed, 26 Feb 2025 13:53:22 +0000
Subject: [PATCH] [FuncComp] Compare MDNodes in cmpMetadata using cmpMDNode.
Use cmpMDNode in cmpMetadata to structurally compare MDNodes for
metadata arguments. This fixes a mis-compile caused by cmpMetadata
incorrectly returning 0 for different nodes.
Note that metadata can contain cycles, so we need to make sure we don't
get stuck in an infinite cycle.
---
.../Transforms/Utils/FunctionComparator.h | 8 ++-
.../Transforms/Utils/FunctionComparator.cpp | 69 +++++++++++++++----
.../MergeFunc/metadata-call-arguments.ll | 42 +++++------
3 files changed, 81 insertions(+), 38 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
index 19c5f7449f23e..47871719b497a 100644
--- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
+++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h
@@ -328,8 +328,12 @@ class FunctionComparator {
int cmpOrderings(AtomicOrdering L, AtomicOrdering R) const;
int cmpInlineAsm(const InlineAsm *L, const InlineAsm *R) const;
int cmpAttrs(const AttributeList L, const AttributeList R) const;
- int cmpMDNode(const MDNode *L, const MDNode *R) const;
- int cmpMetadata(const Metadata *L, const Metadata *R) const;
+ int cmpMDNode(const MDNode *L, const MDNode *R,
+ SmallPtrSetImpl<const MDNode *> &SeenL,
+ SmallPtrSetImpl<const MDNode *> &SeenR) const;
+ int cmpMetadata(const Metadata *L, const Metadata *R,
+ SmallPtrSetImpl<const MDNode *> &SeenL,
+ SmallPtrSetImpl<const MDNode *> &SeenR) const;
int cmpInstMetadata(Instruction const *L, Instruction const *R) const;
int cmpOperandBundlesSchema(const CallBase &LCS, const CallBase &RCS) const;
diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
index 6d4026e8209de..8176518d82011 100644
--- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp
+++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp
@@ -185,8 +185,13 @@ int FunctionComparator::cmpAttrs(const AttributeList L,
return 0;
}
-int FunctionComparator::cmpMetadata(const Metadata *L,
- const Metadata *R) const {
+int FunctionComparator::cmpMetadata(
+ const Metadata *L, const Metadata *R,
+ SmallPtrSetImpl<const MDNode *> &SeenL,
+ SmallPtrSetImpl<const MDNode *> &SeenR) const {
+ if (L == R)
+ return 0;
+
// TODO: the following routine coerce the metadata contents into constants
// or MDStrings before comparison.
// It ignores any other cases, so that the metadata nodes are considered
@@ -207,22 +212,51 @@ int FunctionComparator::cmpMetadata(const Metadata *L,
auto *CL = dyn_cast<ConstantAsMetadata>(L);
auto *CR = dyn_cast<ConstantAsMetadata>(R);
- if (CL == CR)
- return 0;
- if (!CL)
+ if (CL && CR) {
+ if (!CL)
+ return -1;
+ if (!CR)
+ return 1;
+ return cmpConstants(CL->getValue(), CR->getValue());
+ }
+
+ auto *NodeL = dyn_cast<const MDNode>(L);
+ auto *NodeR = dyn_cast<const MDNode>(R);
+ if (NodeL && NodeR)
+ return cmpMDNode(NodeL, NodeR, SeenL, SeenR);
+
+ if (NodeR)
return -1;
- if (!CR)
+
+ if (NodeL)
return 1;
- return cmpConstants(CL->getValue(), CR->getValue());
+
+ assert(false);
+
+ return 0;
}
-int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
+int FunctionComparator::cmpMDNode(
+ const MDNode *L, const MDNode *R, SmallPtrSetImpl<const MDNode *> &SeenL,
+ SmallPtrSetImpl<const MDNode *> &SeenR) const {
if (L == R)
return 0;
if (!L)
return -1;
if (!R)
return 1;
+
+ // Check if we already checked either L or R previosuly. This can be the case
+ // for metadata nodes with cycles.
+ bool AlreadySeenL = !SeenL.insert(L).second;
+ bool AlreadySeenR = !SeenR.insert(R).second;
+ if (AlreadySeenL && AlreadySeenR)
+ return 0;
+ if (AlreadySeenR)
+ return -1;
+ if (AlreadySeenL)
+ return 1;
+
// TODO: Note that as this is metadata, it is possible to drop and/or merge
// this data when considering functions to merge. Thus this comparison would
// return 0 (i.e. equivalent), but merging would become more complicated
@@ -232,7 +266,7 @@ int FunctionComparator::cmpMDNode(const MDNode *L, const MDNode *R) const {
if (int Res = cmpNumbers(L->getNumOperands(), R->getNumOperands()))
return Res;
for (size_t I = 0; I < L->getNumOperands(); ++I)
- if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I)))
+ if (int Res = cmpMetadata(L->getOperand(I), R->getOperand(I), SeenL, SeenR))
return Res;
return 0;
}
@@ -254,7 +288,10 @@ int FunctionComparator::cmpInstMetadata(Instruction const *L,
auto const [KeyR, MR] = MDR[I];
if (int Res = cmpNumbers(KeyL, KeyR))
return Res;
- if (int Res = cmpMDNode(ML, MR))
+
+ SmallPtrSet<const MDNode *, 4> SeenL;
+ SmallPtrSet<const MDNode *, 4> SeenR;
+ if (int Res = cmpMDNode(ML, MR, SeenL, SeenR))
return Res;
}
return 0;
@@ -721,8 +758,11 @@ int FunctionComparator::cmpOperations(const Instruction *L,
if (int Res = cmpNumbers(CI->getTailCallKind(),
cast<CallInst>(R)->getTailCallKind()))
return Res;
+
+ SmallPtrSet<const MDNode *, 4> SeenL;
+ SmallPtrSet<const MDNode *, 4> SeenR;
return cmpMDNode(L->getMetadata(LLVMContext::MD_range),
- R->getMetadata(LLVMContext::MD_range));
+ R->getMetadata(LLVMContext::MD_range), SeenL, SeenR);
}
if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(L)) {
ArrayRef<unsigned> LIndices = IVI->getIndices();
@@ -895,11 +935,10 @@ int FunctionComparator::cmpValues(const Value *L, const Value *R) const {
const MetadataAsValue *MetadataValueL = dyn_cast<MetadataAsValue>(L);
const MetadataAsValue *MetadataValueR = dyn_cast<MetadataAsValue>(R);
if (MetadataValueL && MetadataValueR) {
- if (MetadataValueL == MetadataValueR)
- return 0;
-
+ SmallPtrSet<const MDNode *, 4> SeenL;
+ SmallPtrSet<const MDNode *, 4> SeenR;
return cmpMetadata(MetadataValueL->getMetadata(),
- MetadataValueR->getMetadata());
+ MetadataValueR->getMetadata(), SeenL, SeenR);
}
if (MetadataValueL)
diff --git a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
index 28263741f2cde..10b4c691d314a 100644
--- a/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
+++ b/llvm/test/Transforms/MergeFunc/metadata-call-arguments.ll
@@ -1,7 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --include-generated-funcs --version 5
; RUN: opt -p mergefunc -S %s | FileCheck %s
-; FIXME: Should not be merged with @call_mdtuple_arg_not_equal_2.
define i64 @call_mdtuple_arg_not_equal_1() {
%r = call i64 @llvm.read_volatile_register.i64(metadata !0)
ret i64 %r
@@ -22,7 +21,6 @@ define i64 @call_mdtuple_arg_with_cycle_equal_2() {
ret i64 %r
}
-; FIXME: Should not be merged with @call_mdtuple_arg_with_cycle_not_equal_2.
define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
%r = call i64 @llvm.read_volatile_register.i64(metadata !3)
ret i64 %r
@@ -55,48 +53,50 @@ declare i64 @llvm.read_volatile_register.i64(metadata)
!5 = !{!"foo", i64 10}
!6 = !{!"foo", i64 10}
+; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
+; CHECK-NEXT: ret i64 [[TMP1]]
+;
+;
+; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
+; CHECK-NEXT: ret i64 [[TMP1]]
+;
+;
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_1() {
-; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META0:![0-9]+]])
+; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
; CHECK-NEXT: ret i64 [[R]]
;
;
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_equal_2() {
-; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1:![0-9]+]])
+; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
; CHECK-NEXT: ret i64 [[R]]
;
;
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_1() {
-; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META1]])
+; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3]])
; CHECK-NEXT: ret i64 [[R]]
;
;
; CHECK-LABEL: define i64 @call_mdtuple_arg_with_cycle_not_equal_2() {
-; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META2:![0-9]+]])
+; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META4:![0-9]+]])
; CHECK-NEXT: ret i64 [[R]]
;
;
; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_1() {
-; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META3:![0-9]+]])
+; CHECK-NEXT: [[R:%.*]] = call i64 @llvm.read_volatile_register.i64(metadata [[META5:![0-9]+]])
; CHECK-NEXT: ret i64 [[R]]
;
;
-; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_2() {
-; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
-; CHECK-NEXT: ret i64 [[TMP1]]
-;
-;
-; CHECK-LABEL: define i64 @call_mdtuple_arg_not_equal_1() {
-; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
-; CHECK-NEXT: ret i64 [[TMP1]]
-;
-;
; CHECK-LABEL: define i64 @call_mdtuple_arg_equal_2() {
; CHECK-NEXT: [[TMP1:%.*]] = tail call i64 @call_mdtuple_arg_equal_1()
; CHECK-NEXT: ret i64 [[TMP1]]
;
;.
-; CHECK: [[META0]] = distinct !{[[META0]], !"foo"}
-; CHECK: [[META1]] = distinct !{[[META1]], !"foo"}
-; CHECK: [[META2]] = distinct !{[[META2]], !"bar"}
-; CHECK: [[META3]] = !{!"foo", i64 10}
+; CHECK: [[META0]] = !{!"foo"}
+; CHECK: [[META1]] = !{!"bar"}
+; CHECK: [[META2]] = distinct !{[[META2]], !"foo"}
+; CHECK: [[META3]] = distinct !{[[META3]], !"foo"}
+; CHECK: [[META4]] = distinct !{[[META4]], !"bar"}
+; CHECK: [[META5]] = !{!"foo", i64 10}
;.
More information about the llvm-commits
mailing list