[compiler-rt] [DFSan] Fix sscanf checking that ordinary characters match. (PR #95333)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 12 16:47:24 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-compiler-rt-sanitizer

Author: Andrew Browne (browneee)

<details>
<summary>Changes</summary>

Fixes: #<!-- -->94769

---

Patch is 23.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/95333.diff


2 Files Affected:

- (modified) compiler-rt/lib/dfsan/dfsan_custom.cpp (+291-231) 
- (modified) compiler-rt/test/dfsan/sscanf.c (+102-9) 


``````````diff
diff --git a/compiler-rt/lib/dfsan/dfsan_custom.cpp b/compiler-rt/lib/dfsan/dfsan_custom.cpp
index af3c1f4d1673c..050f5232c0408 100644
--- a/compiler-rt/lib/dfsan/dfsan_custom.cpp
+++ b/compiler-rt/lib/dfsan/dfsan_custom.cpp
@@ -2198,50 +2198,12 @@ struct Formatter {
     return retval;
   }
 
-  int scan() {
-    char *tmp_fmt = build_format_string(true);
-    int read_count = 0;
-    int retval = sscanf(str + str_off, tmp_fmt, &read_count);
-    if (retval > 0) {
-      if (-1 == num_scanned)
-        num_scanned = 0;
-      num_scanned += retval;
-    }
-    free(tmp_fmt);
-    return read_count;
-  }
-
-  template <typename T>
-  int scan(T arg) {
-    char *tmp_fmt = build_format_string(true);
-    int read_count = 0;
-    int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
-    if (retval > 0) {
-      if (-1 == num_scanned)
-        num_scanned = 0;
-      num_scanned += retval;
-    }
-    free(tmp_fmt);
-    return read_count;
-  }
-
-  // with_n -> toggles adding %n on/off; off by default
-  char *build_format_string(bool with_n = false) {
+  char *build_format_string() {
     size_t fmt_size = fmt_cur - fmt_start + 1;
-    size_t add_size = 0;
-    if (with_n)
-      add_size = 2;
-    char *new_fmt = (char *)malloc(fmt_size + 1 + add_size);
+    char *new_fmt = (char *)malloc(fmt_size + 1);
     assert(new_fmt);
     internal_memcpy(new_fmt, fmt_start, fmt_size);
-    if (!with_n) {
-      new_fmt[fmt_size] = '\0';
-    } else {
-      new_fmt[fmt_size] = '%';
-      new_fmt[fmt_size + 1] = 'n';
-      new_fmt[fmt_size + 2] = '\0';
-    }
-
+    new_fmt[fmt_size] = '\0';
     return new_fmt;
   }
 
@@ -2467,6 +2429,102 @@ static int format_buffer(char *str, size_t size, const char *fmt,
   return formatter.str_off;
 }
 
+// Scans a chunk either a constant string or a single format directive (e.g.,
+// '%.3f').
+struct Scanner {
+  Scanner(char *str_, const char *fmt_, size_t size_)
+      : str(str_),
+        str_off(0),
+        size(size_),
+        fmt_start(fmt_),
+        fmt_cur(fmt_),
+        width(-1),
+        num_scanned(0),
+        skip(false) {}
+
+  // Consumes a chunk of ordinary characters.
+  // Returns number of matching ordinary characters.
+  // Returns -1 if the match failed.
+  // In format strings, a space will match multiple spaces.
+  int check_match_ordinary() {
+    char *tmp_fmt = build_format_string_with_n();
+    int read_count = -1;
+    sscanf(str + str_off, tmp_fmt, &read_count);
+    free(tmp_fmt);
+    if (read_count > 0) {
+      str_off += read_count;
+    }
+    return read_count;
+  }
+
+  int scan() {
+    char *tmp_fmt = build_format_string_with_n();
+    int read_count = 0;
+    int retval = sscanf(str + str_off, tmp_fmt, &read_count);
+    free(tmp_fmt);
+    if (retval > 0) {
+      num_scanned += retval;
+    }
+    return read_count;
+  }
+
+  template <typename T>
+  int scan(T arg) {
+    char *tmp_fmt = build_format_string_with_n();
+    int read_count = 0;
+    int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
+    free(tmp_fmt);
+    if (retval > 0) {
+      num_scanned += retval;
+    }
+    return read_count;
+  }
+
+  // Adds %n onto current format string to measure length.
+  char *build_format_string_with_n() {
+    size_t fmt_size = fmt_cur - fmt_start + 1;
+    // +2 for %n, +1 for \0
+    char *new_fmt = (char *)malloc(fmt_size + 2 + 1);
+    assert(new_fmt);
+    internal_memcpy(new_fmt, fmt_start, fmt_size);
+    new_fmt[fmt_size] = '%';
+    new_fmt[fmt_size + 1] = 'n';
+    new_fmt[fmt_size + 2] = '\0';
+    return new_fmt;
+  }
+
+  char *str_cur() { return str + str_off; }
+
+  size_t num_written_bytes(int retval) {
+    if (retval < 0) {
+      return 0;
+    }
+
+    size_t num_avail = str_off < size ? size - str_off : 0;
+    if (num_avail == 0) {
+      return 0;
+    }
+
+    size_t num_written = retval;
+    // A return value of {v,}snprintf of size or more means that the output was
+    // truncated.
+    if (num_written >= num_avail) {
+      num_written -= num_avail;
+    }
+
+    return num_written;
+  }
+
+  char *str;
+  size_t str_off;
+  size_t size;
+  const char *fmt_start;
+  const char *fmt_cur;
+  int width;
+  int num_scanned;
+  bool skip;
+};
+
 // This function is an inverse of format_buffer: we take the input buffer,
 // scan it in search for format strings and store the results in the varargs.
 // The labels are propagated from the input buffer to the varargs.
@@ -2474,220 +2532,222 @@ static int scan_buffer(char *str, size_t size, const char *fmt,
                        dfsan_label *va_labels, dfsan_label *ret_label,
                        dfsan_origin *str_origin, dfsan_origin *ret_origin,
                        va_list ap) {
-  Formatter formatter(str, fmt, size);
-  while (*formatter.fmt_cur) {
-    formatter.fmt_start = formatter.fmt_cur;
-    formatter.width = -1;
-    formatter.skip = false;
+  Scanner scanner(str, fmt, size);
+  while (*scanner.fmt_cur) {
+    scanner.fmt_start = scanner.fmt_cur;
+    scanner.width = -1;
+    scanner.skip = false;
     int read_count = 0;
     void *dst_ptr = 0;
     size_t write_size = 0;
-    if (*formatter.fmt_cur != '%') {
-      // Ordinary character. Consume all the characters until a '%' or the end
-      // of the string.
-      for (; *(formatter.fmt_cur + 1) && *(formatter.fmt_cur + 1) != '%';
-           ++formatter.fmt_cur) {
+    if (*scanner.fmt_cur != '%') {
+      // Ordinary character and spaces.
+      // Consume all the characters until a '%' or the end of the string.
+      for (; *(scanner.fmt_cur + 1) && *(scanner.fmt_cur + 1) != '%';
+           ++scanner.fmt_cur) {
+      }
+      if (scanner.check_match_ordinary() < 0) {
+        // The ordinary characters did not match.
+        break;
       }
-      read_count = formatter.scan();
-      dfsan_set_label(0, formatter.str_cur(),
-                      formatter.num_written_bytes(read_count));
     } else {
       // Conversion directive. Consume all the characters until a conversion
       // specifier or the end of the string.
       bool end_fmt = false;
-      for (; *formatter.fmt_cur && !end_fmt;) {
-        switch (*++formatter.fmt_cur) {
-        case 'd':
-        case 'i':
-        case 'o':
-        case 'u':
-        case 'x':
-        case 'X':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            switch (*(formatter.fmt_cur - 1)) {
-            case 'h':
-              // Also covers the 'hh' case (since the size of the arg is still
-              // an int).
-              dst_ptr = va_arg(ap, int *);
-              read_count = formatter.scan((int *)dst_ptr);
-              write_size = sizeof(int);
-              break;
-            case 'l':
-              if (formatter.fmt_cur - formatter.fmt_start >= 2 &&
-                  *(formatter.fmt_cur - 2) == 'l') {
-                dst_ptr = va_arg(ap, long long int *);
-                read_count = formatter.scan((long long int *)dst_ptr);
-                write_size = sizeof(long long int);
-              } else {
-                dst_ptr = va_arg(ap, long int *);
-                read_count = formatter.scan((long int *)dst_ptr);
-                write_size = sizeof(long int);
+      for (; *scanner.fmt_cur && !end_fmt;) {
+        switch (*++scanner.fmt_cur) {
+          case 'd':
+          case 'i':
+          case 'o':
+          case 'u':
+          case 'x':
+          case 'X':
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              switch (*(scanner.fmt_cur - 1)) {
+                case 'h':
+                  // Also covers the 'hh' case (since the size of the arg is
+                  // still an int).
+                  dst_ptr = va_arg(ap, int *);
+                  read_count = scanner.scan((int *)dst_ptr);
+                  write_size = sizeof(int);
+                  break;
+                case 'l':
+                  if (scanner.fmt_cur - scanner.fmt_start >= 2 &&
+                      *(scanner.fmt_cur - 2) == 'l') {
+                    dst_ptr = va_arg(ap, long long int *);
+                    read_count = scanner.scan((long long int *)dst_ptr);
+                    write_size = sizeof(long long int);
+                  } else {
+                    dst_ptr = va_arg(ap, long int *);
+                    read_count = scanner.scan((long int *)dst_ptr);
+                    write_size = sizeof(long int);
+                  }
+                  break;
+                case 'q':
+                  dst_ptr = va_arg(ap, long long int *);
+                  read_count = scanner.scan((long long int *)dst_ptr);
+                  write_size = sizeof(long long int);
+                  break;
+                case 'j':
+                  dst_ptr = va_arg(ap, intmax_t *);
+                  read_count = scanner.scan((intmax_t *)dst_ptr);
+                  write_size = sizeof(intmax_t);
+                  break;
+                case 'z':
+                case 't':
+                  dst_ptr = va_arg(ap, size_t *);
+                  read_count = scanner.scan((size_t *)dst_ptr);
+                  write_size = sizeof(size_t);
+                  break;
+                default:
+                  dst_ptr = va_arg(ap, int *);
+                  read_count = scanner.scan((int *)dst_ptr);
+                  write_size = sizeof(int);
+              }
+              // get the label associated with the string at the corresponding
+              // place
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                dfsan_set_label(l, dst_ptr, write_size);
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
               }
-              break;
-            case 'q':
-              dst_ptr = va_arg(ap, long long int *);
-              read_count = formatter.scan((long long int *)dst_ptr);
-              write_size = sizeof(long long int);
-              break;
-            case 'j':
-              dst_ptr = va_arg(ap, intmax_t *);
-              read_count = formatter.scan((intmax_t *)dst_ptr);
-              write_size = sizeof(intmax_t);
-              break;
-            case 'z':
-            case 't':
-              dst_ptr = va_arg(ap, size_t *);
-              read_count = formatter.scan((size_t *)dst_ptr);
-              write_size = sizeof(size_t);
-              break;
-            default:
-              dst_ptr = va_arg(ap, int *);
-              read_count = formatter.scan((int *)dst_ptr);
-              write_size = sizeof(int);
-            }
-            // get the label associated with the string at the corresponding
-            // place
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            dfsan_set_label(l, dst_ptr, write_size);
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
             }
-          }
-          end_fmt = true;
+            end_fmt = true;
 
-          break;
+            break;
 
-        case 'a':
-        case 'A':
-        case 'e':
-        case 'E':
-        case 'f':
-        case 'F':
-        case 'g':
-        case 'G':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            if (*(formatter.fmt_cur - 1) == 'L') {
-            dst_ptr = va_arg(ap, long double *);
-            read_count = formatter.scan((long double *)dst_ptr);
-            write_size = sizeof(long double);
-            } else if (*(formatter.fmt_cur - 1) == 'l') {
-            dst_ptr = va_arg(ap, double *);
-            read_count = formatter.scan((double *)dst_ptr);
-            write_size = sizeof(double);
+          case 'a':
+          case 'A':
+          case 'e':
+          case 'E':
+          case 'f':
+          case 'F':
+          case 'g':
+          case 'G':
+            if (scanner.skip) {
+              read_count = scanner.scan();
             } else {
-            dst_ptr = va_arg(ap, float *);
-            read_count = formatter.scan((float *)dst_ptr);
-            write_size = sizeof(float);
-            }
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            dfsan_set_label(l, dst_ptr, write_size);
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+              if (*(scanner.fmt_cur - 1) == 'L') {
+                dst_ptr = va_arg(ap, long double *);
+                read_count = scanner.scan((long double *)dst_ptr);
+                write_size = sizeof(long double);
+              } else if (*(scanner.fmt_cur - 1) == 'l') {
+                dst_ptr = va_arg(ap, double *);
+                read_count = scanner.scan((double *)dst_ptr);
+                write_size = sizeof(double);
+              } else {
+                dst_ptr = va_arg(ap, float *);
+                read_count = scanner.scan((float *)dst_ptr);
+                write_size = sizeof(float);
+              }
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                dfsan_set_label(l, dst_ptr, write_size);
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+              }
             }
-          }
-          end_fmt = true;
-          break;
+            end_fmt = true;
+            break;
 
-        case 'c':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            dst_ptr = va_arg(ap, char *);
-            read_count = formatter.scan((char *)dst_ptr);
-            write_size = sizeof(char);
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+          case 'c':
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              dst_ptr = va_arg(ap, char *);
+              read_count = scanner.scan((char *)dst_ptr);
+              write_size = sizeof(char);
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+              }
             }
-          }
-          end_fmt = true;
-          break;
+            end_fmt = true;
+            break;
 
-        case 's': {
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            dst_ptr = va_arg(ap, char *);
-            read_count = formatter.scan((char *)dst_ptr);
-            if (1 == read_count) {
-            // special case: we have parsed a single string and we need to
-            // update read_count with the string size
-            read_count = strlen((char *)dst_ptr);
+          case 's': {
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              dst_ptr = va_arg(ap, char *);
+              read_count = scanner.scan((char *)dst_ptr);
+              if (1 == read_count) {
+                // special case: we have parsed a single string and we need to
+                // update read_count with the string size
+                read_count = strlen((char *)dst_ptr);
+              }
+              if (str_origin)
+                dfsan_mem_origin_transfer(
+                    dst_ptr, scanner.str_cur(),
+                    scanner.num_written_bytes(read_count));
+              va_labels++;
+              dfsan_mem_shadow_transfer(dst_ptr, scanner.str_cur(),
+                                        scanner.num_written_bytes(read_count));
             }
-            if (str_origin)
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(),
-                                      formatter.num_written_bytes(read_count));
-            va_labels++;
-            dfsan_mem_shadow_transfer(dst_ptr, formatter.str_cur(),
-                                      formatter.num_written_bytes(read_count));
+            end_fmt = true;
+            break;
           }
-          end_fmt = true;
-          break;
-        }
 
-        case 'p':
-          if (formatter.skip) {
-            read_count = formatter.scan();
-          } else {
-            dst_ptr = va_arg(ap, void *);
-            read_count =
-                formatter.scan((int *)dst_ptr);  // note: changing void* to int*
+          case 'p':
+            if (scanner.skip) {
+              read_count = scanner.scan();
+            } else {
+              dst_ptr = va_arg(ap, void *);
+              read_count =
+                  scanner.scan((int *)dst_ptr);  // note: changing void* to int*
                                                  // since we need to call sizeof
-            write_size = sizeof(int);
-
-            dfsan_label l = dfsan_read_label(
-                formatter.str_cur(), formatter.num_written_bytes(read_count));
-            dfsan_set_label(l, dst_ptr, write_size);
-            if (str_origin != nullptr) {
-            dfsan_set_label(l, dst_ptr, write_size);
-            size_t scan_count = formatter.num_written_bytes(read_count);
-            size_t size = scan_count > write_size ? write_size : scan_count;
-            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+              write_size = sizeof(int);
+
+              dfsan_label l = dfsan_read_label(
+                  scanner.str_cur(), scanner.num_written_bytes(read_count));
+              dfsan_set_label(l, dst_ptr, write_size);
+              if (str_origin != nullptr) {
+                dfsan_set_label(l, dst_ptr, write_size);
+                size_t scan_count = scanner.num_written_bytes(read_count);
+                size_t size = scan_count > write_size ? write_size : scan_count;
+                dfsan_mem_origin_transfer(dst_ptr, scanner.str_cur(), size);
+              }
             }
-          }
-          end_fmt = true;
-          break;
+            end_fmt = true;
+            break;
 
-        case 'n': {
-          if (!formatter.skip) {
-            int *ptr = va_arg(ap, int *);
-            *ptr = (int)formatter.str_off;
-            *va_labels++ = 0;
-            dfsan_set_label(0, ptr, sizeof(*ptr));
-            if (str_origin != nullptr)
-            *str_origin++ = 0;
+          case 'n': {
+            if (!scanner.skip) {
+              int *ptr = va_arg(ap, int *);
+              *ptr = (int)scanner.str_off;
+              *va_labels++ = 0;
+              dfsan_set_label(0, ptr, sizeof(*ptr));
+              if (str_origin != nullptr)
+                *str_origin++ = 0;
+            }
+            end_fmt = true;
+            break;
     ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/95333


More information about the llvm-commits mailing list