Skip to content

Commit

Permalink
feat(chmigrate): support distributed migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Sep 7, 2023
1 parent faaf5d3 commit a4c3d24
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 71 deletions.
5 changes: 5 additions & 0 deletions ch/query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func (q *InsertQuery) TableExpr(query string, args ...any) *InsertQuery {
return q
}

func (q *InsertQuery) ModelTable(table string) *InsertQuery {
q.modelTableName = chschema.UnsafeIdent(table)
return q
}

func (q *InsertQuery) ModelTableExpr(query string, args ...any) *InsertQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
Expand Down
5 changes: 5 additions & 0 deletions ch/query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ func (q *SelectQuery) TableExpr(query string, args ...any) *SelectQuery {
return q
}

func (q *SelectQuery) ModelTable(table string) *SelectQuery {
q.modelTableName = chschema.UnsafeIdent(table)
return q
}

func (q *SelectQuery) ModelTableExpr(query string, args ...any) *SelectQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
Expand Down
95 changes: 57 additions & 38 deletions ch/query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type CreateTableQuery struct {
baseQuery

ifNotExists bool
as chschema.QueryWithArgs
onCluster chschema.QueryWithArgs
engine chschema.QueryWithArgs
ttl chschema.QueryWithArgs
Expand Down Expand Up @@ -52,11 +53,21 @@ func (q *CreateTableQuery) TableExpr(query string, args ...any) *CreateTableQuer
return q
}

func (q *CreateTableQuery) ModelTable(table string) *CreateTableQuery {
q.modelTableName = chschema.UnsafeIdent(table)
return q
}

func (q *CreateTableQuery) ModelTableExpr(query string, args ...any) *CreateTableQuery {
q.modelTableName = chschema.SafeQuery(query, args)
return q
}

func (q *CreateTableQuery) As(table string) *CreateTableQuery {
q.as = chschema.UnsafeIdent(table)
return q
}

func (q *CreateTableQuery) ColumnExpr(query string, args ...any) *CreateTableQuery {
q.addColumn(chschema.SafeQuery(query, args))
return q
Expand Down Expand Up @@ -111,10 +122,6 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []
if q.err != nil {
return nil, q.err
}
if q.table == nil {
return nil, errNilModel
}

b = append(b, "CREATE TABLE "...)
if q.ifNotExists {
b = append(b, "IF NOT EXISTS "...)
Expand All @@ -133,36 +140,46 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []
}
}

b = append(b, " ("...)

