[compiler-rt] 8dbcf8e - [DFSAN] Add support for sscanf.

Andrew Browne via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 5 18:16:53 PDT 2023


Author: Tomasz Kuchta
Date: 2023-09-06T01:16:31Z
New Revision: 8dbcf8eba780581a830ad0fabaa6151ad6afc302

URL: https://github.com/llvm/llvm-project/commit/8dbcf8eba780581a830ad0fabaa6151ad6afc302
DIFF: https://github.com/llvm/llvm-project/commit/8dbcf8eba780581a830ad0fabaa6151ad6afc302.diff

LOG: [DFSAN] Add support for sscanf.

Reviewed By: browneee

Differential Revision: https://reviews.llvm.org/D153775

Added: 
    

Modified: 
    compiler-rt/lib/dfsan/dfsan_custom.cpp
    compiler-rt/lib/dfsan/done_abilist.txt
    compiler-rt/test/dfsan/custom.cpp

Removed: 
    


################################################################################
diff  --git a/compiler-rt/lib/dfsan/dfsan_custom.cpp b/compiler-rt/lib/dfsan/dfsan_custom.cpp
index f41dd50617fbc8..7e7af8434b9c95 100644
--- a/compiler-rt/lib/dfsan/dfsan_custom.cpp
+++ b/compiler-rt/lib/dfsan/dfsan_custom.cpp
@@ -2240,8 +2240,13 @@ typedef int dfsan_label_va;
 // '%.3f').
 struct Formatter {
   Formatter(char *str_, const char *fmt_, size_t size_)
-      : str(str_), str_off(0), size(size_), fmt_start(fmt_), fmt_cur(fmt_),
-        width(-1) {}
+      : str(str_),
+        str_off(0),
+        size(size_),
+        fmt_start(fmt_),
+        fmt_cur(fmt_),
+        width(-1),
+        num_scanned(-1) {}
 
   int format() {
     char *tmp_fmt = build_format_string();
@@ -2266,12 +2271,50 @@ struct Formatter {
     return retval;
   }
 
-  char *build_format_string() {
+  int scan() {
+    char *tmp_fmt = build_format_string(true);
+    int read_count = 0;
+    int retval = sscanf(str + str_off, tmp_fmt, &read_count);
+    if (retval > 0) {
+      if (-1 == num_scanned)
+        num_scanned = 0;
+      num_scanned += retval;
+    }
+    free(tmp_fmt);
+    return read_count;
+  }
+
+  template <typename T>
+  int scan(T arg) {
+    char *tmp_fmt = build_format_string(true);
+    int read_count = 0;
+    int retval = sscanf(str + str_off, tmp_fmt, arg, &read_count);
+    if (retval > 0) {
+      if (-1 == num_scanned)
+        num_scanned = 0;
+      num_scanned += retval;
+    }
+    free(tmp_fmt);
+    return read_count;
+  }
+
+  // with_n -> toggles adding %n on/off; off by default
+  char *build_format_string(bool with_n = false) {
     size_t fmt_size = fmt_cur - fmt_start + 1;
-    char *new_fmt = (char *)malloc(fmt_size + 1);
+    size_t add_size = 0;
+    if (with_n)
+      add_size = 2;
+    char *new_fmt = (char *)malloc(fmt_size + 1 + add_size);
     assert(new_fmt);
     internal_memcpy(new_fmt, fmt_start, fmt_size);
-    new_fmt[fmt_size] = '\0';
+    if (!with_n) {
+      new_fmt[fmt_size] = '\0';
+    } else {
+      new_fmt[fmt_size] = '%';
+      new_fmt[fmt_size + 1] = 'n';
+      new_fmt[fmt_size + 2] = '\0';
+    }
+
     return new_fmt;
   }
 
@@ -2303,6 +2346,7 @@ struct Formatter {
   const char *fmt_start;
   const char *fmt_cur;
   int width;
+  int num_scanned;
 };
 
 // Formats the input and propagates the input labels to the output. The output
@@ -2495,6 +2539,229 @@ static int format_buffer(char *str, size_t size, const char *fmt,
   return formatter.str_off;
 }
 
+// This function is an inverse of format_buffer: we take the input buffer,
+// scan it in search for format strings and store the results in the varargs.
+// The labels are propagated from the input buffer to the varargs.
+static int scan_buffer(char *str, size_t size, const char *fmt,
+                       dfsan_label *va_labels, dfsan_label *ret_label,
+                       dfsan_origin *str_origin, dfsan_origin *ret_origin,
+                       va_list ap) {
+  Formatter formatter(str, fmt, size);
+  while (*formatter.fmt_cur) {
+    formatter.fmt_start = formatter.fmt_cur;
+    formatter.width = -1;
+    int retval = 0;
+    dfsan_label l = 0;
+    void *dst_ptr = 0;
+    size_t write_size = 0;
+    if (*formatter.fmt_cur != '%') {
+      // Ordinary character. Consume all the characters until a '%' or the end
+      // of the string.
+      for (; *(formatter.fmt_cur + 1) && *(formatter.fmt_cur + 1) != '%';
+           ++formatter.fmt_cur) {
+      }
+      retval = formatter.scan();
+      dfsan_set_label(0, formatter.str_cur(),
+                      formatter.num_written_bytes(retval));
+    } else {
+      // Conversion directive. Consume all the characters until a conversion
+      // specifier or the end of the string.
+      bool end_fmt = false;
+      for (; *formatter.fmt_cur && !end_fmt;) {
+        switch (*++formatter.fmt_cur) {
+        case 'd':
+        case 'i':
+        case 'o':
+        case 'u':
+        case 'x':
+        case 'X':
+          switch (*(formatter.fmt_cur - 1)) {
+            case 'h':
+            // Also covers the 'hh' case (since the size of the arg is still
+            // an int).
+            dst_ptr = va_arg(ap, int *);
+            retval = formatter.scan((int *)dst_ptr);
+            write_size = sizeof(int);
+            break;
+            case 'l':
+            if (formatter.fmt_cur - formatter.fmt_start >= 2 &&
+                *(formatter.fmt_cur - 2) == 'l') {
+              dst_ptr = va_arg(ap, long long int *);
+              retval = formatter.scan((long long int *)dst_ptr);
+              write_size = sizeof(long long int);
+            } else {
+              dst_ptr = va_arg(ap, long int *);
+              retval = formatter.scan((long int *)dst_ptr);
+              write_size = sizeof(long int);
+            }
+            break;
+            case 'q':
+            dst_ptr = va_arg(ap, long long int *);
+            retval = formatter.scan((long long int *)dst_ptr);
+            write_size = sizeof(long long int);
+            break;
+            case 'j':
+            dst_ptr = va_arg(ap, intmax_t *);
+            retval = formatter.scan((intmax_t *)dst_ptr);
+            write_size = sizeof(intmax_t);
+            break;
+            case 'z':
+            case 't':
+            dst_ptr = va_arg(ap, size_t *);
+            retval = formatter.scan((size_t *)dst_ptr);
+            write_size = sizeof(size_t);
+            break;
+            default:
+            dst_ptr = va_arg(ap, int *);
+            retval = formatter.scan((int *)dst_ptr);
+            write_size = sizeof(int);
+          }
+          // get the label associated with the string at the corresponding
+          // place
+          l = dfsan_read_label(formatter.str_cur(),
+                               formatter.num_written_bytes(retval));
+          if (str_origin == nullptr)
+            dfsan_set_label(l, dst_ptr, write_size);
+          else {
+            dfsan_set_label(l, dst_ptr, write_size);
+            size_t scan_count = formatter.num_written_bytes(retval);
+            size_t size = scan_count > write_size ? write_size : scan_count;
+            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+          }
+          end_fmt = true;
+
+          break;
+
+        case 'a':
+        case 'A':
+        case 'e':
+        case 'E':
+        case 'f':
+        case 'F':
+        case 'g':
+        case 'G':
+          if (*(formatter.fmt_cur - 1) == 'L') {
+            dst_ptr = va_arg(ap, long double *);
+            retval = formatter.scan((long double *)dst_ptr);
+            write_size = sizeof(long double);
+          } else if (*(formatter.fmt_cur - 1) == 'l') {
+            dst_ptr = va_arg(ap, double *);
+            retval = formatter.scan((double *)dst_ptr);
+            write_size = sizeof(double);
+          } else {
+            dst_ptr = va_arg(ap, float *);
+            retval = formatter.scan((float *)dst_ptr);
+            write_size = sizeof(float);
+          }
+          l = dfsan_read_label(formatter.str_cur(),
+                               formatter.num_written_bytes(retval));
+          if (str_origin == nullptr)
+            dfsan_set_label(l, dst_ptr, write_size);
+          else {
+            dfsan_set_label(l, dst_ptr, write_size);
+            size_t scan_count = formatter.num_written_bytes(retval);
+            size_t size = scan_count > write_size ? write_size : scan_count;
+            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+          }
+          end_fmt = true;
+          break;
+
+        case 'c':
+          dst_ptr = va_arg(ap, char *);
+          retval = formatter.scan((char *)dst_ptr);
+          write_size = sizeof(char);
+          l = dfsan_read_label(formatter.str_cur(),
+                               formatter.num_written_bytes(retval));
+          if (str_origin == nullptr)
+            dfsan_set_label(l, dst_ptr, write_size);
+          else {
+            dfsan_set_label(l, dst_ptr, write_size);
+            size_t scan_count = formatter.num_written_bytes(retval);
+            size_t size = scan_count > write_size ? write_size : scan_count;
+            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+          }
+          end_fmt = true;
+          break;
+
+        case 's': {
+          dst_ptr = va_arg(ap, char *);
+          retval = formatter.scan((char *)dst_ptr);
+          if (1 == retval) {
+            // special case: we have parsed a single string and we need to
+            // update retval with the string size
+            retval = strlen((char *)dst_ptr);
+          }
+          if (str_origin)
+            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(),
+                                      formatter.num_written_bytes(retval));
+          va_labels++;
+          dfsan_mem_shadow_transfer(dst_ptr, formatter.str_cur(),
+                                    formatter.num_written_bytes(retval));
+          end_fmt = true;
+          break;
+        }
+
+        case 'p':
+          dst_ptr = va_arg(ap, void *);
+          retval =
+              formatter.scan((int *)dst_ptr);  // note: changing void* to int*
+                                               // since we need to call sizeof
+          write_size = sizeof(int);
+
+          l = dfsan_read_label(formatter.str_cur(),
+                               formatter.num_written_bytes(retval));
+          if (str_origin == nullptr)
+            dfsan_set_label(l, dst_ptr, write_size);
+          else {
+            dfsan_set_label(l, dst_ptr, write_size);
+            size_t scan_count = formatter.num_written_bytes(retval);
+            size_t size = scan_count > write_size ? write_size : scan_count;
+            dfsan_mem_origin_transfer(dst_ptr, formatter.str_cur(), size);
+          }
+          end_fmt = true;
+          break;
+
+        case 'n': {
+          int *ptr = va_arg(ap, int *);
+          *ptr = (int)formatter.str_off;
+          va_labels++;
+          dfsan_set_label(0, ptr, sizeof(*ptr));
+          end_fmt = true;
+          break;
+        }
+
+        case '%':
+          retval = formatter.scan();
+          end_fmt = true;
+          break;
+
+        case '*':
+          formatter.width = va_arg(ap, int);
+          va_labels++;
+          break;
+
+        default:
+          break;
+        }
+      }
+    }
+
+    if (retval < 0) {
+      return retval;
+    }
+
+    formatter.fmt_cur++;
+    formatter.str_off += retval;
+  }
+
+  *ret_label = 0;
+  if (ret_origin)
+    *ret_origin = 0;
+
+  // Number of items scanned in total.
+  return formatter.num_scanned;
+}
+
 extern "C" {
 SANITIZER_INTERFACE_ATTRIBUTE
 int __dfsw_sprintf(char *str, const char *format, dfsan_label str_label,
@@ -2502,6 +2769,7 @@ int __dfsw_sprintf(char *str, const char *format, dfsan_label str_label,
                    dfsan_label *ret_label, ...) {
   va_list ap;
   va_start(ap, ret_label);
+
   int ret = format_buffer(str, ~0ul, format, va_labels, ret_label, nullptr,
                           nullptr, ap);
   va_end(ap);
@@ -2550,6 +2818,58 @@ int __dfso_snprintf(char *str, size_t size, const char *format,
   return ret;
 }
 
+SANITIZER_INTERFACE_ATTRIBUTE
+int __dfsw_sscanf(char *str, const char *format, dfsan_label str_label,
+                  dfsan_label format_label, dfsan_label *va_labels,
+                  dfsan_label *ret_label, ...) {
+  va_list ap;
+  va_start(ap, ret_label);
+  int ret = scan_buffer(str, ~0ul, format, va_labels, ret_label, nullptr,
+                        nullptr, ap);
+  va_end(ap);
+  return ret;
+}
+
+SANITIZER_INTERFACE_ATTRIBUTE
+int __dfso_sscanf(char *str, const char *format, dfsan_label str_label,
+                  dfsan_label format_label, dfsan_label *va_labels,
+                  dfsan_label *ret_label, dfsan_origin str_origin,
+                  dfsan_origin format_origin, dfsan_origin *va_origins,
+                  dfsan_origin *ret_origin, ...) {
+  va_list ap;
+  va_start(ap, ret_origin);
+  int ret = scan_buffer(str, ~0ul, format, va_labels, ret_label, &str_origin,
+                        ret_origin, ap);
+  va_end(ap);
+  return ret;
+}
+
+SANITIZER_INTERFACE_ATTRIBUTE
+int __dfsw___isoc99_sscanf(char *str, const char *format, dfsan_label str_label,
+                           dfsan_label format_label, dfsan_label *va_labels,
+                           dfsan_label *ret_label, ...) {
+  va_list ap;
+  va_start(ap, ret_label);
+  int ret = scan_buffer(str, ~0ul, format, va_labels, ret_label, nullptr,
+                        nullptr, ap);
+  va_end(ap);
+  return ret;
+}
+
+SANITIZER_INTERFACE_ATTRIBUTE
+int __dfso___isoc99_sscanf(char *str, const char *format, dfsan_label str_label,
+                           dfsan_label format_label, dfsan_label *va_labels,
+                           dfsan_label *ret_label, dfsan_origin str_origin,
+                           dfsan_origin format_origin, dfsan_origin *va_origins,
+                           dfsan_origin *ret_origin, ...) {
+  va_list ap;
+  va_start(ap, ret_origin);
+  int ret = scan_buffer(str, ~0ul, format, va_labels, ret_label, &str_origin,
+                        ret_origin, ap);
+  va_end(ap);
+  return ret;
+}
+
 static void BeforeFork() {
   StackDepotLockAll();
   GetChainedOriginDepot()->LockAll();

diff  --git a/compiler-rt/lib/dfsan/done_abilist.txt b/compiler-rt/lib/dfsan/done_abilist.txt
index 84d1b518840134..c582584d77e45f 100644
--- a/compiler-rt/lib/dfsan/done_abilist.txt
+++ b/compiler-rt/lib/dfsan/done_abilist.txt
@@ -308,6 +308,10 @@ fun:gettimeofday=custom
 fun:sprintf=custom
 fun:snprintf=custom
 
+# scanf-like
+fun:sscanf=custom
+fun:__isoc99_sscanf=custom
+
 # TODO: custom
 fun:asprintf=discard
 fun:qsort=discard

diff  --git a/compiler-rt/test/dfsan/custom.cpp b/compiler-rt/test/dfsan/custom.cpp
index c67602d4538e63..dfc24ee3019efb 100644
--- a/compiler-rt/test/dfsan/custom.cpp
+++ b/compiler-rt/test/dfsan/custom.cpp
@@ -2095,6 +2095,154 @@ void test_snprintf() {
   ASSERT_LABEL(r, 0);
 }
 
+template <class T>
+void test_sscanf_chunk(T expected, const char *format, char *input,
+                       int items_num) {
+  char padded_input[512];
+  strcpy(padded_input, "foo ");
+  strcat(padded_input, input);
+  strcat(padded_input, " bar");
+
+  char padded_format[512];
+  strcpy(padded_format, "foo ");
+  strcat(padded_format, format);
+  strcat(padded_format, " bar");
+
+  char *s = padded_input + 4;
+  T arg;
+  memset(&arg, 0, sizeof(arg));
+  dfsan_set_label(i_label, (void *)(s), strlen(input));
+  dfsan_set_label(j_label, (void *)(padded_format + 4), strlen(format));
+  dfsan_origin a_o = dfsan_get_origin((long)(*s));
+#ifndef ORIGIN_TRACKING
+  (void)a_o;
+#else
+  assert(a_o != 0);
+#endif
+  int rv = sscanf(padded_input, padded_format, &arg);
+  assert(rv == items_num);
+  assert(arg == expected);
+  ASSERT_READ_LABEL(&arg, sizeof(arg), i_label);
+  ASSERT_INIT_ORIGINS(&arg, 1, a_o);
+}
+
+void test_sscanf() {
+  char buf[2048];
+  char buf_out[2048];
+  memset(buf, 'a', sizeof(buf));
+  memset(buf_out, 'a', sizeof(buf_out));
+
+  // Test formatting
+  strcpy(buf, "Hello world!");
+  assert(sscanf(buf, "%s", buf_out) == 1);
+  assert(strcmp(buf, "Hello world!") == 0);
+  assert(strcmp(buf_out, "Hello") == 0);
+  ASSERT_READ_LABEL(buf, sizeof(buf), 0);
+  ASSERT_READ_LABEL(buf_out, sizeof(buf_out), 0);
+
+  // Test for extra arguments.
+  assert(sscanf(buf, "%s", buf_out, 42, "hello") == 1);
+  assert(strcmp(buf, "Hello world!") == 0);
+  assert(strcmp(buf_out, "Hello") == 0);
+  ASSERT_READ_LABEL(buf, sizeof(buf), 0);
+  ASSERT_READ_LABEL(buf_out, sizeof(buf_out), 0);
+
+  // Test formatting & label propagation (multiple conversion specifiers): %s,
+  // %d, %n, %f, and %%.
+  int n;
+  strcpy(buf, "hello world, 2014/8/27 12345.678123 % 1000");
+  char *s = buf + 6; //starts with world
+  int y = 0;
+  int m = 0;
+  int d = 0;
+  float fval;
+  int val = 0;
+  dfsan_set_label(k_label, (void *)(s + 1), 2); // buf[7]-b[9]
+  dfsan_origin s_o = dfsan_get_origin((long)(s[1]));
+  dfsan_set_label(i_label, (void *)(s + 12), 1);
+  dfsan_origin m_o = dfsan_get_origin((long)s[12]); // buf[18]
+  dfsan_set_label(j_label, (void *)(s + 14), 2);    // buf[20]
+  dfsan_origin d_o = dfsan_get_origin((long)s[14]);
+  dfsan_set_label(m_label, (void *)(s + 18), 4); //buf[24]
+  dfsan_origin f_o = dfsan_get_origin((long)s[18]);
+
+#ifndef ORIGIN_TRACKING
+  (void)s_o;
+  (void)m_o;
+  (void)d_o;
+  (void)f_o;
+#else
+  assert(s_o != 0);
+  assert(m_o != 0);
+  assert(d_o != 0);
+  assert(f_o != 0);
+#endif
+  int r = sscanf(buf, "hello %s %d/%d/%d %f %% %n%d", buf_out, &y, &m, &d,
+                 &fval, &n, &val);
+  assert(r == 6);
+  assert(strcmp(buf_out, "world,") == 0);
+  ASSERT_READ_LABEL(buf_out, 1, 0);
+  ASSERT_READ_LABEL(buf_out + 1, 2, k_label);
+  ASSERT_INIT_ORIGINS(buf_out + 1, 2, s_o);
+  ASSERT_READ_LABEL(buf + 9, 9, 0);
+  ASSERT_READ_LABEL(&m, 1, i_label);
+  ASSERT_INIT_ORIGINS(&m, 1, m_o);
+  ASSERT_READ_LABEL(&d, 4, j_label);
+  ASSERT_INIT_ORIGINS(&d, 2, d_o);
+  ASSERT_READ_LABEL(&fval, sizeof(fval), m_label);
+  ASSERT_INIT_ORIGINS(&fval, sizeof(fval), f_o);
+  ASSERT_READ_LABEL(&val, 4, 0);
+  ASSERT_LABEL(r, 0);
+  assert(n == 38);
+  assert(val == 1000);
+
+  // Test formatting & label propagation (single conversion specifier, with
+  // additional length and precision modifiers).
+  char input_buf[512];
+  char *input_ptr = input_buf;
+  strcpy(input_buf, "-559038737");
+  test_sscanf_chunk(-559038737, "%d", input_ptr, 1);
+  strcpy(input_buf, "3735928559");
+  test_sscanf_chunk(3735928559, "%u", input_ptr, 1);
+  strcpy(input_buf, "12345");
+  test_sscanf_chunk(12345, "%i", input_ptr, 1);
+  strcpy(input_buf, "0751");
+  test_sscanf_chunk(489, "%o", input_ptr, 1);
+  strcpy(input_buf, "0xbabe");
+  test_sscanf_chunk(47806, "%x", input_ptr, 1);
+  strcpy(input_buf, "0x0000BABE");
+  test_sscanf_chunk(47806, "%10X", input_ptr, 1);
+  strcpy(input_buf, "3735928559");
+  test_sscanf_chunk((char)-17, "%hhd", input_ptr, 1);
+  strcpy(input_buf, "3735928559");
+  test_sscanf_chunk((short)-16657, "%hd", input_ptr, 1);
+  strcpy(input_buf, "0xdeadbeefdeadbeef");
+  test_sscanf_chunk(0xdeadbeefdeadbeefL, "%lx", input_buf, 1);
+  test_sscanf_chunk((void *)0xdeadbeefdeadbeefL, "%p", input_buf, 1);
+
+  intmax_t _x = (intmax_t)-1;
+  char _buf[256];
+  memset(_buf, 0, sizeof(_buf));
+  sprintf(_buf, "%ju", _x);
+  test_sscanf_chunk((intmax_t)18446744073709551615, "%ju", _buf, 1);
+  memset(_buf, 0, sizeof(_buf));
+  size_t _y = (size_t)-1;
+  sprintf(_buf, "%zu", _y);
+  test_sscanf_chunk((size_t)18446744073709551615, "%zu", _buf, 1);
+  memset(_buf, 0, sizeof(_buf));
+  ptr
diff _t _z = (size_t)-1;
+  sprintf(_buf, "%tu", _z);
+  test_sscanf_chunk((ptr
diff _t)18446744073709551615, "%tu", _buf, 1);
+
+  strcpy(input_buf, "0.123456");
+  test_sscanf_chunk((float)0.123456, "%8f", input_ptr, 1);
+  test_sscanf_chunk((float)0.123456, "%g", input_ptr, 1);
+  test_sscanf_chunk((float)1.234560e-01, "%e", input_ptr, 1);
+  test_sscanf_chunk((char)'z', "%c", "z", 1);
+
+  // %n, %s, %d, %f, and %% already tested
+}
+
 // Tested by a seperate source file.  This empty function is here to appease the
 // check-wrappers script.
 void test_fork() {}
@@ -2154,6 +2302,7 @@ int main(void) {
   test_sigaltstack();
   test_sigemptyset();
   test_snprintf();
+  test_sscanf();
   test_socketpair();
   test_sprintf();
   test_stat();


        


More information about the llvm-commits mailing list