diff --git a/README.md b/README.md index b8642704..d233696f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +__[v3.0 in the making](https://github.com/mattes/migrate/tree/v3.0-prev)__ + +--- + # migrate [![Build Status](https://travis-ci.org/mattes/migrate.svg?branch=master)](https://travis-ci.org/mattes/migrate) diff --git a/driver/bash/bash.go b/driver/bash/bash.go index 031f9bb4..187a5dcf 100644 --- a/driver/bash/bash.go +++ b/driver/bash/bash.go @@ -32,5 +32,5 @@ func (driver *Driver) Version() (uint64, error) { } func init() { - driver.RegisterDriver("bash", &Driver{}) + driver.RegisterDriver("bash", Driver{}) } diff --git a/driver/cassandra/cassandra.go b/driver/cassandra/cassandra.go index b91053c3..f776b037 100644 --- a/driver/cassandra/cassandra.go +++ b/driver/cassandra/cassandra.go @@ -14,44 +14,16 @@ import ( "github.com/mattes/migrate/migrate/direction" ) +// Driver implements migrate Driver interface type Driver struct { session *gocql.Session } const ( - tableName = "schema_migrations" - versionRow = 1 + tableName = "schema_migrations" ) -type counterStmt bool - -func (c counterStmt) Exec(session *gocql.Session) error { - var version int64 - if err := session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version); err != nil { - return err - } - - if bool(c) { - version++ - } else { - version-- - } - - return session.Query("UPDATE "+tableName+" SET version = ? WHERE versionRow = ?", version, versionRow).Exec() -} - -const ( - up counterStmt = true - down counterStmt = false -) - -// Cassandra Driver URL format: -// cassandra://host:port/keyspace?protocol=version&consistency=level -// -// Examples: -// cassandra://localhost/SpaceOfKeys?protocol=4 -// cassandra://localhost/SpaceOfKeys?protocol=4&consistency=all -// cassandra://localhost/SpaceOfKeys?consistency=quorum +// Initialize will be called first func (driver *Driver) Initialize(rawurl string) error { u, err := url.Parse(rawurl) if err != nil { @@ -68,7 +40,8 @@ func (driver *Driver) Initialize(rawurl string) error { cluster.Timeout = 1 * time.Minute if len(u.Query().Get("consistency")) > 0 { - consistency, err := parseConsistency(u.Query().Get("consistency")) + var consistency gocql.Consistency + consistency, err = parseConsistency(u.Query().Get("consistency")) if err != nil { return err } @@ -77,7 +50,8 @@ func (driver *Driver) Initialize(rawurl string) error { } if len(u.Query().Get("protocol")) > 0 { - protoversion, err := strconv.Atoi(u.Query().Get("protocol")) + var protoversion int + protoversion, err = strconv.Atoi(u.Query().Get("protocol")) if err != nil { return err } @@ -90,7 +64,7 @@ func (driver *Driver) Initialize(rawurl string) error { password, passwordSet := u.User.Password() if passwordSet == false { - return fmt.Errorf("Missing password. Please provide password.") + return fmt.Errorf("Missing password. Please provide password") } cluster.Authenticator = gocql.PasswordAuthenticator{ @@ -112,53 +86,63 @@ func (driver *Driver) Initialize(rawurl string) error { return nil } +// Close last function to be called. Closes cassandra session func (driver *Driver) Close() error { driver.session.Close() return nil } func (driver *Driver) ensureVersionTableExists() error { - err := driver.session.Query("CREATE TABLE IF NOT EXISTS " + tableName + " (version int, versionRow bigint primary key);").Exec() + err := driver.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id uuid primary key, version bigint)", tableName)).Exec() if err != nil { return err } - _, err = driver.Version() - if err != nil { - if err.Error() == "not found" { - return driver.session.Query("UPDATE "+tableName+" SET version = ? WHERE versionRow = ?", 1, versionRow).Exec() - } + if _, err = driver.Version(); err != nil { return err } return nil } +// FilenameExtension return extension of migrations files func (driver *Driver) FilenameExtension() string { return "cql" } -func (driver *Driver) version(d direction.Direction, invert bool) error { - var stmt counterStmt - switch d { - case direction.Up: - stmt = up - case direction.Down: - stmt = down +func (driver *Driver) updateVersion(version uint64, dir direction.Direction) error { + var ids []string + var id string + var err error + iter := driver.session.Query(fmt.Sprintf("SELECT id FROM %s WHERE version >= ? ALLOW FILTERING", tableName), version).Iter() + for iter.Scan(&id) { + ids = append(ids, id) + } + if len(ids) > 0 { + err = driver.session.Query(fmt.Sprintf("DELETE FROM %s WHERE id IN ?", tableName), ids).Exec() + if err != nil { + return err + } } - if invert { - stmt = !stmt + if dir == direction.Up { + return driver.session.Query(fmt.Sprintf("INSERT INTO %s (id, version) VALUES (uuid(), ?)", tableName), version).Exec() } - return stmt.Exec(driver.session) + return nil } +// Migrate run migration file. Restore previous version in case of fail func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { var err error + previousVersion, err := driver.Version() + if err != nil { + close(pipe) + return + } defer func() { if err != nil { // Invert version direction if we couldn't apply the changes for some reason. - if err := driver.version(f.Direction, true); err != nil { - pipe <- err + if updErr := driver.updateVersion(previousVersion, direction.Up); updErr != nil { + pipe <- updErr } pipe <- err } @@ -166,7 +150,7 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { }() pipe <- f - if err = driver.version(f.Direction, false); err != nil { + if err = driver.updateVersion(f.Version, f.Direction); err != nil { return } @@ -186,14 +170,15 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { } } +// Version return current version func (driver *Driver) Version() (uint64, error) { var version int64 - err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version) - return uint64(version) - 1, err + err := driver.session.Query(fmt.Sprintf("SELECT max(version) FROM %s", tableName)).Scan(&version) + return uint64(version), err } func init() { - driver.RegisterDriver("cassandra", &Driver{}) + driver.RegisterDriver("cassandra", Driver{}) } // ParseConsistency wraps gocql.ParseConsistency to return an error diff --git a/driver/cassandra/cassandra_test.go b/driver/cassandra/cassandra_test.go index 49024522..7218ba23 100644 --- a/driver/cassandra/cassandra_test.go +++ b/driver/cassandra/cassandra_test.go @@ -21,10 +21,10 @@ func TestMigrate(t *testing.T) { host := os.Getenv("CASSANDRA_PORT_9042_TCP_ADDR") port := os.Getenv("CASSANDRA_PORT_9042_TCP_PORT") - driverUrl := "cassandra://" + host + ":" + port + "/system" + driverURL := "cassandra://" + host + ":" + port + "/system" // prepare a clean test database - u, err := url.Parse(driverUrl) + u, err := url.Parse(driverURL) if err != nil { t.Fatal(err) } @@ -35,23 +35,25 @@ func TestMigrate(t *testing.T) { cluster.Timeout = 1 * time.Minute session, err = cluster.CreateSession() - if err != nil { t.Fatal(err) } - if err := session.Query(`DROP KEYSPACE IF EXISTS migrate;`).Exec(); err != nil { + if err = session.Query(`DROP KEYSPACE IF EXISTS migrate;`).Exec(); err != nil { t.Fatal(err) } - if err := session.Query(`CREATE KEYSPACE IF NOT EXISTS migrate WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};`).Exec(); err != nil { + if err = session.Query(`CREATE KEYSPACE IF NOT EXISTS migrate WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1};`).Exec(); err != nil { t.Fatal(err) } cluster.Keyspace = "migrate" session, err = cluster.CreateSession() - driverUrl = "cassandra://" + host + ":" + port + "/migrate" + if err != nil { + t.Fatal(err) + } + driverURL = "cassandra://" + host + ":" + port + "/migrate" d := &Driver{} - if err := d.Initialize(driverUrl); err != nil { + if err := d.Initialize(driverURL); err != nil { t.Fatal(err) } diff --git a/driver/crate/crate.go b/driver/crate/crate.go index eb30853a..7e43d5c5 100644 --- a/driver/crate/crate.go +++ b/driver/crate/crate.go @@ -13,7 +13,7 @@ import ( ) func init() { - driver.RegisterDriver("crate", &Driver{}) + driver.RegisterDriver("crate", Driver{}) } type Driver struct { @@ -97,8 +97,8 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { func splitContent(content string) []string { lines := strings.Split(content, ";") resultLines := make([]string, 0, len(lines)) - for i, line := range lines { - line = strings.Replace(lines[i], ";", "", -1) + for i := range lines { + line := strings.Replace(lines[i], ";", "", -1) line = strings.TrimSpace(line) if line != "" { resultLines = append(resultLines, line) diff --git a/driver/driver.go b/driver/driver.go index e4ecb783..ddf4cb45 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -43,7 +43,7 @@ func New(url string) (Driver, error) { d := GetDriver(u.Scheme) if d == nil { - return nil, fmt.Errorf("Driver '%s' not found.", u.Scheme) + return nil, fmt.Errorf("Driver '%s' not found", u.Scheme) } verifyFilenameExtension(u.Scheme, d) if err := d.Initialize(url); err != nil { @@ -53,6 +53,22 @@ func New(url string) (Driver, error) { return d, nil } +// FilenameExtensionFromURL return extension for migration files. Used for create migrations +func FilenameExtensionFromURL(url string) (string, error) { + u, err := neturl.Parse(url) + if err != nil { + return "", err + } + + d := GetDriver(u.Scheme) + if d == nil { + return "", fmt.Errorf("Driver '%s' not found", u.Scheme) + } + verifyFilenameExtension(u.Scheme, d) + + return d.FilenameExtension(), nil +} + // verifyFilenameExtension panics if the driver's filename extension // is not correct or empty. func verifyFilenameExtension(driverName string, d Driver) { diff --git a/driver/mongodb/example/mongodb_test.go b/driver/mongodb/example/mongodb_test.go index 86c37c00..6070ed73 100644 --- a/driver/mongodb/example/mongodb_test.go +++ b/driver/mongodb/example/mongodb_test.go @@ -6,13 +6,14 @@ import ( "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" + "os" + "reflect" + "time" + "github.com/mattes/migrate/driver" "github.com/mattes/migrate/driver/mongodb" "github.com/mattes/migrate/driver/mongodb/gomethods" pipep "github.com/mattes/migrate/pipe" - "os" - "reflect" - "time" ) type ExpectedMigrationResult struct { diff --git a/driver/mongodb/example/sample_mongdb_migrator.go b/driver/mongodb/example/sample_mongdb_migrator.go index 1ab1440a..227874e7 100644 --- a/driver/mongodb/example/sample_mongdb_migrator.go +++ b/driver/mongodb/example/sample_mongdb_migrator.go @@ -1,11 +1,12 @@ package example import ( + "time" + "github.com/mattes/migrate/driver/mongodb/gomethods" _ "github.com/mattes/migrate/driver/mongodb/gomethods" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" - "time" "github.com/mattes/migrate/driver/mongodb" ) diff --git a/driver/mongodb/gomethods/gomethods_migrator.go b/driver/mongodb/gomethods/gomethods_migrator.go index 97057ab5..d53f7d91 100644 --- a/driver/mongodb/gomethods/gomethods_migrator.go +++ b/driver/mongodb/gomethods/gomethods_migrator.go @@ -3,11 +3,12 @@ package gomethods import ( "bufio" "fmt" - "github.com/mattes/migrate/driver" - "github.com/mattes/migrate/file" "os" "path" "strings" + + "github.com/mattes/migrate/driver" + "github.com/mattes/migrate/file" ) type MethodNotFoundError string diff --git a/driver/mongodb/gomethods/gomethods_registry.go b/driver/mongodb/gomethods/gomethods_registry.go index 418256fd..d9a19eed 100644 --- a/driver/mongodb/gomethods/gomethods_registry.go +++ b/driver/mongodb/gomethods/gomethods_registry.go @@ -2,8 +2,9 @@ package gomethods import ( "fmt" - "github.com/mattes/migrate/driver" "sync" + + "github.com/mattes/migrate/driver" ) var methodsReceiversMu sync.Mutex diff --git a/driver/mongodb/mongodb.go b/driver/mongodb/mongodb.go index fcfae505..55b73855 100644 --- a/driver/mongodb/mongodb.go +++ b/driver/mongodb/mongodb.go @@ -2,14 +2,15 @@ package mongodb import ( "errors" + "reflect" + "strings" + "github.com/mattes/migrate/driver" "github.com/mattes/migrate/driver/mongodb/gomethods" "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" - "reflect" - "strings" ) type UnregisteredMethodsReceiverError string @@ -55,7 +56,7 @@ func (d *Driver) SetMethodsReceiver(r interface{}) error { } func init() { - driver.RegisterDriver("mongodb", &Driver{}) + driver.RegisterDriver("mongodb", Driver{}) } type DbMigration struct { diff --git a/driver/mysql/mysql.go b/driver/mysql/mysql.go index 2bbbc131..eabbf2a3 100644 --- a/driver/mysql/mysql.go +++ b/driver/mysql/mysql.go @@ -181,5 +181,5 @@ func (driver *Driver) Version() (uint64, error) { } func init() { - driver.RegisterDriver("mysql", &Driver{}) + driver.RegisterDriver("mysql", Driver{}) } diff --git a/driver/neo4j/neo4j.go b/driver/neo4j/neo4j.go index 9670abeb..f623391f 100644 --- a/driver/neo4j/neo4j.go +++ b/driver/neo4j/neo4j.go @@ -2,156 +2,165 @@ package neo4j import ( - "fmt" - "bytes" - "strings" - "errors" - "github.com/jmcvetta/neoism" - "github.com/mattes/migrate/driver" - "github.com/mattes/migrate/file" - "github.com/mattes/migrate/migrate/direction" + "bytes" + "errors" + "fmt" + "github.com/jmcvetta/neoism" + "github.com/mattes/migrate/driver" + "github.com/mattes/migrate/file" + "github.com/mattes/migrate/migrate/direction" + "strings" ) type Driver struct { - db *neoism.Database + db *neoism.Database } const labelName = "SchemaMigration" func (driver *Driver) Initialize(url string) error { - url = strings.Replace(url,"neo4j","http",1) + url = strings.Replace(url, "neo4j", "http", 1) - db, err := neoism.Connect(url) - if err != nil { - return err - } + db, err := neoism.Connect(url) + if err != nil { + return err + } - driver.db = db + driver.db = db - if err := driver.ensureVersionConstraintExists(); err != nil { - return err - } - return nil + if err := driver.ensureVersionConstraintExists(); err != nil { + return err + } + return nil } func (driver *Driver) Close() error { - driver.db = nil - return nil + driver.db = nil + return nil } func (driver *Driver) FilenameExtension() string { - return "cql" + return "cql" } func (driver *Driver) ensureVersionConstraintExists() error { - uc, _ := driver.db.UniqueConstraints("SchemaMigration", "version") - if len(uc) == 0 { - _, err := driver.db.CreateUniqueConstraint("SchemaMigration", "version") - return err - } - return nil + uc, _ := driver.db.UniqueConstraints("SchemaMigration", "version") + if len(uc) == 0 { + _, err := driver.db.CreateUniqueConstraint("SchemaMigration", "version") + return err + } + return nil } func (driver *Driver) setVersion(d direction.Direction, v uint64, invert bool) error { - cqUp := neoism.CypherQuery { - Statement: `CREATE (n:SchemaMigration {version: {Version}}) RETURN n`, - Parameters: neoism.Props{"Version": v}, - } - - cqDown := neoism.CypherQuery { - Statement: `MATCH (n:SchemaMigration {version: {Version}}) DELETE n`, - Parameters: neoism.Props{"Version": v}, - } - - var cq neoism.CypherQuery - switch d { - case direction.Up: - if invert { cq = cqDown } else { cq = cqUp } - case direction.Down: - if invert { cq = cqUp } else { cq = cqDown } - } - return driver.db.Cypher(&cq) + cqUp := neoism.CypherQuery{ + Statement: `CREATE (n:SchemaMigration {version: {Version}}) RETURN n`, + Parameters: neoism.Props{"Version": v}, + } + + cqDown := neoism.CypherQuery{ + Statement: `MATCH (n:SchemaMigration {version: {Version}}) DELETE n`, + Parameters: neoism.Props{"Version": v}, + } + + var cq neoism.CypherQuery + switch d { + case direction.Up: + if invert { + cq = cqDown + } else { + cq = cqUp + } + case direction.Down: + if invert { + cq = cqUp + } else { + cq = cqDown + } + } + return driver.db.Cypher(&cq) } func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { - var err error - - defer func() { - if err != nil { - // Invert version direction if we couldn't apply the changes for some reason. - if err := driver.setVersion(f.Direction, f.Version, true); err != nil { - pipe <- err - } - pipe <- err - } - close(pipe) - }() - - pipe <- f - - - if err = driver.setVersion(f.Direction, f.Version, false); err != nil { - pipe <- err - return - } - - if err = f.ReadContent(); err != nil { - pipe <- err - return - } - - cQueries := []*neoism.CypherQuery{} - - // Neoism doesn't support multiple statements per query. - cqlStmts := bytes.Split(f.Content, []byte(";")) - - for _, cqlStmt := range cqlStmts { - cqlStmt = bytes.TrimSpace(cqlStmt) - if len(cqlStmt) > 0 { - cq := neoism.CypherQuery{Statement: string(cqlStmt)} - cQueries = append( cQueries, &cq ) - } - } - - var tx *neoism.Tx - - tx, err = driver.db.Begin(cQueries) - if err != nil { - pipe <- err - for _, err := range tx.Errors { - pipe <- errors.New(fmt.Sprintf("%v", err.Message)) - } - if err = tx.Rollback(); err != nil { - pipe <- err - } - return - } - - if err = tx.Commit(); err != nil { - pipe <- err - for _, err := range tx.Errors { - pipe <- errors.New(fmt.Sprintf("%v", err.Message)) - } - return - } + var err error + + defer func() { + if err != nil { + // Invert version direction if we couldn't apply the changes for some reason. + if err := driver.setVersion(f.Direction, f.Version, true); err != nil { + pipe <- err + } + pipe <- err + } + close(pipe) + }() + + pipe <- f + + if err = driver.setVersion(f.Direction, f.Version, false); err != nil { + pipe <- err + return + } + + if err = f.ReadContent(); err != nil { + pipe <- err + return + } + + cQueries := []*neoism.CypherQuery{} + + // Neoism doesn't support multiple statements per query. + cqlStmts := bytes.Split(f.Content, []byte(";")) + + for _, cqlStmt := range cqlStmts { + cqlStmt = bytes.TrimSpace(cqlStmt) + if len(cqlStmt) > 0 { + cq := neoism.CypherQuery{Statement: string(cqlStmt)} + cQueries = append(cQueries, &cq) + } + } + + var tx *neoism.Tx + + tx, err = driver.db.Begin(cQueries) + if err != nil { + pipe <- err + for _, err := range tx.Errors { + pipe <- errors.New(fmt.Sprintf("%v", err.Message)) + } + if err = tx.Rollback(); err != nil { + pipe <- err + } + return + } + + if err = tx.Commit(); err != nil { + pipe <- err + for _, err := range tx.Errors { + pipe <- errors.New(fmt.Sprintf("%v", err.Message)) + } + return + } } func (driver *Driver) Version() (uint64, error) { - res := []struct {Version uint64 `json:"n.version"`}{} + res := []struct { + Version uint64 `json:"n.version"` + }{} - cq := neoism.CypherQuery{ - Statement: `MATCH (n:SchemaMigration) + cq := neoism.CypherQuery{ + Statement: `MATCH (n:SchemaMigration) RETURN n.version ORDER BY n.version DESC LIMIT 1`, - Result: &res, - } + Result: &res, + } - if err := driver.db.Cypher(&cq); err != nil || len(res) == 0 { - return 0, err - } - return res[0].Version, nil + if err := driver.db.Cypher(&cq); err != nil || len(res) == 0 { + return 0, err + } + return res[0].Version, nil } func init() { - driver.RegisterDriver("neo4j", &Driver{}) + driver.RegisterDriver("neo4j", Driver{}) } diff --git a/driver/postgres/postgres.go b/driver/postgres/postgres.go index df38fde4..d5b25075 100644 --- a/driver/postgres/postgres.go +++ b/driver/postgres/postgres.go @@ -43,7 +43,7 @@ func (driver *Driver) Close() error { } func (driver *Driver) ensureVersionTableExists() error { - r := driver.db.QueryRow("SELECT count(*) FROM information_schema.tables WHERE table_name = $1;", tableName) + r := driver.db.QueryRow("SELECT count(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema());", tableName) c := 0 if err := r.Scan(&c); err != nil { return err @@ -139,5 +139,7 @@ func (driver *Driver) Version() (uint64, error) { } func init() { - driver.RegisterDriver("postgres", &Driver{}) + drv := Driver{} + driver.RegisterDriver("postgres", drv) + driver.RegisterDriver("postgresql", drv) } diff --git a/driver/postgres/postgres_test.go b/driver/postgres/postgres_test.go index df3d3586..1a5c4d7c 100644 --- a/driver/postgres/postgres_test.go +++ b/driver/postgres/postgres_test.go @@ -19,10 +19,10 @@ func TestMigrate(t *testing.T) { host := os.Getenv("POSTGRES_PORT_5432_TCP_ADDR") port := os.Getenv("POSTGRES_PORT_5432_TCP_PORT") - driverUrl := "postgres://postgres@" + host + ":" + port + "/template1?sslmode=disable" + driverURL := "postgres://postgres@" + host + ":" + port + "/template1?sslmode=disable" // prepare clean database - connection, err := sql.Open("postgres", driverUrl) + connection, err := sql.Open("postgres", driverURL) if err != nil { t.Fatal(err) } @@ -33,12 +33,12 @@ func TestMigrate(t *testing.T) { } d := &Driver{} - if err := d.Initialize(driverUrl); err != nil { + if err := d.Initialize(driverURL); err != nil { t.Fatal(err) } // testing idempotency: second call should be a no-op, since table already exists - if err := d.Initialize(driverUrl); err != nil { + if err := d.Initialize(driverURL); err != nil { t.Fatal(err) } diff --git a/driver/ql/ql.go b/driver/ql/ql.go index 1c75fe70..1ee226ca 100644 --- a/driver/ql/ql.go +++ b/driver/ql/ql.go @@ -3,13 +3,14 @@ package ql import ( "database/sql" - "github.com/mattes/migrate/file" "github.com/mattes/migrate/driver" + "github.com/mattes/migrate/file" "github.com/mattes/migrate/migrate/direction" - _ "github.com/cznic/ql/driver" - "strings" "fmt" + "strings" + + _ "github.com/cznic/ql/driver" ) const tableName = "schema_migration" @@ -125,6 +126,6 @@ func (d *Driver) ensureVersionTableExists() error { } func init() { - driver.RegisterDriver("ql+file", &Driver{}) - driver.RegisterDriver("ql+memory", &Driver{}) -} \ No newline at end of file + driver.RegisterDriver("ql+file", Driver{}) + driver.RegisterDriver("ql+memory", Driver{}) +} diff --git a/driver/registry.go b/driver/registry.go index 63889184..c8a297b9 100644 --- a/driver/registry.go +++ b/driver/registry.go @@ -1,17 +1,22 @@ package driver import ( + "reflect" "sort" "sync" ) +type DriverRegistration struct { + Template interface{} +} + var driversMu sync.Mutex -var drivers = make(map[string]Driver) +var drivers = make(map[string]DriverRegistration) -// Registers a driver so it can be created from its name. Drivers should +// RegisterDriver register a driver so it can be created from its name. Drivers should // call this from an init() function so that they registers themselves on // import -func RegisterDriver(name string, driver Driver) { +func RegisterDriver(name string, driver interface{}) { driversMu.Lock() defer driversMu.Unlock() if driver == nil { @@ -20,15 +25,18 @@ func RegisterDriver(name string, driver Driver) { if _, dup := drivers[name]; dup { panic("sql: Register called twice for driver " + name) } - drivers[name] = driver + drivers[name] = DriverRegistration{ + Template: driver, + } } -// Retrieves a registered driver by name +// GetDriver retrieves a registered driver by name func GetDriver(name string) Driver { driversMu.Lock() defer driversMu.Unlock() - driver := drivers[name] - return driver + registration := drivers[name] + driver := reflect.New(reflect.TypeOf(registration.Template)).Interface() + return driver.(Driver) } // Drivers returns a sorted list of the names of the registered drivers. diff --git a/driver/sqlite3/sqlite3.go b/driver/sqlite3/sqlite3.go index a92b14d8..257dfddc 100644 --- a/driver/sqlite3/sqlite3.go +++ b/driver/sqlite3/sqlite3.go @@ -96,9 +96,9 @@ func (driver *Driver) Migrate(f file.File, pipe chan interface{}) { if isErr { // The sqlite3 library only provides error codes, not position information. Output what we do know - pipe <- errors.New(fmt.Sprintf("SQLite Error (%s); Extended (%s)\nError: %s", sqliteErr.Code.Error(), sqliteErr.ExtendedCode.Error(), sqliteErr.Error())) + pipe <- fmt.Errorf("SQLite Error (%s); Extended (%s)\nError: %s", sqliteErr.Code.Error(), sqliteErr.ExtendedCode.Error(), sqliteErr.Error()) } else { - pipe <- errors.New(fmt.Sprintf("An error occurred: %s", err.Error())) + pipe <- fmt.Errorf("An error occurred: %s", err.Error()) } if err := tx.Rollback(); err != nil { @@ -127,5 +127,5 @@ func (driver *Driver) Version() (uint64, error) { } func init() { - driver.RegisterDriver("sqlite3", &Driver{}) + driver.RegisterDriver("sqlite3", Driver{}) } diff --git a/driver/sqlite3/sqlite3_test.go b/driver/sqlite3/sqlite3_test.go index 7f50095a..57dc7792 100644 --- a/driver/sqlite3/sqlite3_test.go +++ b/driver/sqlite3/sqlite3_test.go @@ -17,7 +17,7 @@ func TestMigrate(t *testing.T) { } driverFile := ":memory:" - driverUrl := "sqlite3://" + driverFile + driverURL := "sqlite3://" + driverFile // prepare clean database connection, err := sql.Open("sqlite3", driverFile) @@ -31,7 +31,7 @@ func TestMigrate(t *testing.T) { } d := &Driver{} - if err := d.Initialize(driverUrl); err != nil { + if err := d.Initialize(driverURL); err != nil { t.Fatal(err) } diff --git a/file/file.go b/file/file.go index 54992a48..ccda7d6b 100644 --- a/file/file.go +++ b/file/file.go @@ -5,7 +5,6 @@ import ( "bytes" "errors" "fmt" - "github.com/mattes/migrate/migrate/direction" "go/token" "io/ioutil" "path" @@ -13,6 +12,8 @@ import ( "sort" "strconv" "strings" + + "github.com/mattes/migrate/migrate/direction" ) var filenameRegex = `^([0-9]+)_(.*)\.(up|down)\.%s$` @@ -137,10 +138,10 @@ func (mf *MigrationFiles) From(version uint64, relativeN int) (Files, error) { if d == direction.Up && migrationFile.Version > version && migrationFile.UpFile != nil { files = append(files, *migrationFile.UpFile) - counter -= 1 + counter-- } else if d == direction.Down && migrationFile.Version <= version && migrationFile.DownFile != nil { files = append(files, *migrationFile.DownFile) - counter -= 1 + counter-- } } else { break @@ -258,7 +259,7 @@ func parseFilenameSchema(filename string, filenameRegex *regexp.Regexp) (version version, err = strconv.ParseUint(matches[1], 10, 0) if err != nil { - return 0, "", 0, errors.New(fmt.Sprintf("Unable to parse version '%v' in filename schema", matches[0])) + return 0, "", 0, fmt.Errorf("Unable to parse version '%v' in filename schema", matches[0]) } if matches[3] == "up" { @@ -266,7 +267,7 @@ func parseFilenameSchema(filename string, filenameRegex *regexp.Regexp) (version } else if matches[3] == "down" { d = direction.Down } else { - return 0, "", 0, errors.New(fmt.Sprintf("Unable to parse up|down '%v' in filename schema", matches[3])) + return 0, "", 0, fmt.Errorf("Unable to parse up|down '%v' in filename schema", matches[3]) } return version, matches[2], d, nil @@ -333,7 +334,7 @@ func LinesBeforeAndAfter(data []byte, line, before, after int, lineNumbers bool) lNew = append([]byte(lineCounterStr+": "), lNew...) } newLines = append(newLines, lNew) - lineCounter += 1 + lineCounter++ } return bytes.Join(newLines, []byte("\n")) diff --git a/file/file_test.go b/file/file_test.go index b9ddeabd..b55f05c8 100644 --- a/file/file_test.go +++ b/file/file_test.go @@ -1,11 +1,12 @@ package file import ( - "github.com/mattes/migrate/migrate/direction" "io/ioutil" "os" "path" "testing" + + "github.com/mattes/migrate/migrate/direction" ) func TestParseFilenameSchema(t *testing.T) { @@ -61,7 +62,7 @@ func TestFiles(t *testing.T) { } defer os.RemoveAll(tmpdir) - if err := ioutil.WriteFile(path.Join(tmpdir, "nonsense.txt"), nil, 0755); err != nil { + if err = ioutil.WriteFile(path.Join(tmpdir, "nonsense.txt"), nil, 0755); err != nil { t.Fatal("Unable to write files in tmpdir", err) } ioutil.WriteFile(path.Join(tmpdir, "002_migrationfile.up.sql"), nil, 0755) @@ -133,7 +134,7 @@ func TestFiles(t *testing.T) { } // test read - if err := files[4].DownFile.ReadContent(); err != nil { + if err = files[4].DownFile.ReadContent(); err != nil { t.Error("Unable to read file", err) } if files[4].DownFile.Content == nil { @@ -169,7 +170,8 @@ func TestFiles(t *testing.T) { } for _, test := range tests { - rangeFiles, err := files.From(test.from, test.relative) + var rangeFiles Files + rangeFiles, err = files.From(test.from, test.relative) if err != nil { t.Error("Unable to fetch range:", err) } diff --git a/main.go b/main.go index 67310264..06cbfac0 100644 --- a/main.go +++ b/main.go @@ -176,32 +176,31 @@ func writePipe(pipe chan interface{}) (ok bool) { case item, more := <-pipe: if !more { return okFlag - } else { - switch item.(type) { - - case string: - fmt.Println(item.(string)) - - case error: + } + switch item.(type) { + + case string: + fmt.Println(item.(string)) + + case error: + c := color.New(color.FgRed) + c.Printf("%s\n\n", item.(error).Error()) + okFlag = false + + case file.File: + f := item.(file.File) + if f.Direction == direction.Up { + c := color.New(color.FgGreen) + c.Print(">") + } else if f.Direction == direction.Down { c := color.New(color.FgRed) - c.Println(item.(error).Error(), "\n") - okFlag = false - - case file.File: - f := item.(file.File) - if f.Direction == direction.Up { - c := color.New(color.FgGreen) - c.Print(">") - } else if f.Direction == direction.Down { - c := color.New(color.FgRed) - c.Print("<") - } - fmt.Printf(" %s\n", f.FileName) - - default: - text := fmt.Sprint(item) - fmt.Println(text) + c.Print("<") } + fmt.Printf(" %s\n", f.FileName) + + default: + text := fmt.Sprint(item) + fmt.Println(text) } } } diff --git a/migrate/direction/direction.go b/migrate/direction/direction.go index ed5e5ece..9516636e 100644 --- a/migrate/direction/direction.go +++ b/migrate/direction/direction.go @@ -1,9 +1,12 @@ // Package direction just holds convenience constants for Up and Down migrations. package direction +// Direction - type that indicates direction of migration(up or down) type Direction int const ( - Up Direction = +1 - Down = -1 + // Up - up migration + Up Direction = +1 + // Down - down migration + Down Direction = -1 ) diff --git a/migrate/migrate.go b/migrate/migrate.go index 0c976be9..4fdef5d4 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -51,13 +51,12 @@ func Up(pipe chan interface{}, url, migrationsPath string) { } go pipep.Close(pipe, nil) return - } else { - if err := d.Close(); err != nil { - pipe <- err - } - go pipep.Close(pipe, nil) - return } + if err := d.Close(); err != nil { + pipe <- err + } + go pipep.Close(pipe, nil) + return } // UpSync is synchronous version of Up @@ -101,13 +100,12 @@ func Down(pipe chan interface{}, url, migrationsPath string) { } go pipep.Close(pipe, nil) return - } else { - if err2 := d.Close(); err2 != nil { - pipe <- err2 - } - go pipep.Close(pipe, nil) - return } + if err2 := d.Close(); err2 != nil { + pipe <- err2 + } + go pipep.Close(pipe, nil) + return } // DownSync is synchronous version of Down @@ -129,9 +127,8 @@ func Redo(pipe chan interface{}, url, migrationsPath string) { if ok := pipep.WaitAndRedirect(pipe1, pipe, signals); !ok { go pipep.Close(pipe, nil) return - } else { - go Migrate(pipe, url, migrationsPath, +1) } + go Migrate(pipe, url, migrationsPath, +1) } // RedoSync is synchronous version of Redo @@ -153,9 +150,8 @@ func Reset(pipe chan interface{}, url, migrationsPath string) { if ok := pipep.WaitAndRedirect(pipe1, pipe, signals); !ok { go pipep.Close(pipe, nil) return - } else { - go Up(pipe, url, migrationsPath) } + go Up(pipe, url, migrationsPath) } // ResetSync is synchronous version of Reset @@ -221,17 +217,20 @@ func Version(url, migrationsPath string) (version uint64, err error) { if err != nil { return 0, err } + defer func() { + err = d.Close() + }() return d.Version() } // Create creates new migration files on disk func Create(url, migrationsPath, name string) (*file.MigrationFile, error) { - d, err := driver.New(url) + ext, err := driver.FilenameExtensionFromURL(url) if err != nil { return nil, err } - files, err := file.ReadMigrationFiles(migrationsPath, file.FilenameRegex(d.FilenameExtension())) + files, err := file.ReadMigrationFiles(migrationsPath, file.FilenameRegex(ext)) if err != nil { return nil, err } @@ -258,14 +257,14 @@ func Create(url, migrationsPath, name string) (*file.MigrationFile, error) { Version: version, UpFile: &file.File{ Path: migrationsPath, - FileName: fmt.Sprintf(filenamef, versionStr, name, "up", d.FilenameExtension()), + FileName: fmt.Sprintf(filenamef, versionStr, name, "up", ext), Name: name, Content: []byte(""), Direction: direction.Up, }, DownFile: &file.File{ Path: migrationsPath, - FileName: fmt.Sprintf(filenamef, versionStr, name, "down", d.FilenameExtension()), + FileName: fmt.Sprintf(filenamef, versionStr, name, "down", ext), Name: name, Content: []byte(""), Direction: direction.Down, diff --git a/migrate/migrate_test.go b/migrate/migrate_test.go index c7b43adb..8801ff81 100644 --- a/migrate/migrate_test.go +++ b/migrate/migrate_test.go @@ -9,18 +9,18 @@ import ( // Ensure imports for each driver we wish to test _ "github.com/mattes/migrate/driver/postgres" - _ "github.com/mattes/migrate/driver/sqlite3" _ "github.com/mattes/migrate/driver/ql" + _ "github.com/mattes/migrate/driver/sqlite3" ) // Add Driver URLs here to test basic Up, Down, .. functions. -var driverUrls = []string{ +var driverURLs = []string{ "postgres://postgres@" + os.Getenv("POSTGRES_PORT_5432_TCP_ADDR") + ":" + os.Getenv("POSTGRES_PORT_5432_TCP_PORT") + "/template1?sslmode=disable", "ql+file://./test.db", } -func tearDown(driverUrl, tmpdir string) { - DownSync(driverUrl, tmpdir) +func tearDown(driverURL, tmpdir string) { + DownSync(driverURL, tmpdir) os.RemoveAll(tmpdir) } @@ -28,18 +28,18 @@ func TestCreate(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } - for _, driverUrl := range driverUrls { - t.Logf("Test driver: %s", driverUrl) + for _, driverURL := range driverURLs { + t.Logf("Test driver: %s", driverURL) tmpdir, err := ioutil.TempDir("/tmp", "migrate-test") if err != nil { t.Fatal(err) } - defer tearDown(driverUrl, tmpdir) + defer tearDown(driverURL, tmpdir) - if _, err := Create(driverUrl, tmpdir, "test_migration"); err != nil { + if _, err = Create(driverURL, tmpdir, "test_migration"); err != nil { t.Fatal(err) } - if _, err := Create(driverUrl, tmpdir, "another migration"); err != nil { + if _, err = Create(driverURL, tmpdir, "another migration"); err != nil { t.Fatal(err) } @@ -83,25 +83,25 @@ func TestReset(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } - for _, driverUrl := range driverUrls { - t.Logf("Test driver: %s", driverUrl) + for _, driverURL := range driverURLs { + t.Logf("Test driver: %s", driverURL) tmpdir, err := ioutil.TempDir("/tmp", "migrate-test") if err != nil { t.Fatal(err) } - defer tearDown(driverUrl, tmpdir) + defer tearDown(driverURL, tmpdir) - Create(driverUrl, tmpdir, "migration1") - f, err := Create(driverUrl, tmpdir, "migration2") + Create(driverURL, tmpdir, "migration1") + f, err := Create(driverURL, tmpdir, "migration2") if err != nil { t.Fatal(err) } - if err, ok := ResetSync(driverUrl, tmpdir); !ok { + if err, ok := ResetSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != f.Version { t.Fatalf("Expected version %v, got %v", version, f.Version) @@ -113,36 +113,36 @@ func TestDown(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } - for _, driverUrl := range driverUrls { - t.Logf("Test driver: %s", driverUrl) + for _, driverURL := range driverURLs { + t.Logf("Test driver: %s", driverURL) tmpdir, err := ioutil.TempDir("/tmp", "migrate-test") if err != nil { t.Fatal(err) } - defer tearDown(driverUrl, tmpdir) + defer tearDown(driverURL, tmpdir) - initVersion, _ := Version(driverUrl, tmpdir) + initVersion, _ := Version(driverURL, tmpdir) - firstMigration, _ := Create(driverUrl, tmpdir, "migration1") - secondMigration, _ := Create(driverUrl, tmpdir, "migration2") + firstMigration, _ := Create(driverURL, tmpdir, "migration1") + secondMigration, _ := Create(driverURL, tmpdir, "migration2") t.Logf("init %v first %v second %v", initVersion, firstMigration.Version, secondMigration.Version) - if err, ok := ResetSync(driverUrl, tmpdir); !ok { + if err, ok := ResetSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != secondMigration.Version { t.Fatalf("Expected version %v, got %v", version, secondMigration.Version) } - if err, ok := DownSync(driverUrl, tmpdir); !ok { + if err, ok := DownSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != 0 { t.Fatalf("Expected 0, got %v", version) @@ -154,36 +154,36 @@ func TestUp(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } - for _, driverUrl := range driverUrls { - t.Logf("Test driver: %s", driverUrl) + for _, driverURL := range driverURLs { + t.Logf("Test driver: %s", driverURL) tmpdir, err := ioutil.TempDir("/tmp", "migrate-test") if err != nil { t.Fatal(err) } - defer tearDown(driverUrl, tmpdir) + defer tearDown(driverURL, tmpdir) - initVersion, _ := Version(driverUrl, tmpdir) + initVersion, _ := Version(driverURL, tmpdir) - firstMigration, _ := Create(driverUrl, tmpdir, "migration1") - secondMigration, _ := Create(driverUrl, tmpdir, "migration2") + firstMigration, _ := Create(driverURL, tmpdir, "migration1") + secondMigration, _ := Create(driverURL, tmpdir, "migration2") t.Logf("init %v first %v second %v", initVersion, firstMigration.Version, secondMigration.Version) - if err, ok := DownSync(driverUrl, tmpdir); !ok { + if err, ok := DownSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != initVersion { t.Fatalf("Expected initial version %v, got %v", initVersion, version) } - if err, ok := UpSync(driverUrl, tmpdir); !ok { + if err, ok := UpSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != secondMigration.Version { t.Fatalf("Expected migrated version %v, got %v", secondMigration.Version, version) @@ -195,36 +195,36 @@ func TestRedo(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } - for _, driverUrl := range driverUrls { - t.Logf("Test driver: %s", driverUrl) + for _, driverURL := range driverURLs { + t.Logf("Test driver: %s", driverURL) tmpdir, err := ioutil.TempDir("/tmp", "migrate-test") if err != nil { t.Fatal(err) } - defer tearDown(driverUrl, tmpdir) + defer tearDown(driverURL, tmpdir) - initVersion, _ := Version(driverUrl, tmpdir) + initVersion, _ := Version(driverURL, tmpdir) - firstMigration, _ := Create(driverUrl, tmpdir, "migration1") - secondMigration, _ := Create(driverUrl, tmpdir, "migration2") + firstMigration, _ := Create(driverURL, tmpdir, "migration1") + secondMigration, _ := Create(driverURL, tmpdir, "migration2") t.Logf("init %v first %v second %v", initVersion, firstMigration.Version, secondMigration.Version) - if err, ok := ResetSync(driverUrl, tmpdir); !ok { + if err, ok := ResetSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != secondMigration.Version { t.Fatalf("Expected migrated version %v, got %v", secondMigration.Version, version) } - if err, ok := RedoSync(driverUrl, tmpdir); !ok { + if err, ok := RedoSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != secondMigration.Version { t.Fatalf("Expected migrated version %v, got %v", secondMigration.Version, version) @@ -236,46 +236,46 @@ func TestMigrate(t *testing.T) { if testing.Short() { t.Skip("skipping test in short mode.") } - for _, driverUrl := range driverUrls { - t.Logf("Test driver: %s", driverUrl) + for _, driverURL := range driverURLs { + t.Logf("Test driver: %s", driverURL) tmpdir, err := ioutil.TempDir("/tmp", "migrate-test") if err != nil { t.Fatal(err) } - defer tearDown(driverUrl, tmpdir) + defer tearDown(driverURL, tmpdir) - initVersion, _ := Version(driverUrl, tmpdir) + initVersion, _ := Version(driverURL, tmpdir) - firstMigration, _ := Create(driverUrl, tmpdir, "migration1") - secondMigration, _ := Create(driverUrl, tmpdir, "migration2") + firstMigration, _ := Create(driverURL, tmpdir, "migration1") + secondMigration, _ := Create(driverURL, tmpdir, "migration2") t.Logf("init %v first %v second %v", initVersion, firstMigration.Version, secondMigration.Version) - if err, ok := ResetSync(driverUrl, tmpdir); !ok { + if err, ok := ResetSync(driverURL, tmpdir); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != secondMigration.Version { t.Fatalf("Expected migrated version %v, got %v", secondMigration.Version, version) } - if err, ok := MigrateSync(driverUrl, tmpdir, -2); !ok { + if err, ok := MigrateSync(driverURL, tmpdir, -2); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != 0 { t.Fatalf("Expected 0, got %v", version) } - if err, ok := MigrateSync(driverUrl, tmpdir, +1); !ok { + if err, ok := MigrateSync(driverURL, tmpdir, +1); !ok { t.Fatal(err) } - if version, err := Version(driverUrl, tmpdir); err != nil { + if version, err := Version(driverURL, tmpdir); err != nil { t.Fatal(err) } else if version != firstMigration.Version { t.Fatalf("Expected first version %v, got %v", firstMigration.Version, version) diff --git a/pipe/pipe.go b/pipe/pipe.go index 02979fc9..3888b352 100644 --- a/pipe/pipe.go +++ b/pipe/pipe.go @@ -32,7 +32,7 @@ func WaitAndRedirect(pipe, redirectPipe chan interface{}, interrupt chan os.Sign select { case <-interrupt: - interruptsReceived += 1 + interruptsReceived++ if interruptsReceived > 1 { os.Exit(5) } else { @@ -43,12 +43,11 @@ func WaitAndRedirect(pipe, redirectPipe chan interface{}, interrupt chan os.Sign case item, ok := <-pipe: if !ok { return !errorReceived && interruptsReceived == 0 - } else { - redirectPipe <- item - switch item.(type) { - case error: - errorReceived = true - } + } + redirectPipe <- item + switch item.(type) { + case error: + errorReceived = true } } } @@ -66,11 +65,10 @@ func ReadErrors(pipe chan interface{}) []error { case item, ok := <-pipe: if !ok { return err - } else { - switch item.(type) { - case error: - err = append(err, item.(error)) - } + } + switch item.(type) { + case error: + err = append(err, item.(error)) } } } diff --git a/version.go b/version.go index 082af71a..cc9748fe 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,4 @@ package main +// Version - version of tool const Version string = "1.3.0"