Skip to content

Commit

Permalink
⚙️ Simplify DNS forwarding, respect user-options
Browse files Browse the repository at this point in the history
The forward flag was ignored before. Simplifies overall DNS Querying
removing redundant client code. Also fix a bug while the message
response was parsed for forwarding requests
  • Loading branch information
mudler committed Mar 2, 2022
1 parent 2e8c6d5 commit 9ed2716
Showing 1 changed file with 33 additions and 36 deletions.
69 changes: 33 additions & 36 deletions pkg/services/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package services
import (
"context"
"fmt"
"net"
"regexp"
"time"

Expand Down Expand Up @@ -81,7 +80,8 @@ type dnsHandler struct {
cache *lru.Cache
}

func (d dnsHandler) parseQuery(m *dns.Msg) {
func (d dnsHandler) parseQuery(m *dns.Msg, forward bool) *dns.Msg {
response := m.Copy()
if len(m.Question) > 0 {
q := m.Question[0]
// Resolve the entry to an IP from the blockchain data
Expand All @@ -93,51 +93,60 @@ func (d dnsHandler) parseQuery(m *dns.Msg) {
if val, exists := res[dns.Type(q.Qtype)]; exists {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, dns.TypeToString[q.Qtype], val))
if err == nil {
m.Answer = append(m.Answer, rr)
return
response.Answer = append(m.Answer, rr)
return response
}
}
}
}
r, err := d.forwardQuery(m)
if err == nil {
m.Answer = r.Answer
if forward {
r, err := d.forwardQuery(m)
if err == nil {
response.Answer = r.Answer
}
}
}
return response
}

func (d dnsHandler) handleDNSRequest() func(w dns.ResponseWriter, r *dns.Msg) {
return func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Compress = false

var resp *dns.Msg
switch r.Opcode {
case dns.OpcodeQuery:
d.parseQuery(m)
resp = d.parseQuery(r, d.forwarder)
}

w.WriteMsg(m)
w.WriteMsg(resp)
}
}

func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
mess := new(dns.Msg)
mess.Question = dnsMessage.Copy().Question
if len(mess.Question) > 0 {
if v, ok := d.cache.Get(mess.Question[0].String()); ok {
reqCopy := dnsMessage.Copy()
if len(reqCopy.Question) > 0 {
if v, ok := d.cache.Get(reqCopy.Question[0].String()); ok {
q := v.(*dns.Msg)
return q, nil
}
}

for _, server := range d.forward {
r, err := QueryDNS(d.ctx, mess, server)
r, err := QueryDNS(d.ctx, reqCopy, server)
if r != nil && len(r.Answer) == 0 && !r.MsgHdr.Truncated {
continue
}

if err != nil {
return nil, err
}
if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess {
d.cache.Add(mess.Question[0].String(), r)

if r.Rcode == dns.RcodeSuccess {
d.cache.Add(reqCopy.Question[0].String(), r)
}

if r == nil || r.Rcode == dns.RcodeNameError || r.Rcode == dns.RcodeSuccess || err == nil {
return r, err
}
}
Expand All @@ -147,22 +156,10 @@ func (d dnsHandler) forwardQuery(dnsMessage *dns.Msg) (*dns.Msg, error) {
// QueryDNS queries a dns server with a dns message and return the answer
// it is blocking.
func QueryDNS(ctx context.Context, msg *dns.Msg, dnsServer string) (*dns.Msg, error) {
c := new(dns.Conn)
cc, _ := (&net.Dialer{Timeout: 35 * time.Second}).DialContext(ctx, "udp", dnsServer)
c.Conn = cc
defer c.Close()

err := c.SetWriteDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return nil, err
}
err = c.WriteMsg(msg)
if err != nil {
return nil, err
}
err = c.SetReadDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return nil, err
}
return c.ReadMsg()
client := &dns.Client{
Net: "udp",
Timeout: 30 * time.Second,
SingleInflight: true}
r, _, err := client.Exchange(msg, dnsServer)
return r, err
}

0 comments on commit 9ed2716

Please sign in to comment.