Skip to content

Commit

Permalink
Exec: bind values; Fix 'INSERT INTO ...SELECT' parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
kshvakov committed Aug 1, 2017
1 parent b964252 commit 3d7bd11
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 29 deletions.
5 changes: 1 addition & 4 deletions clickhouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ func (ch *clickhouse) Rollback() error {
}
ch.data = nil
ch.inTransaction = false
if err := ch.cancel(); err != nil {
return err
}
return driver.ErrBadConn
return nil
}

func (ch *clickhouse) Close() error {
Expand Down
3 changes: 1 addition & 2 deletions clickhouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package clickhouse_test

import (
"database/sql"
"database/sql/driver"
"fmt"
"strings"
"testing"
Expand Down Expand Up @@ -704,7 +703,7 @@ func Test_Tx(t *testing.T) {
if tx, err := connect.Begin(); assert.NoError(t, err) {
_, err = tx.Query("SELECT 1")
if assert.NoError(t, err) {
if !assert.Equal(t, driver.ErrBadConn, tx.Rollback()) {
if !assert.NoError(t, tx.Rollback()) {
return
}
}
Expand Down
5 changes: 4 additions & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"database/sql/driver"
"fmt"
"regexp"
"strings"
"time"
)
Expand Down Expand Up @@ -96,9 +97,11 @@ func paramParser(reader *bytes.Reader) string {
return name.String()
}

var selectRe = regexp.MustCompile(`\s+SELECT\s+`)

func isInsert(query string) bool {
if f := strings.Fields(query); len(f) > 2 {
return strings.EqualFold("INSERT", f[0]) && strings.EqualFold("INTO", f[1]) && strings.Index(strings.ToUpper(query), " SELECT ") == -1
return strings.EqualFold("INSERT", f[0]) && strings.EqualFold("INTO", f[1]) && !selectRe.MatchString(strings.ToUpper(query))
}
return false
}
Expand Down
47 changes: 25 additions & 22 deletions stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ func (stmt *stmt) execContext(ctx context.Context, args []driver.Value) (driver.
}
return emptyResult, nil
}
if err := stmt.ch.sendQuery(stmt.query); err != nil {

if err := stmt.ch.sendQuery(stmt.bind(convertOldArgs(args))); err != nil {
return nil, err
}
if _, err := stmt.ch.receiveData(); err != nil {
Expand All @@ -57,16 +58,33 @@ func (stmt *stmt) Query(args []driver.Value) (driver.Rows, error) {
}

func (stmt *stmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
if finish := stmt.ch.watchCancel(ctx); finish != nil {
defer finish()
}

if err := stmt.ch.sendQuery(stmt.bind(args)); err != nil {
return nil, err
}

rows, err := stmt.ch.receiveData()
if err != nil {
return nil, err
}

return rows, nil
}

func (stmt *stmt) Close() error {
stmt.ch.logf("[stmt] close")
return nil
}

func (stmt *stmt) bind(args []namedValue) string {
var (
buf bytes.Buffer
index int
keyword bool
)

if finish := stmt.ch.watchCancel(ctx); finish != nil {
defer finish()
}

switch {
case stmt.NumInput() != 0:
reader := bytes.NewReader([]byte(stmt.query))
Expand Down Expand Up @@ -110,22 +128,7 @@ func (stmt *stmt) queryContext(ctx context.Context, args []namedValue) (driver.R
default:
buf.WriteString(stmt.query)
}

if err := stmt.ch.sendQuery(buf.String()); err != nil {
return nil, err
}

rows, err := stmt.ch.receiveData()
if err != nil {
return nil, err
}

return rows, nil
}

func (stmt *stmt) Close() error {
stmt.ch.logf("[stmt] close")
return nil
return buf.String()
}

type namedValue struct {
Expand Down

0 comments on commit 3d7bd11

Please sign in to comment.