diff --git a/src/dns_conf.c b/src/dns_conf.c index 8494caf023..2096c3698f 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -763,7 +763,7 @@ static int _config_domain_rule_each_from_list(const char *file, domain_set_rule_ line_no = 0; while (fgets(line, MAX_LINE_LEN, fp)) { line_no++; - filed_num = sscanf(line, "%256s", domain); + filed_num = sscanf(line, "%255s", domain); if (filed_num <= 0) { continue; } @@ -3198,7 +3198,7 @@ static int _conf_dhcp_lease_dnsmasq_add(const char *file) line_no = 0; while (fgets(line, MAX_LINE_LEN, fp)) { line_no++; - filed_num = sscanf(line, "%*s %*s %64s %256s %*s", ip, hostname); + filed_num = sscanf(line, "%*s %*s %63s %255s %*s", ip, hostname); if (filed_num <= 0) { continue; } diff --git a/src/lib/conf.c b/src/lib/conf.c index 73979b6c5d..eb3f6c5a25 100644 --- a/src/lib/conf.c +++ b/src/lib/conf.c @@ -377,7 +377,7 @@ static int load_conf_file(const char *file, struct config_item *items, conf_erro } line_len = 0; - filed_num = sscanf(line, "%63s %8192[^\r\n]s", key, value); + filed_num = sscanf(line, "%63s %8191[^\r\n]s", key, value); if (filed_num <= 0) { continue; } diff --git a/src/smartdns.c b/src/smartdns.c index c3dc58f1ce..983d1c4836 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -459,6 +459,7 @@ static int _smartdns_init(void) int ret = 0; const char *logfile = _smartdns_log_path(); int i = 0; + char logdir[PATH_MAX] = {0}; ret = tlog_init(logfile, dns_conf_log_size, dns_conf_log_num, 0, 0); if (ret != 0) { @@ -466,7 +467,8 @@ static int _smartdns_init(void) goto errout; } - if (verbose_screen != 0 || dns_conf_log_console != 0) { + safe_strncpy(logdir, _smartdns_log_path(), PATH_MAX); + if (verbose_screen != 0 || dns_conf_log_console != 0 || access(dir_name(logdir), W_OK) != 0) { tlog_setlogscreen(1); } @@ -736,9 +738,11 @@ static void smartdns_test_notify_func(int fd_notify, uint64_t retval) } } +#define smartdns_close_allfds() close_all_fd(fd_notify); int smartdns_main(int argc, char *argv[], int fd_notify) #else #define smartdns_test_notify(retval) +#define smartdns_close_allfds() close_all_fd(-1); int main(int argc, char *argv[]) #endif { @@ -750,6 +754,7 @@ int main(int argc, char *argv[]) int signal_ignore = 0; sigset_t empty_sigblock; struct stat sb; + int daemon_ret = 0; safe_strncpy(config_file, SMARTDNS_CONF_FILE, MAX_LINE_LEN); @@ -762,6 +767,7 @@ int main(int argc, char *argv[]) /* patch for Asus router: unblock all signal*/ sigemptyset(&empty_sigblock); sigprocmask(SIG_SETMASK, &empty_sigblock, NULL); + smartdns_close_allfds(); while ((opt = getopt(argc, argv, "fhc:p:SvxN:")) != -1) { switch (opt) { @@ -769,10 +775,14 @@ int main(int argc, char *argv[]) is_foreground = 1; break; case 'c': - snprintf(config_file, sizeof(config_file), "%s", optarg); + if (full_path(config_file, sizeof(config_file), optarg) != 0) { + snprintf(config_file, sizeof(config_file), "%s", optarg); + } break; case 'p': - snprintf(pid_file, sizeof(pid_file), "%s", optarg); + if (strncmp(optarg, "-", 2) == 0 || full_path(pid_file, sizeof(pid_file), optarg) != 0) { + snprintf(pid_file, sizeof(pid_file), "%s", optarg); + } break; case 'S': signal_ignore = 1; @@ -794,16 +804,27 @@ int main(int argc, char *argv[]) } } - if (dns_server_load_conf(config_file) != 0) { + ret = dns_server_load_conf(config_file); + if (ret != 0) { fprintf(stderr, "load config failed.\n"); goto errout; } if (is_foreground == 0) { - if (daemon(0, 0) < 0) { - fprintf(stderr, "run daemon process failed, %s\n", strerror(errno)); + daemon_ret = run_daemon(); + if (daemon_ret < 0) { + char buff[4096]; + char *log_path = realpath(_smartdns_log_path(), buff); + + if (log_path != NULL && access(log_path, F_OK) == 0 && daemon_ret != -2) { + fprintf(stderr, "run daemon failed, please check log at %s\n", log_path); + } return 1; } + + if (daemon_ret == 0) { + return 0; + } } if (signal_ignore == 0) { @@ -811,6 +832,7 @@ int main(int argc, char *argv[]) } if (strncmp(pid_file, "-", 2) != 0 && create_pid_file(pid_file) != 0) { + ret = -2; goto errout; } @@ -818,9 +840,10 @@ int main(int argc, char *argv[]) signal(SIGINT, _sig_exit); signal(SIGTERM, _sig_exit); - if (_smartdns_init_pre() != 0) { + ret = _smartdns_init_pre(); + if (ret != 0) { fprintf(stderr, "init failed.\n"); - return 1; + goto errout; } drop_root_privilege(); @@ -831,11 +854,21 @@ int main(int argc, char *argv[]) goto errout; } + if (daemon_ret > 0) { + ret = daemon_kickoff(daemon_ret, 0); + if (ret != 0) { + goto errout; + } + } + smartdns_test_notify(1); ret = _smartdns_run(); _smartdns_exit(); return ret; errout: + if (daemon_ret > 0) { + daemon_kickoff(daemon_ret, ret); + } smartdns_test_notify(2); return 1; } diff --git a/src/util.c b/src/util.c index ee4f5e66cb..56c131e216 100644 --- a/src/util.c +++ b/src/util.c @@ -25,6 +25,7 @@ #include "util.h" #include #include +#include #include #include #include @@ -38,11 +39,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -806,7 +809,11 @@ int create_pid_file(const char *pid_file) } if (lockf(fd, F_TLOCK, 0) < 0) { - fprintf(stderr, "Server is already running.\n"); + memset(buff, 0, TMP_BUFF_LEN_32); + if (read(fd, buff, TMP_BUFF_LEN_32) <= 0) { + buff[0] = '\0'; + } + fprintf(stderr, "Server is already running, pid is %s", buff); goto errout; } @@ -831,6 +838,27 @@ int create_pid_file(const char *pid_file) return -1; } +int full_path(char *normalized_path, int normalized_path_len, const char *path) +{ + const char *p = path; + + if (path == NULL || normalized_path == NULL) { + return -1; + } + + while (*p == ' ') { + p++; + } + + if (*p == '\0' || *p == '/') { + return -1; + } + + char buf[PATH_MAX]; + snprintf(normalized_path, normalized_path_len, "%s/%s", getcwd(buf, sizeof(buf)), path); + return 0; +} + int generate_cert_key(const char *key_path, const char *cert_path, const char *san, int days) { int ret = -1; @@ -1479,6 +1507,156 @@ int dns_packet_save(const char *dir, const char *type, const char *from, const v return ret; } +static void _close_all_fd_by_res(void) +{ + struct rlimit lim; + int maxfd = 0; + int i = 0; + + getrlimit(RLIMIT_NOFILE, &lim); + + maxfd = lim.rlim_cur; + if (maxfd > 4096) { + maxfd = 4096; + } + + for (i = 3; i < maxfd; i++) { + close(i); + } +} + +void close_all_fd(int keepfd) +{ + DIR *dirp; + int dir_fd = -1; + struct dirent *dentp; + + dirp = opendir("/proc/self/fd"); + if (dirp == NULL) { + goto errout; + } + + dir_fd = dirfd(dirp); + + while ((dentp = readdir(dirp)) != NULL) { + int fd = atol(dentp->d_name); + if (fd < 0) { + continue; + } + + if (fd == dir_fd || fd == STDIN_FILENO || fd == STDOUT_FILENO || fd == STDERR_FILENO || fd == keepfd) { + continue; + } + close(fd); + } + + closedir(dirp); + return; +errout: + if (dirp) { + closedir(dirp); + } + _close_all_fd_by_res(); + return; +} + +int daemon_kickoff(int fd, int status) +{ + if (fd <= 0) { + return -1; + } + + int ret = write(fd, &status, sizeof(status)); + if (ret != sizeof(status)) { + return -1; + } + + int fd_null = open("/dev/null", O_RDWR); + if (fd_null < 0) { + fprintf(stderr, "open /dev/null failed, %s\n", strerror(errno)); + return -1; + } + + dup2(fd_null, STDIN_FILENO); + dup2(fd_null, STDOUT_FILENO); + dup2(fd_null, STDERR_FILENO); + + if (fd_null > 2) { + close(fd_null); + } + + close(fd); + + return 0; +} + +int run_daemon() +{ + pid_t pid = 0; + int fds[2] = {0}; + + if (pipe(fds) != 0) { + fprintf(stderr, "run daemon process failed, pipe failed, %s\n", strerror(errno)); + return -1; + } + + pid = fork(); + if (pid < 0) { + fprintf(stderr, "run daemon process failed, fork failed, %s\n", strerror(errno)); + close(fds[0]); + close(fds[1]); + return -1; + } else if (pid > 0) { + struct pollfd pfd; + int ret = 0; + int status = 0; + + close(fds[1]); + + pfd.fd = fds[0]; + pfd.events = POLLIN; + pfd.revents = 0; + + ret = poll(&pfd, 1, 1000); + if (ret <= 0) { + fprintf(stderr, "run daemon process failed, wait child timeout\n"); + goto errout; + } + + if (!(pfd.revents & POLLIN)) { + goto errout; + } + + ret = read(fds[0], &status, sizeof(status)); + if (ret != sizeof(status)) { + goto errout; + } + + return status; + } + + setsid(); + + pid = fork(); + if (pid < 0) { + fprintf(stderr, "double fork failed, %s\n", strerror(errno)); + _exit(1); + } else if (pid > 0) { + _exit(0); + } + + umask(0); + if (chdir("/") != 0) { + goto errout; + } + close(fds[0]); + return fds[1]; + +errout: + kill(pid, SIGKILL); + return -1; +} + #ifdef DEBUG struct _dns_read_packet_info { int data_len; @@ -1604,7 +1782,7 @@ static int _dns_debug_display(struct dns_packet *packet) int ret = 0; ret = dns_get_HTTPS_svcparm_start(rrs, &p, name, DNS_MAX_CNAME_LEN, &ttl, &priority, target, - DNS_MAX_CNAME_LEN); + DNS_MAX_CNAME_LEN); if (ret != 0) { printf("get HTTPS svcparm failed\n"); break; diff --git a/src/util.h b/src/util.h index b7d952296c..fa5f4d338a 100644 --- a/src/util.h +++ b/src/util.h @@ -105,6 +105,8 @@ int generate_cert_key(const char *key_path, const char *cert_path, const char *s int create_pid_file(const char *pid_file); +int full_path(char *normalized_path, int normalized_path_len, const char *path); + /* Parse a TLS packet for the Server Name Indication extension in the client * hello handshake, returning the first server name found (pointer to static * array) @@ -138,6 +140,12 @@ uint64_t get_free_space(const char *path); void print_stack(void); +void close_all_fd(int keepfd); + +int run_daemon(void); + +int daemon_kickoff(int fd, int status); + int write_file(const char *filename, void *data, int data_len); int dns_packet_save(const char *dir, const char *type, const char *from, const void *packet, int packet_len);