Skip to content

Commit

Permalink
kkk
Browse files Browse the repository at this point in the history
  • Loading branch information
wjzhangq committed Feb 1, 2016
1 parent 5ae38ab commit 48a4a47
Show file tree
Hide file tree
Showing 3 changed files with 377 additions and 0 deletions.
Empty file modified dnsproxy.go
100644 → 100755
Empty file.
375 changes: 375 additions & 0 deletions dnsproxy1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
package main

import (
"bufio"
"crypto/md5"
"encoding/hex"
"flag"
"fmt"
"github.com/miekg/dns"
"github.com/pmylund/go-cache"
"io"
"log"
"net"
"os"
"os/signal"
"path"
"path/filepath"
"runtime"
"strings"
"syscall"
"time"
)

var (
dnss = flag.String("dns", "192.168.2.1:53:udp,8.8.8.8:53:udp,8.8.4.4:53:udp,8.8.8.8:53:tcp,8.8.4.4:53:tcp", "dns address, use `,` as sep")
local = flag.String("local", ":53", "local listen address")
debug = flag.Int("debug", 0, "debug level 0 1 2")
encache = flag.Bool("cache", true, "enable go-cache")
expire = flag.Int64("expire", 3600, "default cache expire seconds, -1 means use doamin ttl time")
file = flag.String("file", filepath.Join(path.Dir(os.Args[0]), "cache.dat"), "cached file")
cfg = flag.String("cfg", filepath.Join(path.Dir(os.Args[0]), "hosts.cfg"), "local host file")
ipv6 = flag.Bool("6", false, "skip ipv6 record query AAAA")
timeout = flag.Int("timeout", 200, "read/write timeout")

clientTCP *dns.Client
clientUDP *dns.Client

DEBUG int
ENCACHE bool

localMap map[string]string

DNS [][]string

conn *cache.Cache

saveSig = make(chan os.Signal)
)

func toMd5(data string) string {
m := md5.New()
m.Write([]byte(data))
return hex.EncodeToString(m.Sum(nil))
}

func intervalSaveCache() {
save := func() {
err := conn.SaveFile(*file)
if err == nil {
log.Printf("cache save: %s\n", *file)
} else {
log.Printf("cache save failed: %s, %s\n", *file, err)
}
}

go func() {
for {
select {
case sig := <-saveSig:
save()
switch sig {
case syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT:
os.Exit(0)
case syscall.SIGHUP:
log.Println("recv SIGHUP clear cache")
conn.Flush()
}
case <-time.After(time.Second * 10):
save()
loadCfg()
}
}
}()
}


func loadCfg() {
MyMap := make(map[string]string)

f, err := os.Open(*cfg)
if err != nil {
log.Printf("cfg %s load failure!", *cfg)
return
}
defer f.Close()

r := bufio.NewReader(f)
for {
b, _, err := r.ReadLine()
if err != nil {
if err == io.EOF {
break
}
log.Printf("cfg read failure! %v", err)
break
}

s := strings.Split(strings.TrimSpace(string(b)), " ")
n := len(s)
if n < 2 {
log.Printf("cfg error line %v", b)
continue
}

my_ip := s[0]

for i:=1; i < n; i++ {
tmp_row := strings.TrimSpace(s[i])
if len(tmp_row) == 0 {
continue
}
MyMap[tmp_row] = my_ip
}


log.Println(MyMap)
}

localMap = MyMap
}

func init() {
flag.Parse()

ENCACHE = *encache
DEBUG = *debug

runtime.GOMAXPROCS(runtime.NumCPU()*2 - 1)

clientTCP = new(dns.Client)
clientTCP.Net = "tcp"
clientTCP.ReadTimeout = time.Duration(*timeout) * time.Millisecond
clientTCP.WriteTimeout = time.Duration(*timeout) * time.Millisecond

clientUDP = new(dns.Client)
clientUDP.Net = "udp"
clientUDP.ReadTimeout = time.Duration(*timeout) * time.Millisecond
clientUDP.WriteTimeout = time.Duration(*timeout) * time.Millisecond

loadCfg()

if ENCACHE {
conn = cache.New(time.Second*time.Duration(*expire), time.Second*60)
conn.LoadFile(*file)
intervalSaveCache()
}

for _, s := range strings.Split(*dnss, ",") {
s = strings.TrimSpace(s)
if s == "" {
continue
}

dns := s
proto := "udp"
parts := strings.Split(s, ":")

if len(parts) > 2 {
dns = strings.Join(parts[:2], ":")
if parts[2] == "tcp" {
proto = "tcp"
}
}

_, err := net.ResolveTCPAddr("tcp", dns)
if err != nil {
log.Fatalf("wrong dns address %s\n", dns)
}

DNS = append(DNS, []string{dns, proto})
}

if len(DNS) == 0 {
log.Fatalln("dns address must be not empty")
}

signal.Notify(saveSig, syscall.SIGINT, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGQUIT)
}

