From 8b76bf2eb030962bd16d389a0e9db230fcfc3b42 Mon Sep 17 00:00:00 2001 From: chenzhe07 Date: Thu, 2 Mar 2017 10:15:56 +0800 Subject: [PATCH] escape sql query and insert query to mysql table --- conf.cnf | 3 ++ conf.go | 22 +++++++++ db.go | 50 ++++++++++++++++++++ logsql.go | 134 ++++++++++++++++++++++++++++++++++++++++++++++++------ main.go | 21 +++++++++ test.sql | 15 ++++++ 6 files changed, 232 insertions(+), 13 deletions(-) create mode 100644 conf.cnf create mode 100644 conf.go create mode 100644 db.go create mode 100644 test.sql diff --git a/conf.cnf b/conf.cnf new file mode 100644 index 0000000..5055b2a --- /dev/null +++ b/conf.cnf @@ -0,0 +1,3 @@ +[backend] +dsn = user_test:Aksj@qop@tcp(127.0.0.1:3306)/test?charset=utf8 + diff --git a/conf.go b/conf.go new file mode 100644 index 0000000..4234bcd --- /dev/null +++ b/conf.go @@ -0,0 +1,22 @@ +/*config read to verify normal user*/ +package main + +import ( + "github.com/chenzhe07/goconfig" +) + +func get_config(conf string) (c *goconfig.ConfigFile, err error) { + c, err = goconfig.ReadConfigFile(conf) + if err != nil { + return c, err + } + return c, nil +} + +func get_backend_dsn(c *goconfig.ConfigFile) (dsn string, err error) { + dsn, err = c.GetString("backend", "dsn") + if err != nil { + return dsn, err + } + return dsn, nil +} diff --git a/db.go b/db.go new file mode 100644 index 0000000..1b2b9ec --- /dev/null +++ b/db.go @@ -0,0 +1,50 @@ +/*config read to verify normal user*/ +package main + +import ( + "database/sql" + "fmt" + _ "github.com/go-sql-driver/mysql" + "log" +) + +func dbh(dsn string) (db *sql.DB, err error) { + db, err = sql.Open("mysql", dsn) + if err != nil { + return db, err + } + return db, nil +} + +func Query(db *sql.DB, q string) (*sql.Rows, error) { + if Verbose { + log.Printf("Query: %s\n", q) + } + return db.Query(q) +} + +func QueryRow(db *sql.DB, q string) *sql.Row { + if Verbose { + log.Printf("Query: %s", q) + } + return db.QueryRow(q) +} + +func ExecQuery(db *sql.DB, q string) (sql.Result, error) { + if Verbose { + log.Printf("ExecQuery: %s\n", q) + } + return db.Exec(q) +} + +func insertlog(db *sql.DB, t *query) bool { + insertSql := ` + insert into query_log(bindport, client, client_port, server, server_port, sql_type, + sql_string, create_time) values (%d, '%s', %d, '%s', %d, '%s', '%s', now()) + ` + _, err := ExecQuery(db, fmt.Sprintf(insertSql, t.bindPort, t.client, t.cport, t.server, t.sport, t.sqlType, t.sqlString)) + if err != nil { + return false + } + return true +} diff --git a/logsql.go b/logsql.go index e6e08f8..f69c71a 100644 --- a/logsql.go +++ b/logsql.go @@ -1,7 +1,10 @@ package main import ( + "fmt" "log" + "strconv" + "strings" ) //read more client-server protocol from http://dev.mysql.com/doc/internals/en/text-protocol.html @@ -36,36 +39,141 @@ const ( comStmtFetch ) +type query struct { + bindPort int64 + client string + cport int64 + server string + sport int64 + sqlType string + sqlString string +} + +func ipPortFromNetAddr(s string) (ip string, port int64) { + addrInfo := strings.SplitN(s, ":", 2) + ip = addrInfo[0] + port, _ = strconv.ParseInt(addrInfo[1], 10, 64) + return +} + +func converToUnixLine(sql string) string { + sql = strings.Replace(sql, "\r\n", "\n", -1) + sql = strings.Replace(sql, "\r", "\n", -1) + return sql +} + +func sql_escape(s string) string { + var j int = 0 + if len(s) == 0 { + return "" + } + + tempStr := s[:] + desc := make([]byte, len(tempStr)*2) + for i := 0; i < len(tempStr); i++ { + flag := false + var escape byte + switch tempStr[i] { + case '\r': + flag = true + escape = '\r' + break + case '\n': + flag = true + escape = '\n' + break + case '\\': + flag = true + escape = '\\' + break + case '\'': + flag = true + escape = '\'' + break + case '"': + flag = true + escape = '"' + break + case '\032': + flag = true + escape = 'Z' + break + default: + } + if flag { + desc[j] = '\\' + desc[j+1] = escape + j = j + 2 + } else { + desc[j] = tempStr[i] + j = j + 1 + } + } + return string(desc[0:j]) +} + func proxyLog(src, dst *Conn) { buffer := make([]byte, Bsize) - clientIp := src.conn.RemoteAddr().String() - serverIp := dst.conn.RemoteAddr().String() + var sqlInfo query + sqlInfo.client, sqlInfo.cport = ipPortFromNetAddr(src.conn.RemoteAddr().String()) + sqlInfo.server, sqlInfo.sport = ipPortFromNetAddr(dst.conn.RemoteAddr().String()) + _, sqlInfo.bindPort = ipPortFromNetAddr(src.conn.LocalAddr().String()) + for { n, err := src.Read(buffer) if err != nil { return } if n >= 5 { + var verboseStr string switch buffer[4] { case comQuit: - log.Printf("From %s To %s; Quit: %s\n", clientIp, serverIp, "user quit") + verboseStr = fmt.Sprintf("From %s To %s; Quit: %s\n", sqlInfo.client, sqlInfo.server, "user quit") + sqlInfo.sqlType = "Quit" case comInitDB: - log.Printf("From %s To %s; schema: use %s\n", clientIp, serverIp, string(buffer[5:n])) + verboseStr = fmt.Sprintf("From %s To %s; schema: use %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "Schema" case comQuery: - log.Printf("From %s To %s; Query: %s\n", clientIp, serverIp, string(buffer[5:n])) - case comFieldList: - log.Printf("From %s To %s; Table columns list: %s\n", clientIp, serverIp, string(buffer[5:n])) - case comConnect: - log.Printf("Internal: internal command in the server\n") + verboseStr = fmt.Sprintf("From %s To %s; Query: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "Query" + //case comFieldList: + // verboseStr = log.Printf("From %s To %s; Table columns list: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + // sqlInfo.sqlType = "Table columns list" + case comCreateDB: + verboseStr = fmt.Sprintf("From %s To %s; CreateDB: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "CreateDB" + case comDropDB: + verboseStr = fmt.Sprintf("From %s To %s; DropDB: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "DropDB" case comRefresh: - log.Printf("From %s To %s; Refresh: command: %s\n", clientIp, serverIp, string(buffer[5:n])) + verboseStr = fmt.Sprintf("From %s To %s; Refresh: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "Refresh" case comStmtPrepare: - log.Printf("From %s To %s; Prepare Query: %s\n", clientIp, serverIp, string(buffer[5:n])) + verboseStr = fmt.Sprintf("From %s To %s; Prepare Query: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "Prepare Query" case comStmtExecute: - log.Printf("From %s To %s; Prepare Args: %s\n", clientIp, serverIp, string(buffer[5:n])) + verboseStr = fmt.Sprintf("From %s To %s; Prepare Args: %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "Prepare Args" case comProcessKill: - log.Printf("From %s To %s; Kill: kill conntion %s\n", clientIp, serverIp, string(buffer[5:n])) + verboseStr = fmt.Sprintf("From %s To %s; Kill: kill conntion %s\n", sqlInfo.client, sqlInfo.server, string(buffer[5:n])) + sqlInfo.sqlType = "Kill" + default: + } + + if Verbose { + log.Print(verboseStr) } + + if strings.EqualFold(sqlInfo.sqlType, "Quit") { + sqlInfo.sqlString = "user quit" + } else { + sqlInfo.sqlString = converToUnixLine(sql_escape(string(buffer[5:n]))) + } + + if !strings.EqualFold(sqlInfo.sqlType, "") { + insertlog(Dbh, &sqlInfo) + } + } _, err = dst.Write(buffer[0:n]) diff --git a/main.go b/main.go index 0c9ece0..2ce038b 100644 --- a/main.go +++ b/main.go @@ -8,6 +8,7 @@ cz-20151119 */ import ( + "database/sql" "flag" "github.com/VividCortex/godaemon" "log" @@ -35,19 +36,39 @@ func waitSignal() { const timeout = time.Second * 2 var Bsize uint +var Verbose bool +var Dbh *sql.DB func main() { // options var bind, backend, logTo string var buffer uint var daemon bool + var verbose bool + var conf string + flag.StringVar(&bind, "bind", ":8002", "locate ip and port") flag.StringVar(&backend, "backend", "127.0.0.1:8003", "backend server ip and port") flag.StringVar(&logTo, "logTo", "stdout", "stdout or syslog") flag.UintVar(&buffer, "buffer", 4096, "buffer size") flag.BoolVar(&daemon, "daemon", false, "run as daemon process") + flag.BoolVar(&verbose, "verbose", false, "print verbose message") + flag.StringVar(&conf, "conf", "", "config file to verify database and firewall info") flag.Parse() Bsize = buffer + Verbose = verbose + + conf_fh, err := get_config(conf) + if err != nil { + log.Printf("Can't get config info, skip insert log to mysql...\n") + } + + backend_dsn, _ := get_backend_dsn(conf_fh) + Dbh, err = dbh(backend_dsn) + if err != nil { + log.Printf("Can't get database handle, skip insert log to mysql...\n") + } + defer Dbh.Close() log.SetOutput(os.Stdout) if logTo == "syslog" { diff --git a/test.sql b/test.sql new file mode 100644 index 0000000..36ef101 --- /dev/null +++ b/test.sql @@ -0,0 +1,15 @@ +CREATE TABLE `query_log` ( + `id` int(10) unsigned NOT NULL AUTO_INCREMENT, + `bindport` smallint(5) unsigned NOT NULL, + `client` char(15) NOT NULL DEFAULT '', + `client_port` smallint(5) unsigned NOT NULL, + `server` char(15) NOT NULL DEFAULT '', + `server_port` smallint(5) unsigned NOT NULL, + `sql_type` varchar(30) NOT NULL DEFAULT 'Query', + `sql_string` text, + `create_time` datetime NOT NULL, + PRIMARY KEY (`id`), + KEY `idx_client` (`client`), + KEY `idx_server` (`server`), + KEY `idx_cretime` (`create_time`) +) ENGINE=InnoDB AUTO_INCREMENT=9945 DEFAULT CHARSET=utf8