[compiler-rt] [DFSan] Fix sscanf checking that ordinary characters match. (PR #95333)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 12 16:47:24 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-compiler-rt-sanitizer
Author: Andrew Browne (browneee)
<details>
<summary>Changes</summary>
Fixes: #<!-- -->94769
---
Patch is 23.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95333.diff
2 Files Affected:
- (modified) compiler-rt/lib/dfsan/dfsan_custom.cpp (+291-231)
- (modified) compiler-rt/test/dfsan/sscanf.c (+102-9)
``````````diff
diff --git a/compiler-rt/lib/dfsan/dfsan_custom.cpp b/compiler-rt/lib/dfsan/dfsan_custom.cpp
index af3c1f4d1673c..050f5232c0408 100644
--- a/compiler-rt/lib/dfsan/dfsan_custom.cpp
+++ b/compiler-rt/lib/dfsan/dfsan_custom.cpp
@@ -2198,50 +2198,12 @@ struct Formatter {
return retval;
}
- int scan() {
- char *tmp_fmt = build_format_string(true);
- int read_count = 0;
- int retval = sscanf(str + str_off, tmp_fmt, &read_count);
- if (retval > 0) {
- if (-1 == num_scanned)
- num_scanned = 0;
- num_scanned += retval;
- }
- free(tmp_fmt);
- return read_count;
- }
-
- template <typename T>
- int scan(T arg) {
- char *tmp_fmt = build_format_string(true);
- int read_count = 0;
- int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
- if (retval > 0) {
- if (-1 == num_scanned)
- num_scanned = 0;
- num_scanned += retval;
- }
- free(tmp_fmt);
- return read_count;
- }
-
- // with_n -> toggles adding %n on/off; off by default
- char *build_format_string(bool with_n = false) {
+ char *build_format_string() {
size_t fmt_size = fmt_cur - fmt_start + 1;
- size_t add_size = 0;
- if (with_n)
- add_size = 2;
- char *new_fmt = (char *)malloc(fmt_size + 1 + add_size);
+ char *new_fmt = (char *)malloc(fmt_size + 1);
assert(new_fmt);
internal_memcpy(new_fmt, fmt_start, fmt_size);
- if (!with_n) {
- new_fmt[fmt_size] = '\0';
- } else {
- new_fmt[fmt_size] = '%';
- new_fmt[fmt_size + 1] = 'n';
- new_fmt[fmt_size + 2] = '\0';
- }
-
+ new_fmt[fmt_size] = '\0';
return new_fmt;
}
@@ -2467,6 +2429,102 @@ static int format_buffer(char *str, size_t size, const char *fmt,
return formatter.str_off;
}
+// Scans a chunk either a constant string or a single format directive (e.g.,
+// '%.3f').
+struct Scanner {
+ Scanner(char *str_, const char *fmt_, size_t size_)
+ : str(str_),
+ str_off(0),
+ size(size_),
+ fmt_start(fmt_),
+ fmt_cur(fmt_),
+ width(-1),
+ num_scanned(0),
+ skip(false) {}
+
+ // Consumes a chunk of ordinary characters.
+ // Returns number of matching ordinary characters.
+ // Returns -1 if the match failed.
+ // In format strings, a space will match multiple spaces.
+ int check_match_ordinary() {
+ char *tmp_fmt = build_format_string_with_n();
+ int read_count = -1;
+ sscanf(str + str_off, tmp_fmt, &read_count);
+ free(tmp_fmt);
+ if (read_count > 0) {
+ str_off += read_count;
+ }
+ return read_count;
+ }
+
+ int scan() {
+ char *tmp_fmt = build_format_string_with_n();
+ int read_count = 0;
+ int retval = sscanf(str + str_off, tmp_fmt, &read_count);
+ free(tmp_fmt);
+ if (retval > 0) {
+ num_scanned += retval;
+ }
+ return read_count;
+ }
+
+ template <typename T>
+ int scan(T arg) {
+ char *tmp_fmt = build_format_string_with_n();
+ int read_count = 0;
+ int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
+ free(tmp_fmt);
+ if (retval > 0) {
+ num_scanned += retval;
+ }
+ return read_count;
+ }
+
+ // Adds %n onto current format string to measure length.
+ char *build_format_string_with_n() {
+ size_t fmt_size = fmt_cur - fmt_start + 1;
+ // +2 for %n, +1 for \0
+ char *new_fmt = (char *)malloc(fmt_size + 2 + 1);
+ assert(new_fmt);
+ internal_memcpy(new_fmt, fmt_start, fmt_size);
+ new_fmt[fmt_size] = '%';
+ new_fmt[fmt_size + 1] = 'n';
+ new_fmt[fmt_size + 2] = '\0';
+ return new_fmt;
+ }
+
+ char *str_cur() { return str + str_off; }
+
+ size_t num_written_bytes(int retval) {
+ if (retval < 0) {
+ return 0;
+ }
+
+ size_t num_avail = str_off < size ? size - str_off : 0;
+ if (num_avail == 0) {
+ return 0;
+ }
+
+ size_t num_written = retval;
+ // A return value of {v,}snprintf of size or more means that the output was
+ // truncated.
+ if (num_written >= num_avail) {
+ num_written -= num_avail;
+ }
+
+ return num_written;
+ }
+
+ char *str;
+ size_t str_off;
+ size_t size;
+ const char *fmt_start;
+ const char *fmt_cur;
+ int width;
+ int num_scanned;
+ bool skip;
+};
+
// This function is an inverse of format_buffer: we take the input buffer,
// scan it in search for format strings and store the results in the varargs.
// The labels are propagated from the input buffer to the varargs.
@@ -2474,220 +2532,222 @@ static int scan_buffer(char *str, size_t size, const char *fmt,
dfsan_label *va_labels, dfsan_label *ret_label,
dfsan_origin *str_origin, dfsan_origin *ret_origin,
va_list ap) {
- Formatter formatter(str, fmt, size);
- while (*formatter.fmt_cur) {
- formatter.fmt_start = formatter.fmt_cur;
- formatter.width = -1;
- formatter.skip = false;
+ Scanner scanner(str, fmt, size);
+ while (*scanner.fmt_cur) {
+ scanner.fmt_start = scanner.fmt_cur;
+ scanner.width = -1;
+ scanner.skip = false;
int read_count = 0;
void *dst_ptr = 0;
size_t write_size = 0;
- if (*formatter.fmt_cur != '%') {
- // Ordinary character. Consume all the characters until a '%' or the end
- // of the string.
- for (; *(formatter.fmt_cur + 1) && *(formatter.fmt_cur + 1) != '%';
- ++formatter.fmt_cur) {
+ if (*scanner.fmt_cur != '%') {
+ // Ordinary character and spaces.
+ // Consume all the characters until a '%' or the end of the string.
+ for (; *(scanner.fmt_cur + 1) && *(scanner.fmt_cur + 1) != '%';
+ ++scanner.fmt_cur) {
+ }
+ if (scanner.check_match_ordinary() < 0) {
+ // The ordinary characters did not match.
+ break;
}
- read_count = formatter.scan();
- dfsan_set_label(0, formatter.str_cur(),
- formatter.num_written_bytes(read_count));
} else {
// Conversion directive. Consume all the characters until a conversion
// specifier or the end of the string.
bool end_fmt = false;
- for (; *formatter.fmt_cur && !end_fmt;) {
- switch (*++formatter.fmt_cur) {
- case 'd':
- case 'i':
- case 'o':
- case 'u':
- case 'x':
- case 'X':
- if (formatter.skip) {
- read_count = formatter.scan();
- } else {
- switch (*(formatter.fmt_cur - 1)) {
- case 'h':
- // Also covers the 'hh' case (since the size of the arg is still
- // an int).
- dst_ptr = va_arg(ap, int *);
- read_count = formatter.scan((int *)dst_ptr);
- write_size = sizeof(int);
- break;
- case 'l':
- if (formatter.fmt_cur - formatter.fmt_start >= 2 &&
- *(formatter.fmt_cur - 2) == 'l') {
- dst_ptr = va_arg(ap, long long int *);
- read_count = formatter.scan((long long int *)dst_ptr);
- write_size = sizeof(long long int);
- } else {
- dst_ptr = va_arg(ap, long int *);
- read_count = formatter.scan((long int *)dst_ptr);
- write_size = sizeof(long int);
+ for (; *scanner.fmt_cur && !end_fmt;) {
+ switch (*++scanner.fmt_cur) {
+ case 'd':
+ case 'i':
+ case 'o':
+ case 'u':
+ case 'x':
+ case 'X':
+ if (scanner.skip) {
+ read_count = scanner.scan();
+ } else {
+ switch (*(scanner.fmt_cur - 1)) {
+ case 'h':
+ // Also covers the 'hh' case (since the size of the arg is
+ // still an int).
+ dst_ptr = va_arg(ap, int *);
+ read_count = scanner.scan((int *)dst_ptr);
+ write_size = sizeof(int);
+ break;
+ case 'l':
+ if (scanner.fmt_cur - scanner.fmt_start >= 2 &&
+ *(scanner.fmt_cur - 2) == 'l') {
+ dst_ptr = va_arg(ap, long long int *);
+ read_count = scanner.scan((long long int *)dst_ptr);
+ write_size = sizeof(long long int);
+ } else {
+ dst_ptr = va_arg(ap, long int *);
+ read_count = scanner.scan((long int *)dst_ptr);
+ write_size = sizeof(long int);
+ }
+ break;
+ case 'q':
+ dst_ptr = va_arg(ap, long long int *);
+ read_count = scanner.scan((long long int *)dst_ptr);
+ write_size = sizeof(long long int);
+ break;
+ case 'j':
+ dst_ptr = va_arg(ap, intmax_t *);
+ read_count = scanner.scan((intmax_t *)dst_ptr);
+ write_size = sizeof(intmax_t);
+ break;
+ case 'z':
+ case 't':
+ dst_ptr = va_arg(ap, size_t *);
+ read_count = scanner.scan((size_t *)dst_ptr);
+ write_size = sizeof(size_t);
+ break;
+ default:
+ dst_ptr = va_arg(ap, int *);
+ read_count = scanner.scan((int *)dst_ptr);
+ write_size = sizeof(int);
+ }
+ // get the label associated with the string at the corresponding
+ // place
+ dfsan_label l = dfsan_read_label(
+ scanner.str_cur(), scanner.num_written_bytes(read_count));
+ dfsan_set_label(l, dst_ptr, write_size);
+ if (str_origin != nullptr) {
+ dfsan_set_label(l, dst_ptr, write_size);
+ size_t scan_count = scanner.num_written_bytes(read_count);
+ size_t size = scan_count > write_size ? write_size : scan_count;
+ dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
}
- break;
- case 'q':
- dst_ptr = va_arg(ap, long long int *);
- read_count = formatter.scan((long long int *)dst_ptr);
- write_size = sizeof(long long int);
- break;
- case 'j':
- dst_ptr = va_arg(ap, intmax_t *);
- read_count = formatter.scan((intmax_t *)dst_ptr);
- write_size = sizeof(intmax_t);
- break;
- case 'z':
- case 't':
- dst_ptr = va_arg(ap, size_t *);
- read_count = formatter.scan((size_t *)dst_ptr);
- write_size = sizeof(size_t);
- break;
- default:
- dst_ptr = va_arg(ap, int *);
- read_count = formatter.scan((int *)dst_ptr);
- write_size = sizeof(int);
- }
- // get the label associated with the string at the corresponding
- // place
- dfsan_label l = dfsan_read_label(
- formatter.str_cur(), formatter.num_written_bytes(read_count));
- dfsan_set_label(l, dst_ptr, write_size);
- if (str_origin != nullptr) {
- dfsan_set_label(l, dst_ptr, write_size);
- size_t scan_count = formatter.num_written_bytes(read_count);
- size_t size = scan_count > write_size ? write_size : scan_count;
- dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
}
- }
- end_fmt = true;
+ end_fmt = true;
- break;
+ break;
- case 'a':
- case 'A':
- case 'e':
- case 'E':
- case 'f':
- case 'F':
- case 'g':
- case 'G':
- if (formatter.skip) {
- read_count = formatter.scan();
- } else {
- if (*(formatter.fmt_cur - 1) == 'L') {
- dst_ptr = va_arg(ap, long double *);
- read_count = formatter.scan((long double *)dst_ptr);
- write_size = sizeof(long double);
- } else if (*(formatter.fmt_cur - 1) == 'l') {
- dst_ptr = va_arg(ap, double *);
- read_count = formatter.scan((double *)dst_ptr);
- write_size = sizeof(double);
+ case 'a':
+ case 'A':
+ case 'e':
+ case 'E':
+ case 'f':
+ case 'F':
+ case 'g':
+ case 'G':
+ if (scanner.skip) {
+ read_count = scanner.scan();
} else {
- dst_ptr = va_arg(ap, float *);
- read_count = formatter.scan((float *)dst_ptr);
- write_size = sizeof(float);
- }
- dfsan_label l = dfsan_read_label(
- formatter.str_cur(), formatter.num_written_bytes(read_count));
- dfsan_set_label(l, dst_ptr, write_size);
- if (str_origin != nullptr) {
- dfsan_set_label(l, dst_ptr, write_size);
- size_t scan_count = formatter.num_written_bytes(read_count);
- size_t size = scan_count > write_size ? write_size : scan_count;
- dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+ if (*(scanner.fmt_cur - 1) == 'L') {
+ dst_ptr = va_arg(ap, long double *);
+ read_count = scanner.scan((long double *)dst_ptr);
+ write_size = sizeof(long double);
+ } else if (*(scanner.fmt_cur - 1) == 'l') {
+ dst_ptr = va_arg(ap, double *);
+ read_count = scanner.scan((double *)dst_ptr);
+ write_size = sizeof(double);
+ } else {
+ dst_ptr = va_arg(ap, float *);
+ read_count = scanner.scan((float *)dst_ptr);
+ write_size = sizeof(float);
+ }
+ dfsan_label l = dfsan_read_label(
+ scanner.str_cur(), scanner.num_written_bytes(read_count));
+ dfsan_set_label(l, dst_ptr, write_size);
+ if (str_origin != nullptr) {
+ dfsan_set_label(l, dst_ptr, write_size);
+ size_t scan_count = scanner.num_written_bytes(read_count);
+ size_t size = scan_count > write_size ? write_size : scan_count;
+ dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+ }
}
- }
- end_fmt = true;
- break;
+ end_fmt = true;
+ break;
- case 'c':
- if (formatter.skip) {
- read_count = formatter.scan();
- } else {
- dst_ptr = va_arg(ap, char *);
- read_count = formatter.scan((char *)dst_ptr);
- write_size = sizeof(char);
- dfsan_label l = dfsan_read_label(
- formatter.str_cur(), formatter.num_written_bytes(read_count));
- dfsan_set_label(l, dst_ptr, write_size);
- if (str_origin != nullptr) {
- size_t scan_count = formatter.num_written_bytes(read_count);
- size_t size = scan_count > write_size ? write_size : scan_count;
- dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+ case 'c':
+ if (scanner.skip) {
+ read_count = scanner.scan();
+ } else {
+ dst_ptr = va_arg(ap, char *);
+ read_count = scanner.scan((char *)dst_ptr);
+ write_size = sizeof(char);
+ dfsan_label l = dfsan_read_label(
+ scanner.str_cur(), scanner.num_written_bytes(read_count));
+ dfsan_set_label(l, dst_ptr, write_size);
+ if (str_origin != nullptr) {
+ size_t scan_count = scanner.num_written_bytes(read_count);
+ size_t size = scan_count > write_size ? write_size : scan_count;
+ dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+ }
}
- }
- end_fmt = true;
- break;
+ end_fmt = true;
+ break;
- case 's': {
- if (formatter.skip) {
- read_count = formatter.scan();
- } else {
- dst_ptr = va_arg(ap, char *);
- read_count = formatter.scan((char *)dst_ptr);
- if (1 == read_count) {
- // special case: we have parsed a single string and we need to
- // update read_count with the string size
- read_count = strlen((char *)dst_ptr);
+ case 's': {
+ if (scanner.skip) {
+ read_count = scanner.scan();
+ } else {
+ dst_ptr = va_arg(ap, char *);
+ read_count = scanner.scan((char *)dst_ptr);
+ if (1 == read_count) {
+ // special case: we have parsed a single string and we need to
+ // update read_count with the string size
+ read_count = strlen((char *)dst_ptr);
+ }
+ if (str_origin)
+ dfsan_mem_origin_transfer(
+ dst_ptr, scanner.str_cur(),
+ scanner.num_written_bytes(read_count));
+ va_labels++;
+ dfsan_mem_shadow_transfer(dst_ptr, scanner.str_cur(),
+ scanner.num_written_bytes(read_count));
}
- if (str_origin)
- dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(),
- formatter.num_written_bytes(read_count));
- va_labels++;
- dfsan_mem_shadow_transfer(dst_ptr, formatter.str_cur(),
- formatter.num_written_bytes(read_count));
+ end_fmt = true;
+ break;
}
- end_fmt = true;
- break;
- }
- case 'p':
- if (formatter.skip) {
- read_count = formatter.scan();
- } else {
- dst_ptr = va_arg(ap, void *);
- read_count =
- formatter.scan((int *)dst_ptr); // note: changing void* to int*
+ case 'p':
+ if (scanner.skip) {
+ read_count = scanner.scan();
+ } else {
+ dst_ptr = va_arg(ap, void *);
+ read_count =
+ scanner.scan((int *)dst_ptr); // note: changing void* to int*
// since we need to call sizeof
- write_size = sizeof(int);
-
- dfsan_label l = dfsan_read_label(
- formatter.str_cur(), formatter.num_written_bytes(read_count));
- dfsan_set_label(l, dst_ptr, write_size);
- if (str_origin != nullptr) {
- dfsan_set_label(l, dst_ptr, write_size);
- size_t scan_count = formatter.num_written_bytes(read_count);
- size_t size = scan_count > write_size ? write_size : scan_count;
- dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+ write_size = sizeof(int);
+
+ dfsan_label l = dfsan_read_label(
+ scanner.str_cur(), scanner.num_written_bytes(read_count));
+ dfsan_set_label(l, dst_ptr, write_size);
+ if (str_origin != nullptr) {
+ dfsan_set_label(l, dst_ptr, write_size);
+ size_t scan_count = scanner.num_written_bytes(read_count);
+ size_t size = scan_count > write_size ? write_size : scan_count;
+ dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+ }
}
- }
- end_fmt = true;
- break;
+ end_fmt = true;
+ break;
- case 'n': {
- if (!formatter.skip) {
- int *ptr = va_arg(ap, int *);
- *ptr = (int)formatter.str_off;
- *va_labels++ = 0;
- dfsan_set_label(0, ptr, sizeof(*ptr));
- if (str_origin != nullptr)
- *str_origin++ = 0;
+ case 'n': {
+ if (!scanner.skip) {
+ int *ptr = va_arg(ap, int *);
+ *ptr = (int)scanner.str_off;
+ *va_labels++ = 0;
+ dfsan_set_label(0, ptr, sizeof(*ptr));
+ if (str_origin != nullptr)
+ *str_origin++ = 0;
+ }
+ end_fmt = true;
+ break;
...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/95333
More information about the llvm-commits
mailing list