func main() {
dns.HandleFunc(".", proxyServe)

failure := make(chan error, 1)

go func(failure chan error) {
failure <- dns.ListenAndServe(*local, "tcp", nil)
}(failure)

go func(failure chan error) {
failure <- dns.ListenAndServe(*local, "udp", nil)
}(failure)

log.Printf("ready for accept connection on tcp/upd %s ...\n", *local)
fmt.Println(<-failure)
}

func proxyServe(w dns.ResponseWriter, req *dns.Msg) {
var (
key string
m *dns.Msg
err error
tried bool
data []byte
id uint16
query []string
questions []dns.Question
used string
)

defer func() {
if err := recover(); err != nil {
fmt.Println(err)
}
}()

if req.MsgHdr.Response == true {
return
}

query = make([]string, len(req.Question))

for i, q := range req.Question {
if q.Qtype != dns.TypeAAAA || *ipv6 {
questions = append(questions, q)
}
query[i] = fmt.Sprintf("(%s %s %s)", q.Name, dns.ClassToString[q.Qclass], dns.TypeToString[q.Qtype])
}

if len(questions) == 0 {
return
}

//check local map
dom := req.Question[0].Name
domain := dom[:-1]
v, ok := localMap[domain]
if ok {
tm := new(dns.Msg)
tm.Id = id
tm.Answer = make([]dns.RR, 1)
dom := req.Question[0].Name
trr := new(dns.A)
trr.Hdr = dns.RR_Header{Name: dom, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 5}
trr.A = net.IPv4(127, 0, 0, 1)
tm.Answer[0] = trr

err = w.WriteMsg(tm)
goto end
}

req.Question = questions

id = req.Id

req.Id = 0
key = toMd5(req.String())
req.Id = id

if ENCACHE {
if reply, ok := conn.Get(key); ok {
data, _ = reply.([]byte)
}

if data != nil && len(data) > 0 {
m = &dns.Msg{}
m.Unpack(data)
m.Id = id
err = w.WriteMsg(m)

if DEBUG > 0 {
log.Printf("id:%5d cache: HIT %v\n", id, query)
}

goto end
} else {
if DEBUG > 0 {
log.Printf("id: %5d cache: MISS %v\n", id, query)
}
}
}

for i, parts := range DNS {
dns := parts[0]
proto := parts[1]

tried = i > 0

if DEBUG > 0 {
if tried {
log.Printf("id: 5%d try: %v %s %s\n", id, query, dns, proto)
} else {
log.Printf("id: 5%d resolve: %v %s %s\n", id, query, dns, proto)
}
}

client := clientUDP
if proto == "tcp" {
client = clientTCP
}

m, _, err = client.Exchange(req, dns)

if err == nil && len(m.Answer) > 0 {
used = dns
break
}
}

if err == nil {
if DEBUG > 0 {
if tried {
if len(m.Answer) == 0 {
log.Printf("id: %5d failed: %v\n", id, query)
} else {
log.Printf("id: %5d bingo: %v %s\n", id, query, used)
}
}
}

data, err = m.Pack()

if err == nil {
_, err = w.Write(data)

if err == nil {
if ENCACHE {
m.Id = 0
data, _ = m.Pack()
ttl := 0

if len(m.Answer) > 0 {
ttl = int(m.Answer[0].Header().Ttl)

if ttl < 0 {
ttl = 0
}
}
conn.Set(key, data, time.Second*time.Duration(ttl))

m.Id = id
if DEBUG > 0 {
log.Printf("id: %5d cache: CACHED %v TTL %v\n", id, query, ttl)
}
}
}
}
}

end:
if DEBUG > 1 {
fmt.Println(req)

if m != nil {
fmt.Println(m)
}
}

if err != nil {
log.Printf("id: %5d error: %v %s\n", id, query, err)
}

if DEBUG > 1 {
fmt.Println("======================================")
}
}
2 changes: 2 additions & 0 deletions hosts.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
127.0.0.1 www.baidu.com sina.com.cn
8.8.8.8 google.com

0 comments on commit 48a4a47

Please sign in to comment.