[llvm] [SimplifyLibCalls] Fix memchr misoptimization (PR #106121)
Sergei Barannikov via llvm-commits
llvm-commits at lists.llvm.org
Mon Aug 26 12:18:16 PDT 2024
https://github.com/s-barannikov created https://github.com/llvm/llvm-project/pull/106121
The `ch` argument of memcmp should be truncated to `unsigned char` before using it in comparisons. This didn't happen on all code paths. The following program miscompiled at -O1 and higher:
```C++
#include <cstring>
#include <iostream>
char ch = '\x81';
int main() {
bool found = std::strchr("\x80\x81\x82", ch) != nullptr;
std::cout << std::boolalpha << found << '\n';
}
```
>From 1a9bb34adef4263368a4954ab26fc4d9f4eb4536 Mon Sep 17 00:00:00 2001
From: Sergei Barannikov <barannikov88 at gmail.com>
Date: Mon, 26 Aug 2024 22:14:33 +0300
Subject: [PATCH] [SimplifyLibCalls] Fix memchr misoptimization
The `ch` argument of memcmp should be truncated to `unsigned char`
before using it in comparisons. This didn't happen on all code paths.
The following program miscompiled at -O1 and higher:
```C++
#include <cstring>
#include <iostream>
char ch = '\x81';
int main() {
bool found = std::strchr("\x80\x81\x82", ch) != nullptr;
std::cout << std::boolalpha << found << '\n';
}
```
---
.../lib/Transforms/Utils/SimplifyLibCalls.cpp | 6 +-
llvm/test/Transforms/InstCombine/memchr-7.ll | 55 ++++++++++---------
2 files changed, 34 insertions(+), 27 deletions(-)
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index fb2efe581ac6bb..1e6dc88ed93532 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -1454,10 +1454,12 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
if (NonContRanges > 2)
return nullptr;
+ // Slice off the character's high end bits.
+ CharVal = B.CreateTrunc(CharVal, B.getInt8Ty());
+
SmallVector<Value *> CharCompares;
for (unsigned char C : SortedStr)
- CharCompares.push_back(
- B.CreateICmpEQ(CharVal, ConstantInt::get(CharVal->getType(), C)));
+ CharCompares.push_back(B.CreateICmpEQ(CharVal, B.getInt8(C)));
return B.CreateIntToPtr(B.CreateOr(CharCompares), CI->getType());
}
diff --git a/llvm/test/Transforms/InstCombine/memchr-7.ll b/llvm/test/Transforms/InstCombine/memchr-7.ll
index 0b364cce656d77..61f1093279f834 100644
--- a/llvm/test/Transforms/InstCombine/memchr-7.ll
+++ b/llvm/test/Transforms/InstCombine/memchr-7.ll
@@ -12,11 +12,12 @@ declare ptr @memchr(ptr, i32, i64)
define zeroext i1 @strchr_to_memchr_n_equals_len(i32 %c) {
; CHECK-LABEL: @strchr_to_memchr_n_equals_len(
-; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[C:%.*]], 0
-; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[C]], -97
-; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 26
-; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP1]], [[TMP3]]
-; CHECK-NEXT: ret i1 [[TMP4]]
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = add i8 [[TMP1]], -97
+; CHECK-NEXT: [[TMP4:%.*]] = icmp ult i8 [[TMP3]], 26
+; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP2]], [[TMP4]]
+; CHECK-NEXT: ret i1 [[TMP5]]
;
%call = tail call ptr @strchr(ptr nonnull dereferenceable(27) @.str, i32 %c)
%cmp = icmp ne ptr %call, null
@@ -38,9 +39,10 @@ define zeroext i1 @memchr_n_equals_len(i32 %c) {
define zeroext i1 @memchr_n_less_than_len(i32 %c) {
; CHECK-LABEL: @memchr_n_less_than_len(
-; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[C:%.*]], -97
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 15
-; CHECK-NEXT: ret i1 [[TMP2]]
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = add i8 [[TMP1]], -97
+; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i8 [[TMP2]], 15
+; CHECK-NEXT: ret i1 [[TMP3]]
;
%call = tail call ptr @memchr(ptr @.str, i32 %c, i64 15)
%cmp = icmp ne ptr %call, null
@@ -50,11 +52,12 @@ define zeroext i1 @memchr_n_less_than_len(i32 %c) {
define zeroext i1 @memchr_n_more_than_len(i32 %c) {
; CHECK-LABEL: @memchr_n_more_than_len(
-; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[C:%.*]], 0
-; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[C]], -97
-; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 26
-; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP1]], [[TMP3]]
-; CHECK-NEXT: ret i1 [[TMP4]]
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[TMP1]], 0
+; CHECK-NEXT: [[TMP3:%.*]] = add i8 [[TMP1]], -97
+; CHECK-NEXT: [[TMP4:%.*]] = icmp ult i8 [[TMP3]], 26
+; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP2]], [[TMP4]]
+; CHECK-NEXT: ret i1 [[TMP5]]
;
%call = tail call ptr @memchr(ptr @.str, i32 %c, i64 30)
%cmp = icmp ne ptr %call, null
@@ -114,12 +117,13 @@ define zeroext i1 @memchr_n_equals_len2_minsize(i32 %c) minsize {
; Positive test - 2 non-contiguous ranges
define zeroext i1 @strchr_to_memchr_2_non_cont_ranges(i32 %c) {
; CHECK-LABEL: @strchr_to_memchr_2_non_cont_ranges(
-; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[C:%.*]], -97
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 6
-; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[C]], -109
-; CHECK-NEXT: [[TMP4:%.*]] = icmp ult i32 [[TMP3]], 3
-; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP2]], [[TMP4]]
-; CHECK-NEXT: ret i1 [[TMP5]]
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = add i8 [[TMP1]], -97
+; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i8 [[TMP2]], 6
+; CHECK-NEXT: [[TMP4:%.*]] = add i8 [[TMP1]], -109
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i8 [[TMP4]], 3
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP3]], [[TMP5]]
+; CHECK-NEXT: ret i1 [[TMP6]]
;
%call = tail call ptr @memchr(ptr @.str.2, i32 %c, i64 9)
%cmp = icmp ne ptr %call, null
@@ -129,12 +133,13 @@ define zeroext i1 @strchr_to_memchr_2_non_cont_ranges(i32 %c) {
; Positive test - 2 non-contiguous ranges with char duplication
define zeroext i1 @strchr_to_memchr_2_non_cont_ranges_char_dup(i32 %c) {
; CHECK-LABEL: @strchr_to_memchr_2_non_cont_ranges_char_dup(
-; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[C:%.*]], -97
-; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 3
-; CHECK-NEXT: [[TMP3:%.*]] = add i32 [[C]], -109
-; CHECK-NEXT: [[TMP4:%.*]] = icmp ult i32 [[TMP3]], 2
-; CHECK-NEXT: [[TMP5:%.*]] = or i1 [[TMP2]], [[TMP4]]
-; CHECK-NEXT: ret i1 [[TMP5]]
+; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8
+; CHECK-NEXT: [[TMP2:%.*]] = add i8 [[TMP1]], -97
+; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i8 [[TMP2]], 3
+; CHECK-NEXT: [[TMP4:%.*]] = add i8 [[TMP1]], -109
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i8 [[TMP4]], 2
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP3]], [[TMP5]]
+; CHECK-NEXT: ret i1 [[TMP6]]
;
%call = tail call ptr @memchr(ptr @.str.4, i32 %c, i64 6)
%cmp = icmp ne ptr %call, null
More information about the llvm-commits
mailing list