[llvm] [SampleProfileLoader] Fix integer overflow in generateMDProfMetadata (PR #90217)

via llvm-commits llvm-commits at lists.llvm.org
Fri Apr 26 07:24:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Nabeel Omer (omern1)

<details>
<summary>Changes</summary>

This patch fixes an integer overflow in the SampleProfileLoader pass.
The issue occurs when weights are saturated and Profi isn't being used.

This patch also adds a newline to a debug message to make it more
readable.

---
Full diff: https://github.com/llvm/llvm-project/pull/90217.diff


3 Files Affected:

- (modified) llvm/lib/Transforms/IPO/SampleProfile.cpp (+5-2) 
- (added) llvm/test/Transforms/SampleProfile/Inputs/overflow.proftext (+2) 
- (added) llvm/test/Transforms/SampleProfile/overflow.ll (+70) 


``````````diff
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index 0b3a6931e779b6..7018e3feded8ab 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -1713,13 +1713,16 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
       // if needed. Sample counts in profiles are 64-bit unsigned values,
       // but internally branch weights are expressed as 32-bit values.
       if (Weight > std::numeric_limits<uint32_t>::max()) {
-        LLVM_DEBUG(dbgs() << " (saturated due to uint32_t overflow)");
+        LLVM_DEBUG(dbgs() << " (saturated due to uint32_t overflow)\n");
         Weight = std::numeric_limits<uint32_t>::max();
       }
       if (!SampleProfileUseProfi) {
         // Weight is added by one to avoid propagation errors introduced by
         // 0 weights.
-        Weights.push_back(static_cast<uint32_t>(Weight + 1));
+        if (Weight != std::numeric_limits<uint32_t>::max())
+          Weight += 1;
+
+        Weights.push_back(static_cast<uint32_t>(Weight));
       } else {
         // Profi creates proper weights that do not require "+1" adjustments but
         // we evenly split the weight among branches with the same destination.
diff --git a/llvm/test/Transforms/SampleProfile/Inputs/overflow.proftext b/llvm/test/Transforms/SampleProfile/Inputs/overflow.proftext
new file mode 100644
index 00000000000000..753294a49e99c5
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/Inputs/overflow.proftext
@@ -0,0 +1,2 @@
+_Z3testi:29600000000:29600000000
+ 5: 29600000000
diff --git a/llvm/test/Transforms/SampleProfile/overflow.ll b/llvm/test/Transforms/SampleProfile/overflow.ll
new file mode 100644
index 00000000000000..76c7ce874093a1
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/overflow.ll
@@ -0,0 +1,70 @@
+; RUN: opt < %s -passes=sample-profile -sample-profile-file=%S/Inputs/overflow.proftext | opt -passes='print<branch-prob>' -disable-output 2>&1 | FileCheck %s
+
+; Original Source:
+; int sqrt(int);
+; int lol(int i) {
+;    if (i == 5) {
+;        return 42;
+;    }
+;    else {
+;        return sqrt(i);
+;    }
+;}
+
+define dso_local noundef i32 @_Z3testi(i32 noundef %i) local_unnamed_addr #0 !dbg !10 {
+entry:
+  tail call void @llvm.dbg.value(metadata i32 %i, metadata !16, metadata !DIExpression()), !dbg !17
+  %cmp = icmp eq i32 %i, 5, !dbg !18
+  br i1 %cmp, label %return, label %if.else, !dbg !20
+
+if.else:                                          ; preds = %entry
+  %call = tail call noundef i32 @_Z4sqrti(i32 noundef %i), !dbg !21
+  br label %return, !dbg !23
+
+return:                                           ; preds = %entry, %if.else
+  %retval.0 = phi i32 [ %call, %if.else ], [ 42, %entry ], !dbg !24
+  ret i32 %retval.0, !dbg !25
+}
+
+declare !dbg !26 noundef i32 @_Z4sqrti(i32 noundef)
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare void @llvm.dbg.value(metadata, metadata, metadata)
+
+attributes #0 = { "use-sample-profile" }
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!2, !3, !4, !5, !6, !7, !8}
+!llvm.ident = !{!9}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang", isOptimized: true, runtimeVersion: 0, emissionKind: FullDebug, splitDebugInlining: false, nameTableKind: None)
+!1 = !DIFile(filename: "test.cpp", directory: "/", checksumkind: CSK_MD5, checksum: "cb38d90153a7ebdd6ecf3058eb0524c7")
+!2 = !{i32 7, !"Dwarf Version", i32 5}
+!3 = !{i32 2, !"Debug Info Version", i32 3}
+!4 = !{i32 1, !"wchar_size", i32 4}
+!5 = !{i32 8, !"PIC Level", i32 2}
+!6 = !{i32 7, !"PIE Level", i32 2}
+!7 = !{i32 7, !"uwtable", i32 2}
+!8 = !{i32 7, !"debug-info-assignment-tracking", i1 true}
+!9 = !{!"clang"}
+!10 = distinct !DISubprogram(name: "test", linkageName: "_Z3loli", scope: !11, file: !11, line: 3, type: !12, scopeLine: 3, flags: DIFlagPrototyped | DIFlagAllCallsDescribed, spFlags: DISPFlagDefinition | DISPFlagOptimized, unit: !0, retainedNodes: !15)
+!11 = !DIFile(filename: "./test.cpp", directory: "/", checksumkind: CSK_MD5, checksum: "cb38d90153a7ebdd6ecf3058eb0524c7")
+!12 = !DISubroutineType(types: !13)
+!13 = !{!14, !14}
+!14 = !DIBasicType(name: "int", size: 32, encoding: DW_ATE_signed)
+!15 = !{!16}
+!16 = !DILocalVariable(name: "i", arg: 1, scope: !10, file: !11, line: 3, type: !14)
+!17 = !DILocation(line: 0, scope: !10)
+!18 = !DILocation(line: 4, column: 11, scope: !19)
+!19 = distinct !DILexicalBlock(scope: !10, file: !11, line: 4, column: 9)
+!20 = !DILocation(line: 4, column: 9, scope: !10)
+!21 = !DILocation(line: 8, column: 16, scope: !22)
+!22 = distinct !DILexicalBlock(scope: !19, file: !11, line: 7, column: 10)
+!23 = !DILocation(line: 8, column: 9, scope: !22)
+!24 = !DILocation(line: 0, scope: !19)
+!25 = !DILocation(line: 10, column: 1, scope: !10)
+!26 = !DISubprogram(name: "sqrt", linkageName: "_Z4sqrti", scope: !11, file: !11, line: 1, type: !12, flags: DIFlagPrototyped, spFlags: DISPFlagOptimized)
+
+; CHECK: edge %entry -> %return probability is 0x00000000 / 0x80000000 = 0.00%
+; CHECK: edge %entry -> %if.else probability is 0x80000000 / 0x80000000 = 100.00% [HOT edge]
+; CHECK: edge %if.else -> %return probability is 0x80000000 / 0x80000000 = 100.00% [HOT edge]

``````````

</details>


https://github.com/llvm/llvm-project/pull/90217


More information about the llvm-commits mailing list