From 7ff6ae3ea074fd5c0765fcb4601f1db2bc131ec5 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Tue, 11 Jul 2023 23:30:39 +0800 Subject: [PATCH] dns_server: fix edns subnet not working issue. --- src/dns.c | 16 +++++ src/dns.h | 3 + src/dns_client.c | 33 ++++----- test/cases/test-subnet.cc | 142 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 175 insertions(+), 19 deletions(-) diff --git a/src/dns.c b/src/dns.c index 7e15659f4a..f70edc09c7 100644 --- a/src/dns.c +++ b/src/dns.c @@ -875,6 +875,22 @@ int dns_get_PTR(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char * return _dns_get_RAW(rrs, domain, maxsize, ttl, cname, &len); } +int dns_add_TXT(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *text) +{ + int rr_len = strnlen(text, DNS_MAX_CNAME_LEN); + char data[DNS_MAX_CNAME_LEN]; + data[0] = rr_len; + rr_len++; + memcpy(data + 1, text, rr_len); + data[rr_len] = 0; + return _dns_add_RAW(packet, type, DNS_T_TXT, domain, ttl, data, rr_len); +} + +int dns_get_TXT(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *text, int txt_size) +{ + return -1; +} + int dns_add_NS(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *cname) { int rr_len = strnlen(cname, DNS_MAX_CNAME_LEN) + 1; diff --git a/src/dns.h b/src/dns.h index e44b0648ff..0ed573a5d3 100644 --- a/src/dns.h +++ b/src/dns.h @@ -262,6 +262,9 @@ int dns_get_A(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned int dns_add_PTR(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *cname); int dns_get_PTR(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *cname, int cname_size); +int dns_add_TXT(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, const char *text); +int dns_get_TXT(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, char *text, int txt_size); + int dns_add_AAAA(struct dns_packet *packet, dns_rr_type type, const char *domain, int ttl, unsigned char addr[DNS_RR_AAAA_LEN]); int dns_get_AAAA(struct dns_rrs *rrs, char *domain, int maxsize, int *ttl, unsigned char addr[DNS_RR_AAAA_LEN]); diff --git a/src/dns_client.c b/src/dns_client.c index 894175f9f3..ed45031bbe 100644 --- a/src/dns_client.c +++ b/src/dns_client.c @@ -3378,16 +3378,7 @@ static int _dns_client_setup_server_packet(struct dns_server_info *server_info, *packet_data = default_packet; *packet_data_len = default_packet_len; - if (query->qtype != DNS_T_AAAA && query->qtype != DNS_T_A) { - /* no need to encode packet */ - return 0; - } - - if (server_info->ecs_ipv4.enable == true && query->qtype == DNS_T_A) { - repack = 1; - } - - if (server_info->ecs_ipv6.enable == true && query->qtype == DNS_T_AAAA) { + if (server_info->ecs_ipv4.enable == true || server_info->ecs_ipv6.enable == true) { repack = 1; } @@ -3429,12 +3420,16 @@ static int _dns_client_setup_server_packet(struct dns_server_info *server_info, dns_set_OPT_payload_size(packet, DNS_IN_PACKSIZE); /* dns_add_OPT_TCP_KEEPALIVE(packet, 600); */ - if ((query->qtype == DNS_T_A && server_info->ecs_ipv4.enable) || - (query->qtype == DNS_T_AAAA && server_info->ecs_ipv6.enable == 0 && server_info->ecs_ipv4.enable)) { + if ((query->qtype == DNS_T_A && server_info->ecs_ipv4.enable)) { dns_add_OPT_ECS(packet, &server_info->ecs_ipv4.ecs); - } else if ((query->qtype == DNS_T_AAAA && server_info->ecs_ipv6.enable) || - (query->qtype == DNS_T_A && server_info->ecs_ipv4.enable == 0 && server_info->ecs_ipv6.enable)) { + } else if ((query->qtype == DNS_T_AAAA && server_info->ecs_ipv6.enable)) { dns_add_OPT_ECS(packet, &server_info->ecs_ipv6.ecs); + } else { + if (server_info->ecs_ipv6.enable) { + dns_add_OPT_ECS(packet, &server_info->ecs_ipv6.ecs); + } else if (server_info->ecs_ipv4.enable) { + dns_add_OPT_ECS(packet, &server_info->ecs_ipv4.ecs); + } } /* encode packet */ @@ -3671,17 +3666,17 @@ static int _dns_client_query_setup_default_ecs(struct dns_query_struct *query) if (client.ecs_ipv4.enable) { add_ipv4_ecs = 1; } else if (client.ecs_ipv6.enable) { - add_ipv4_ecs = 1; + add_ipv6_ecs = 1; } } - if (add_ipv4_ecs) { - memcpy(&query->ecs, &client.ecs_ipv4, sizeof(query->ecs)); + if (add_ipv6_ecs) { + memcpy(&query->ecs, &client.ecs_ipv6, sizeof(query->ecs)); return 0; } - if (add_ipv6_ecs) { - memcpy(&query->ecs, &client.ecs_ipv6, sizeof(query->ecs)); + if (add_ipv4_ecs) { + memcpy(&query->ecs, &client.ecs_ipv4, sizeof(query->ecs)); return 0; } diff --git a/test/cases/test-subnet.cc b/test/cases/test-subnet.cc index 77b5abd191..0ab9f572c8 100644 --- a/test/cases/test-subnet.cc +++ b/test/cases/test-subnet.cc @@ -245,6 +245,148 @@ cache-persist no)"""); EXPECT_EQ(client.GetAnswer()[0].GetData(), "2001:db8::1"); } +TEST_F(SubNet, v4_server_subnet_txt) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_TXT) { + return smartdns::SERVER_REQUEST_SOA; + } + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + + if (has_ecs == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.family != DNS_OPT_ECS_FAMILY_IPV4) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\x08\x08\x08\x00", 4) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 24) { + return smartdns::SERVER_REQUEST_ERROR; + } + + dns_add_TXT(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 6, "hello world"); + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 -subnet 8.8.8.8/24 +log-num 0 +log-console yes +dualstack-ip-selection no +log-level debug +rr-ttl-min 0 +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com TXT", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 6); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "TXT"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "\"hello world\""); +} + +TEST_F(SubNet, v6_default_subnet_txt) +{ + smartdns::MockServer server_upstream; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [&](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_TXT) { + return smartdns::SERVER_REQUEST_SOA; + } + struct dns_opt_ecs ecs; + struct dns_rrs *rrs = NULL; + int rr_count = 0; + int i = 0; + int ret = 0; + int has_ecs = 0; + + rr_count = 0; + rrs = dns_get_rrs_start(request->packet, DNS_RRS_OPT, &rr_count); + if (rr_count <= 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + for (i = 0; i < rr_count && rrs; i++, rrs = dns_get_rrs_next(request->packet, rrs)) { + memset(&ecs, 0, sizeof(ecs)); + ret = dns_get_OPT_ECS(rrs, NULL, NULL, &ecs); + if (ret != 0) { + continue; + } + has_ecs = 1; + break; + } + if (has_ecs == 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.family != DNS_OPT_ECS_FAMILY_IPV6) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (memcmp(ecs.addr, "\xff\xff\xff\xff\xff\xff\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00", 16) != 0) { + return smartdns::SERVER_REQUEST_ERROR; + } + + if (ecs.source_prefix != 64) { + return smartdns::SERVER_REQUEST_ERROR; + } + + dns_add_TXT(request->response_packet, DNS_RRS_AN, request->domain.c_str(), 6, "hello world"); + return smartdns::SERVER_REQUEST_OK; + }); + + server.Start(R"""(bind [::]:60053 +server 127.0.0.1:61053 +log-num 0 +log-console yes +dualstack-ip-selection no +rr-ttl-min 0 +edns-client-subnet ffff:ffff:ffff:ffff:ffff::/64 +log-level debug +cache-persist no)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("a.com TXT", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "a.com"); + EXPECT_EQ(client.GetAnswer()[0].GetTTL(), 6); + EXPECT_EQ(client.GetAnswer()[0].GetType(), "TXT"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "\"hello world\""); +} + TEST_F(SubNet, per_server) { smartdns::MockServer server_upstream1;