for i, field := range q.table.Fields {
if i > 0 {
b = append(b, ", "...)
}

b = append(b, field.CHName...)
b = append(b, " "...)
b = append(b, field.CHType...)
if field.NotNull {
b = append(b, " NOT NULL"...)
}
if field.CHDefault != "" {
b = append(b, " DEFAULT "...)
b = append(b, field.CHDefault...)
if !q.as.IsEmpty() {
b = append(b, " AS "...)
b, err = q.as.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}

for i, col := range q.columns {
if i > 0 || len(q.table.Fields) > 0 {
b = append(b, ", "...)
if q.table != nil {
b = append(b, " ("...)

for i, field := range q.table.Fields {
if i > 0 {
b = append(b, ", "...)
}

b = append(b, field.CHName...)
b = append(b, " "...)
b = append(b, field.CHType...)
if field.NotNull {
b = append(b, " NOT NULL"...)
}
if field.CHDefault != "" {
b = append(b, " DEFAULT "...)
b = append(b, field.CHDefault...)
}
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err

for i, col := range q.columns {
if i > 0 || len(q.table.Fields) > 0 {
b = append(b, ", "...)
}
b, err = col.AppendQuery(fmter, b)
if err != nil {
return nil, err
}
}
}

b = append(b, ")"...)
b = append(b, ")"...)
}

b = append(b, " Engine = "...)

Expand All @@ -189,17 +206,19 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []
return nil, err
}
b = append(b, ')')
} else if len(q.table.PKs) > 0 {
b = append(b, " ORDER BY ("...)
for i, pk := range q.table.PKs {
if i > 0 {
b = append(b, ", "...)
} else if q.table != nil {
if len(q.table.PKs) > 0 {
b = append(b, " ORDER BY ("...)
for i, pk := range q.table.PKs {
if i > 0 {
b = append(b, ", "...)
}
b = append(b, pk.CHName...)
}
b = append(b, pk.CHName...)
b = append(b, ')')
} else if q.table.CHEngine == "" {
b = append(b, " ORDER BY tuple()"...)
}
b = append(b, ')')
} else if q.table.CHEngine == "" {
b = append(b, " ORDER BY tuple()"...)
}

if !q.ttl.IsZero() {
Expand All @@ -219,7 +238,7 @@ func (q *CreateTableQuery) AppendQuery(fmter chschema.Formatter, b []byte) (_ []
}

func (q *CreateTableQuery) appendPartition(fmter chschema.Formatter, b []byte) ([]byte, error) {
if q.partition.IsZero() && q.table.CHPartition == "" {
if q.partition.IsZero() && (q.table == nil || q.table.CHPartition == "") {
return b, nil
}

Expand Down
9 changes: 9 additions & 0 deletions ch/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ func TestQuery(t *testing.T) {
q2 := db.NewSelect().Model(new(Model))
return q1.UnionAll(q2)
},
func(db *ch.DB) chschema.QueryAppender {
return db.NewCreateTable().
Table("my-table_dist").
As("my-table").
Engine("Distributed(?, currentDatabase(), ?, rand())",
ch.Ident("my-cluster"), ch.Ident("my-table")).
OnCluster("my-cluster").
IfNotExists()
},
}

db := chDB()
Expand Down
1 change: 1 addition & 0 deletions ch/testdata/snapshots/TestQuery-19
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE TABLE IF NOT EXISTS "my-table_dist" AS "my-table" ON CLUSTER "my-cluster" Engine = Distributed("my-cluster", currentDatabase(), "my-table", rand())
99 changes: 66 additions & 33 deletions chmigrate/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type MigratorOption func(m *Migrator)

func WithTableName(table string) MigratorOption {
return func(m *Migrator) {
m.table = table
m.migrationsTable = table
}
}

Expand All @@ -33,9 +33,15 @@ func WithReplicated(on bool) MigratorOption {
}
}

func WithDistributed(on bool) MigratorOption {
return func(m *Migrator) {
m.distributed = on
}
}

func WithOnCluster(cluster string) MigratorOption {
return func(m *Migrator) {
m.onCluster = cluster
m.cluster = cluster
}
}

Expand All @@ -53,10 +59,11 @@ type Migrator struct {

ms MigrationSlice

table string
migrationsTable string
locksTable string
replicated bool
onCluster string
distributed bool
cluster string
markAppliedOnSuccess bool
}

Expand All @@ -67,8 +74,8 @@ func NewMigrator(db *ch.DB, migrations *Migrations, opts ...MigratorOption) *Mig

ms: migrations.ms,

table: "ch_migrations",
locksTable: "ch_migration_locks",
migrationsTable: "ch_migrations",
locksTable: "ch_migration_locks",
}
for _, opt := range opts {
opt(m)
Expand Down Expand Up @@ -107,6 +114,12 @@ func (m *Migrator) migrationsWithStatus(ctx context.Context) (MigrationSlice, in
}

func (m *Migrator) Init(ctx context.Context) error {
if m.distributed {
if m.cluster == "" {
return errors.New("chmigrate: distributed requires a cluster name")
}
}

if _, err := m.db.NewCreateTable().
Model((*Migration)(nil)).
Apply(func(q *ch.CreateTableQuery) *ch.CreateTableQuery {
Expand All @@ -115,12 +128,13 @@ func (m *Migrator) Init(ctx context.Context) error {
}
return q.Engine("CollapsingMergeTree(sign)")
}).
ModelTableExpr(m.table).
OnCluster(m.onCluster).
ModelTable(m.migrationsTable).
OnCluster(m.cluster).
IfNotExists().
Exec(ctx); err != nil {
return err
}

if _, err := m.db.NewCreateTable().
Model((*migrationLock)(nil)).
Apply(func(q *ch.CreateTableQuery) *ch.CreateTableQuery {
Expand All @@ -129,31 +143,47 @@ func (m *Migrator) Init(ctx context.Context) error {
}
return q.Engine("MergeTree")
}).
ModelTableExpr(m.locksTable).
OnCluster(m.onCluster).
ModelTable(m.locksTable).
OnCluster(m.cluster).
IfNotExists().
Exec(ctx); err != nil {
return err
}

if m.distributed {
if _, err := m.db.NewCreateTable().
Table(m.distTable(m.migrationsTable)).
As(m.migrationsTable).
Engine("Distributed(?, currentDatabase(), ?, rand())",
ch.Ident(m.cluster), ch.Ident(m.migrationsTable)).
OnCluster(m.cluster).
IfNotExists().
Exec(ctx); err != nil {
return err
}
}

return nil
}

func (m *Migrator) Reset(ctx context.Context) error {
if _, err := m.db.NewDropTable().
Model((*Migration)(nil)).
ModelTableExpr(m.table).
OnCluster(m.onCluster).
IfExists().
Exec(ctx); err != nil {
return err
}
if _, err := m.db.NewDropTable().
Model((*migrationLock)(nil)).
ModelTableExpr(m.locksTable).
OnCluster(m.onCluster).
IfExists().
Exec(ctx); err != nil {
return err
tables := []string{
m.migrationsTable,
m.locksTable,
}
if m.distributed {
tables = append(tables,
m.distTable(m.migrationsTable),
)
}
for _, tableName := range tables {
if _, err := m.db.NewDropTable().
Table(tableName).
OnCluster(m.cluster).
IfExists().
Exec(ctx); err != nil {
return err
}
}
return m.Init(ctx)
}
Expand Down Expand Up @@ -363,7 +393,7 @@ func (m *Migrator) MarkApplied(ctx context.Context, migration *Migration) error
migration.MigratedAt = time.Now()
_, err := m.db.NewInsert().
Model(migration).
ModelTableExpr(m.table).
ModelTable(m.distTable(m.migrationsTable)).
Exec(ctx)
return err
}
Expand All @@ -373,13 +403,13 @@ func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) erro
migration.Sign = -1
_, err := m.db.NewInsert().
Model(migration).
ModelTableExpr(m.table).
ModelTable(m.distTable(m.migrationsTable)).
Exec(ctx)
return err
}

func (m *Migrator) TruncateTable(ctx context.Context) error {
_, err := m.db.Exec("TRUNCATE TABLE ?", ch.Ident(m.table))
_, err := m.db.Exec("TRUNCATE TABLE ?", ch.Ident(m.distTable(m.migrationsTable)))
return err
}

Expand Down Expand Up @@ -407,25 +437,28 @@ func (m *Migrator) AppliedMigrations(ctx context.Context) (MigrationSlice, error
if err := m.db.NewSelect().
ColumnExpr("*").
Model(&ms).
ModelTableExpr(m.table).
ModelTable(m.distTable(m.migrationsTable)).
Final().
Scan(ctx); err != nil {
return nil, err
}
return ms, nil
}

func (m *Migrator) formattedTableName(db *ch.DB) string {
return db.Formatter().FormatQuery(m.table)
}

func (m *Migrator) validate() error {
if len(m.ms) == 0 {
return errors.New("chmigrate: there are no any migrations")
}
return nil
}

func (m *Migrator) distTable(table string) string {
if m.distributed {
return table + "_dist"
}
return table
}

//------------------------------------------------------------------------------

type migrationLock struct {
Expand Down

0 comments on commit a4c3d24

Please sign in to comment.