diff --git a/dnsproxy.go b/dnsproxy.go old mode 100644 new mode 100755 diff --git a/dnsproxy1.go b/dnsproxy1.go new file mode 100755 index 0000000..f827c74 --- /dev/null +++ b/dnsproxy1.go @@ -0,0 +1,375 @@ +package main + +import ( + "bufio" + "crypto/md5" + "encoding/hex" + "flag" + "fmt" + "github.com/miekg/dns" + "github.com/pmylund/go-cache" + "io" + "log" + "net" + "os" + "os/signal" + "path" + "path/filepath" + "runtime" + "strings" + "syscall" + "time" +) + +var ( + dnss = flag.String("dns", "192.168.2.1:53:udp,8.8.8.8:53:udp,8.8.4.4:53:udp,8.8.8.8:53:tcp,8.8.4.4:53:tcp", "dns address, use `,` as sep") + local = flag.String("local", ":53", "local listen address") + debug = flag.Int("debug", 0, "debug level 0 1 2") + encache = flag.Bool("cache", true, "enable go-cache") + expire = flag.Int64("expire", 3600, "default cache expire seconds, -1 means use doamin ttl time") + file = flag.String("file", filepath.Join(path.Dir(os.Args[0]), "cache.dat"), "cached file") + cfg = flag.String("cfg", filepath.Join(path.Dir(os.Args[0]), "hosts.cfg"), "local host file") + ipv6 = flag.Bool("6", false, "skip ipv6 record query AAAA") + timeout = flag.Int("timeout", 200, "read/write timeout") + + clientTCP *dns.Client + clientUDP *dns.Client + + DEBUG int + ENCACHE bool + + localMap map[string]string + + DNS [][]string + + conn *cache.Cache + + saveSig = make(chan os.Signal) +) + +func toMd5(data string) string { + m := md5.New() + m.Write([]byte(data)) + return hex.EncodeToString(m.Sum(nil)) +} + +func intervalSaveCache() { + save := func() { + err := conn.SaveFile(*file) + if err == nil { + log.Printf("cache save: %s\n", *file) + } else { + log.Printf("cache save failed: %s, %s\n", *file, err) + } + } + + go func() { + for { + select { + case sig := <-saveSig: + save() + switch sig { + case syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT: + os.Exit(0) + case syscall.SIGHUP: + log.Println("recv SIGHUP clear cache") + conn.Flush() + } + case <-time.After(time.Second * 10): + save() + loadCfg() + } + } + }() +} + + +func loadCfg() { + MyMap := make(map[string]string) + + f, err := os.Open(*cfg) + if err != nil { + log.Printf("cfg %s load failure!", *cfg) + return + } + defer f.Close() + + r := bufio.NewReader(f) + for { + b, _, err := r.ReadLine() + if err != nil { + if err == io.EOF { + break + } + log.Printf("cfg read failure! %v", err) + break + } + + s := strings.Split(strings.TrimSpace(string(b)), " ") + n := len(s) + if n < 2 { + log.Printf("cfg error line %v", b) + continue + } + + my_ip := s[0] + + for i:=1; i < n; i++ { + tmp_row := strings.TrimSpace(s[i]) + if len(tmp_row) == 0 { + continue + } + MyMap[tmp_row] = my_ip + } + + + log.Println(MyMap) + } + + localMap = MyMap +} + +func init() { + flag.Parse() + + ENCACHE = *encache + DEBUG = *debug + + runtime.GOMAXPROCS(runtime.NumCPU()*2 - 1) + + clientTCP = new(dns.Client) + clientTCP.Net = "tcp" + clientTCP.ReadTimeout = time.Duration(*timeout) * time.Millisecond + clientTCP.WriteTimeout = time.Duration(*timeout) * time.Millisecond + + clientUDP = new(dns.Client) + clientUDP.Net = "udp" + clientUDP.ReadTimeout = time.Duration(*timeout) * time.Millisecond + clientUDP.WriteTimeout = time.Duration(*timeout) * time.Millisecond + + loadCfg() + + if ENCACHE { + conn = cache.New(time.Second*time.Duration(*expire), time.Second*60) + conn.LoadFile(*file) + intervalSaveCache() + } + + for _, s := range strings.Split(*dnss, ",") { + s = strings.TrimSpace(s) + if s == "" { + continue + } + + dns := s + proto := "udp" + parts := strings.Split(s, ":") + + if len(parts) > 2 { + dns = strings.Join(parts[:2], ":") + if parts[2] == "tcp" { + proto = "tcp" + } + } + + _, err := net.ResolveTCPAddr("tcp", dns) + if err != nil { + log.Fatalf("wrong dns address %s\n", dns) + } + + DNS = append(DNS, []string{dns, proto}) + } + + if len(DNS) == 0 { + log.Fatalln("dns address must be not empty") + } + + signal.Notify(saveSig, syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT) +} + +func main() { + dns.HandleFunc(".", proxyServe) + + failure := make(chan error, 1) + + go func(failure chan error) { + failure <- dns.ListenAndServe(*local, "tcp", nil) + }(failure) + + go func(failure chan error) { + failure <- dns.ListenAndServe(*local, "udp", nil) + }(failure) + + log.Printf("ready for accept connection on tcp/upd %s ...\n", *local) + fmt.Println(<-failure) +} + +func proxyServe(w dns.ResponseWriter, req *dns.Msg) { + var ( + key string + m *dns.Msg + err error + tried bool + data []byte + id uint16 + query []string + questions []dns.Question + used string + ) + + defer func() { + if err := recover(); err != nil { + fmt.Println(err) + } + }() + + if req.MsgHdr.Response == true { + return + } + + query = make([]string, len(req.Question)) + + for i, q := range req.Question { + if q.Qtype != dns.TypeAAAA || *ipv6 { + questions = append(questions, q) + } + query[i] = fmt.Sprintf("(%s %s %s)", q.Name, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype]) + } + + if len(questions) == 0 { + return + } + + //check local map + dom := req.Question[0].Name + domain := dom[:-1] + v, ok := localMap[domain] + if ok { + tm := new(dns.Msg) + tm.Id = id + tm.Answer = make([]dns.RR, 1) + dom := req.Question[0].Name + trr := new(dns.A) + trr.Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5} + trr.A = net.IPv4(127, 0, 0, 1) + tm.Answer[0] = trr + + err = w.WriteMsg(tm) + goto end + } + + req.Question = questions + + id = req.Id + + req.Id = 0 + key = toMd5(req.String()) + req.Id = id + + if ENCACHE { + if reply, ok := conn.Get(key); ok { + data, _ = reply.([]byte) + } + + if data != nil && len(data) > 0 { + m = &dns.Msg{} + m.Unpack(data) + m.Id = id + err = w.WriteMsg(m) + + if DEBUG > 0 { + log.Printf("id:%5d cache: HIT %v\n", id, query) + } + + goto end + } else { + if DEBUG > 0 { + log.Printf("id: %5d cache: MISS %v\n", id, query) + } + } + } + + for i, parts := range DNS { + dns := parts[0] + proto := parts[1] + + tried = i > 0 + + if DEBUG > 0 { + if tried { + log.Printf("id: 5%d try: %v %s %s\n", id, query, dns, proto) + } else { + log.Printf("id: 5%d resolve: %v %s %s\n", id, query, dns, proto) + } + } + + client := clientUDP + if proto == "tcp" { + client = clientTCP + } + + m, _, err = client.Exchange(req, dns) + + if err == nil && len(m.Answer) > 0 { + used = dns + break + } + } + + if err == nil { + if DEBUG > 0 { + if tried { + if len(m.Answer) == 0 { + log.Printf("id: %5d failed: %v\n", id, query) + } else { + log.Printf("id: %5d bingo: %v %s\n", id, query, used) + } + } + } + + data, err = m.Pack() + + if err == nil { + _, err = w.Write(data) + + if err == nil { + if ENCACHE { + m.Id = 0 + data, _ = m.Pack() + ttl := 0 + + if len(m.Answer) > 0 { + ttl = int(m.Answer[0].Header().Ttl) + + if ttl < 0 { + ttl = 0 + } + } + conn.Set(key, data, time.Second*time.Duration(ttl)) + + m.Id = id + if DEBUG > 0 { + log.Printf("id: %5d cache: CACHED %v TTL %v\n", id, query, ttl) + } + } + } + } + } + +end: + if DEBUG > 1 { + fmt.Println(req) + + if m != nil { + fmt.Println(m) + } + } + + if err != nil { + log.Printf("id: %5d error: %v %s\n", id, query, err) + } + + if DEBUG > 1 { + fmt.Println("======================================") + } +} diff --git a/hosts.cfg b/hosts.cfg new file mode 100755 index 0000000..5f4c50e --- /dev/null +++ b/hosts.cfg @@ -0,0 +1,2 @@ +127.0.0.1 www.baidu.com sina.com.cn +8.8.8.8 google.com \ No newline at end of file